diff options
-rw-r--r-- | conf/java-opts | 1 | ||||
-rw-r--r-- | conf/log4j.properties | 8 | ||||
-rwxr-xr-x | conf/spark-env.sh | 13 | ||||
-rw-r--r-- | src/scala/spark/Broadcast.scala | 808 | ||||
-rw-r--r-- | src/scala/spark/ChainedBroadcast.scala | 863 | ||||
-rw-r--r-- | src/scala/spark/DfsBroadcast.scala | 127 | ||||
-rw-r--r-- | src/scala/spark/SparkContext.scala | 22 | ||||
-rw-r--r-- | src/scala/spark/repl/ClassServer.scala | 77 |
8 files changed, 1161 insertions, 758 deletions
diff --git a/conf/java-opts b/conf/java-opts new file mode 100644 index 0000000000..20a2ade45c --- /dev/null +++ b/conf/java-opts @@ -0,0 +1 @@ +-Dspark.broadcast.MasterHostAddress=127.0.0.1 -Dspark.broadcast.MasterTrackerPort=22222 -Dspark.broadcast.BlockSize=256 -Dspark.broadcast.MaxRetryCount=2 -Dspark.broadcast.TrackerSocketTimeout=50000 -Dspark.broadcast.ServerSocketTimeout=10000 diff --git a/conf/log4j.properties b/conf/log4j.properties new file mode 100644 index 0000000000..33774b463d --- /dev/null +++ b/conf/log4j.properties @@ -0,0 +1,8 @@ +# Set everything to be logged to the console +log4j.rootCategory=INFO, console +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.eclipse.jetty=WARN diff --git a/conf/spark-env.sh b/conf/spark-env.sh new file mode 100755 index 0000000000..77f9cb69b9 --- /dev/null +++ b/conf/spark-env.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash + +# Set Spark environment variables for your site in this file. Some useful +# variables to set are: +# - MESOS_HOME, to point to your Mesos installation +# - SCALA_HOME, to point to your Scala installation +# - SPARK_CLASSPATH, to add elements to Spark's classpath +# - SPARK_JAVA_OPTS, to add JVM options +# - SPARK_MEM, to change the amount of memory used per node (this should +# be in the same format as the JVM's -Xmx option, e.g. 300m or 1g). +# - SPARK_LIBRARY_PATH, to add extra search paths for native libraries. + +MESOS_HOME=/home/mosharaf/Work/mesos diff --git a/src/scala/spark/Broadcast.scala b/src/scala/spark/Broadcast.scala index 5089dca82e..afff500bb0 100644 --- a/src/scala/spark/Broadcast.scala +++ b/src/scala/spark/Broadcast.scala @@ -1,23 +1,10 @@ package spark -import java.io._ -import java.net._ -import java.util.{UUID, PriorityQueue, Comparator} - -import java.util.concurrent.{Executors, ExecutorService} - -import scala.actors.Actor -import scala.actors.Actor._ - -import scala.collection.mutable.Map - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem} - -import spark.compress.lzf.{LZFInputStream, LZFOutputStream} +import java.util.UUID +import java.util.concurrent.{Executors, ThreadPoolExecutor, ThreadFactory} @serializable -trait BroadcastRecipe { +trait Broadcast { val uuid = UUID.randomUUID // We cannot have an abstract readObject here due to some weird issues with @@ -27,173 +14,80 @@ trait BroadcastRecipe { override def toString = "spark.Broadcast(" + uuid + ")" } -// TODO: Right, now no parallelization between multiple broadcasts -@serializable -class ChainedStreamingBroadcast[T] (@transient var value_ : T, local: Boolean) -extends BroadcastRecipe with Logging { - - def value = value_ +private object Broadcast +extends Logging { + private var initialized = false - BroadcastCS.synchronized { BroadcastCS.values.put (uuid, value_) } - - if (!local) { sendBroadcast } - - def sendBroadcast () { - // Create a variableInfo object and store it in valueInfos - var variableInfo = blockifyObject (value_, BroadcastCS.blockSize) - // TODO: Even though this part is not in use now, there is problem in the - // following statement. Shouldn't use constant port and hostAddress anymore? - // val masterSource = - // new SourceInfo (BroadcastCS.masterHostAddress, BroadcastCS.masterListenPort, - // variableInfo.totalBlocks, variableInfo.totalBytes, 0) - // variableInfo.pqOfSources.add (masterSource) - - BroadcastCS.synchronized { - // BroadcastCS.valueInfos.put (uuid, variableInfo) - - // TODO: Not using variableInfo in current implementation. Manually - // setting all the variables inside BroadcastCS object - - BroadcastCS.initializeVariable (variableInfo) + // Called by SparkContext or Executor before using Broadcast + // Calls all other initializers here + def initialize (isMaster: Boolean): Unit = { + synchronized { + if (!initialized) { + // Initialization for DfsBroadcast + DfsBroadcast.initialize + // Initialization for ChainedStreamingBroadcast + ChainedBroadcast.initialize (isMaster) + + initialized = true + } } - - // Now store a persistent copy in HDFS, just in case - val out = new ObjectOutputStream (BroadcastCH.openFileForWriting(uuid)) - out.writeObject (value_) - out.close } - // Called by Java when deserializing an object - private def readObject (in: ObjectInputStream) { - in.defaultReadObject - BroadcastCS.synchronized { - val cachedVal = BroadcastCS.values.get (uuid) - if (cachedVal != null) { - value_ = cachedVal.asInstanceOf[T] - } else { - // Only a single worker (the first one) in the same node can ever be - // here. The rest will always get the value ready. - val start = System.nanoTime - - val retByteArray = BroadcastCS.receiveBroadcast (uuid) - // If does not succeed, then get from HDFS copy - if (retByteArray != null) { - value_ = byteArrayToObject[T] (retByteArray) - BroadcastCS.values.put (uuid, value_) - // val variableInfo = blockifyObject (value_, BroadcastCS.blockSize) - // BroadcastCS.valueInfos.put (uuid, variableInfo) - } else { - val fileIn = new ObjectInputStream(BroadcastCH.openFileForReading(uuid)) - value_ = fileIn.readObject.asInstanceOf[T] - BroadcastCH.values.put(uuid, value_) - fileIn.close - } - - val time = (System.nanoTime - start) / 1e9 - logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s") + // Returns a standard ThreadFactory except all threads are daemons + private def newDaemonThreadFactory: ThreadFactory = { + new ThreadFactory { + def newThread(r: Runnable): Thread = { + var t = Executors.defaultThreadFactory.newThread (r) + t.setDaemon (true) + return t } - } + } } - private def blockifyObject (obj: T, blockSize: Int): VariableInfo = { - val baos = new ByteArrayOutputStream - val oos = new ObjectOutputStream (baos) - oos.writeObject (obj) - oos.close - baos.close - val byteArray = baos.toByteArray - val bais = new ByteArrayInputStream (byteArray) - - var blockNum = (byteArray.length / blockSize) - if (byteArray.length % blockSize != 0) - blockNum += 1 - - var retVal = new Array[BroadcastBlock] (blockNum) - var blockID = 0 - - // TODO: What happens in byteArray.length == 0 => blockNum == 0 - for (i <- 0 until (byteArray.length, blockSize)) { - val thisBlockSize = Math.min (blockSize, byteArray.length - i) - var tempByteArray = new Array[Byte] (thisBlockSize) - val hasRead = bais.read (tempByteArray, 0, thisBlockSize) - - retVal (blockID) = new BroadcastBlock (blockID, tempByteArray) - blockID += 1 - } - bais.close - - var variableInfo = VariableInfo (retVal, blockNum, byteArray.length) - variableInfo.hasBlocks = blockNum - - return variableInfo - } + // Wrapper over newCachedThreadPool + def newDaemonCachedThreadPool: ThreadPoolExecutor = { + var threadPool = + Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor] - private def byteArrayToObject[A] (bytes: Array[Byte]): A = { - val in = new ObjectInputStream (new ByteArrayInputStream (bytes)) - val retVal = in.readObject.asInstanceOf[A] - in.close - return retVal + threadPool.setThreadFactory (newDaemonThreadFactory) + + return threadPool } - private def getByteArrayOutputStream (obj: T): ByteArrayOutputStream = { - val bOut = new ByteArrayOutputStream - val out = new ObjectOutputStream (bOut) - out.writeObject (obj) - out.close - bOut.close - return bOut - } -} - -@serializable -class CentralizedHDFSBroadcast[T](@transient var value_ : T, local: Boolean) -extends BroadcastRecipe with Logging { + // Wrapper over newFixedThreadPool + def newDaemonFixedThreadPool (nThreads: Int): ThreadPoolExecutor = { + var threadPool = + Executors.newFixedThreadPool (nThreads).asInstanceOf[ThreadPoolExecutor] - def value = value_ - - BroadcastCH.synchronized { BroadcastCH.values.put(uuid, value_) } - - if (!local) { sendBroadcast } - - def sendBroadcast () { - val out = new ObjectOutputStream (BroadcastCH.openFileForWriting(uuid)) - out.writeObject (value_) - out.close - } - - // Called by Java when deserializing an object - private def readObject(in: ObjectInputStream) { - in.defaultReadObject - BroadcastCH.synchronized { - val cachedVal = BroadcastCH.values.get(uuid) - if (cachedVal != null) { - value_ = cachedVal.asInstanceOf[T] - } else { - val start = System.nanoTime - - val fileIn = new ObjectInputStream(BroadcastCH.openFileForReading(uuid)) - value_ = fileIn.readObject.asInstanceOf[T] - BroadcastCH.values.put(uuid, value_) - fileIn.close - - val time = (System.nanoTime - start) / 1e9 - logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s") - } - } - } + threadPool.setThreadFactory (newDaemonThreadFactory) + + return threadPool + } } @serializable case class SourceInfo (val hostAddress: String, val listenPort: Int, val totalBlocks: Int, val totalBytes: Int, val replicaID: Int) -extends Comparable[SourceInfo]{ +extends Comparable [SourceInfo] with Logging { var currentLeechers = 0 var receptionFailed = false + var hasBlocks = 0 + + // Ascending sort based on leecher count def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers) } +object SourceInfo { + // Constants for special values of listenPort + val TxNotStartedRetry = -1 + val TxOverGoToHDFS = 0 + // Other constants + val StopBroadcast = -2 + val UnusedParam = 0 +} + @serializable case class BroadcastBlock (val blockID: Int, val byteArray: Array[Byte]) { } @@ -202,598 +96,4 @@ case class VariableInfo (@transient val arrayOfBlocks : Array[BroadcastBlock], val totalBlocks: Int, val totalBytes: Int) { @transient var hasBlocks = 0 - - val listenPortLock = new AnyRef - val totalBlocksLock = new AnyRef - val hasBlocksLock = new AnyRef - - @transient var pqOfSources = new PriorityQueue[SourceInfo] -} - -private object Broadcast { - private var initialized = false - - // Will be called by SparkContext or Executor before using Broadcast - // Calls all other initializers here - def initialize (isMaster: Boolean) { - synchronized { - if (!initialized) { - // Initialization for CentralizedHDFSBroadcast - BroadcastCH.initialize - // Initialization for ChainedStreamingBroadcast - // BroadcastCS.initialize (isMaster) - - initialized = true - } - } - } -} - -private object BroadcastCS extends Logging { - val values = Cache.newKeySpace() - - // private var valueToPort = Map[UUID, Int] () - - private var initialized = false - private var isMaster_ = false - - private var masterHostAddress_ = "127.0.0.1" - private var masterListenPort_ : Int = 11111 - private var blockSize_ : Int = 512 * 1024 - private var maxRetryCount_ : Int = 2 - private var serverSocketTimout_ : Int = 50000 - private var dualMode_ : Boolean = false - - private val hostAddress = InetAddress.getLocalHost.getHostAddress - private var listenPort = -1 - - var arrayOfBlocks: Array[BroadcastBlock] = null - var totalBytes = -1 - var totalBlocks = -1 - var hasBlocks = 0 - - val listenPortLock = new Object - val totalBlocksLock = new Object - val hasBlocksLock = new Object - - var pqOfSources = new PriorityQueue[SourceInfo] - - private var serveMR: ServeMultipleRequests = null - private var guideMR: GuideMultipleRequests = null - - def initialize (isMaster__ : Boolean) { - synchronized { - if (!initialized) { - masterHostAddress_ = - System.getProperty ("spark.broadcast.masterHostAddress", "127.0.0.1") - masterListenPort_ = - System.getProperty ("spark.broadcast.masterListenPort", "11111").toInt - blockSize_ = - System.getProperty ("spark.broadcast.blockSize", "512").toInt * 1024 - maxRetryCount_ = - System.getProperty ("spark.broadcast.maxRetryCount", "2").toInt - serverSocketTimout_ = - System.getProperty ("spark.broadcast.serverSocketTimout", "50000").toInt - dualMode_ = - System.getProperty ("spark.broadcast.dualMode", "false").toBoolean - - isMaster_ = isMaster__ - - if (isMaster) { - guideMR = new GuideMultipleRequests - guideMR.setDaemon (true) - guideMR.start - logInfo("GuideMultipleRequests started") - } - - serveMR = new ServeMultipleRequests - serveMR.setDaemon (true) - serveMR.start - logInfo("ServeMultipleRequests started") - - logInfo("BroadcastCS object has been initialized") - - initialized = true - } - } - } - - // TODO: This should change in future implementation. - // Called from the Master constructor to setup states for this particular that - // is being broadcasted - def initializeVariable (variableInfo: VariableInfo) { - arrayOfBlocks = variableInfo.arrayOfBlocks - totalBytes = variableInfo.totalBytes - totalBlocks = variableInfo.totalBlocks - hasBlocks = variableInfo.totalBlocks - - // listenPort should already be valid - assert (listenPort != -1) - - pqOfSources = new PriorityQueue[SourceInfo] - val masterSource_0 = - new SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes, 0) - BroadcastCS.pqOfSources.add (masterSource_0) - // Add one more time to have two replicas of any seeds in the PQ - if (BroadcastCS.dualMode) { - val masterSource_1 = - new SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes, 1) - BroadcastCS.pqOfSources.add (masterSource_1) - } - } - - def masterHostAddress = masterHostAddress_ - def masterListenPort = masterListenPort_ - def blockSize = blockSize_ - def maxRetryCount = maxRetryCount_ - def serverSocketTimout = serverSocketTimout_ - def dualMode = dualMode_ - - def isMaster = isMaster_ - - def receiveBroadcast (variableUUID: UUID): Array[Byte] = { - // Wait until hostAddress and listenPort are created by the - // ServeMultipleRequests thread - // NO need to wait; ServeMultipleRequests is created much further ahead - while (listenPort == -1) { - listenPortLock.synchronized { - listenPortLock.wait - } - } - - // Connect and receive broadcast from the specified source, retrying the - // specified number of times in case of failures - var retriesLeft = BroadcastCS.maxRetryCount - var retByteArray: Array[Byte] = null - do { - // Connect to Master and send this worker's Information - val clientSocketToMaster = - new Socket(BroadcastCS.masterHostAddress, BroadcastCS.masterListenPort) - logInfo("Connected to Master's guiding object") - // TODO: Guiding object connection is reusable - val oisMaster = - new ObjectInputStream (clientSocketToMaster.getInputStream) - val oosMaster = - new ObjectOutputStream (clientSocketToMaster.getOutputStream) - - oosMaster.writeObject(new SourceInfo (hostAddress, listenPort, -1, -1, 0)) - oosMaster.flush - - // Receive source information from Master - var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo] - totalBlocks = sourceInfo.totalBlocks - arrayOfBlocks = new Array[BroadcastBlock] (totalBlocks) - totalBlocksLock.synchronized { - totalBlocksLock.notifyAll - } - totalBytes = sourceInfo.totalBytes - logInfo("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort) - - retByteArray = receiveSingleTransmission (sourceInfo) - - logInfo("I got this from receiveSingleTransmission: " + retByteArray) - - // TODO: Update sourceInfo to add error notifactions for Master - if (retByteArray == null) { sourceInfo.receptionFailed = true } - - // TODO: Supposed to update values here, but we don't support advanced - // statistics right now. Master can handle leecherCount by itself. - - // Send back statistics to the Master - oosMaster.writeObject (sourceInfo) - - oisMaster.close - oosMaster.close - clientSocketToMaster.close - - retriesLeft -= 1 - } while (retriesLeft > 0 && retByteArray == null) - - return retByteArray - } - - // Tries to receive broadcast from the Master and returns Boolean status. - // This might be called multiple times to retry a defined number of times. - private def receiveSingleTransmission(sourceInfo: SourceInfo): Array[Byte] = { - var clientSocketToSource: Socket = null - var oisSource: ObjectInputStream = null - var oosSource: ObjectOutputStream = null - - var retByteArray:Array[Byte] = null - - try { - // Connect to the source to get the object itself - clientSocketToSource = - new Socket (sourceInfo.hostAddress, sourceInfo.listenPort) - oosSource = - new ObjectOutputStream (clientSocketToSource.getOutputStream) - oisSource = - new ObjectInputStream (clientSocketToSource.getInputStream) - - logInfo("Inside receiveSingleTransmission") - logInfo("totalBlocks: " + totalBlocks + " " + "hasBlocks: " + hasBlocks) - retByteArray = new Array[Byte] (totalBytes) - for (i <- 0 until totalBlocks) { - val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock] - System.arraycopy (bcBlock.byteArray, 0, retByteArray, - i * BroadcastCS.blockSize, bcBlock.byteArray.length) - arrayOfBlocks(hasBlocks) = bcBlock - hasBlocks += 1 - hasBlocksLock.synchronized { - hasBlocksLock.notifyAll - } - logInfo("Received block: " + i + " " + bcBlock) - } - assert (hasBlocks == totalBlocks) - logInfo("After the receive loop") - } catch { - case e: Exception => { - retByteArray = null - logInfo("receiveSingleTransmission had a " + e) - } - } finally { - if (oisSource != null) { oisSource.close } - if (oosSource != null) { oosSource.close } - if (clientSocketToSource != null) { clientSocketToSource.close } - } - - return retByteArray - } - -// class TrackMultipleValues extends Thread with Logging { -// override def run = { -// var threadPool = Executors.newCachedThreadPool -// var serverSocket: ServerSocket = null -// -// serverSocket = new ServerSocket (BroadcastCS.masterListenPort) -// logInfo("TrackMultipleVariables" + serverSocket + " " + listenPort) -// -// var keepAccepting = true -// try { -// while (keepAccepting) { -// var clientSocket: Socket = null -// try { -// serverSocket.setSoTimeout (serverSocketTimout) -// clientSocket = serverSocket.accept -// } catch { -// case e: Exception => { -// logInfo("TrackMultipleValues Timeout. Stopping listening...") -// keepAccepting = false -// } -// } -// logInfo("TrackMultipleValues:Got new request:" + clientSocket) -// if (clientSocket != null) { -// try { -// threadPool.execute (new Runnable { -// def run = { -// val oos = new ObjectOutputStream (clientSocket.getOutputStream) -// val ois = new ObjectInputStream (clientSocket.getInputStream) -// try { -// val variableUUID = ois.readObject.asInstanceOf[UUID] -// var contactPort = 0 -// // TODO: Add logic and data structures to find out UUID->port -// // mapping. 0 = missed the broadcast, read from HDFS; <0 = -// // Haven't started yet, wait & retry; >0 = Read from this port -// oos.writeObject (contactPort) -// } catch { -// case e: Exception => { } -// } finally { -// ois.close -// oos.close -// clientSocket.close -// } -// } -// }) -// } catch { -// // In failure, close the socket here; else, the thread will close it -// case ioe: IOException => clientSocket.close -// } -// } -// } -// } finally { -// serverSocket.close -// } -// } -// } -// -// class TrackSingleValue { -// -// } - -// public static ExecutorService newCachedThreadPool() { -// return new ThreadPoolExecutor(0, Integer.MAX_VALUE, 60L, TimeUnit.SECONDS, -// new SynchronousQueue<Runnable>()); -// } - - - class GuideMultipleRequests extends Thread with Logging { - override def run = { - var threadPool = Executors.newCachedThreadPool - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket (BroadcastCS.masterListenPort) - // listenPort = BroadcastCS.masterListenPort - logInfo("GuideMultipleRequests" + serverSocket + " " + listenPort) - - var keepAccepting = true - try { - while (keepAccepting) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout (serverSocketTimout) - clientSocket = serverSocket.accept - } catch { - case e: Exception => { - logInfo("GuideMultipleRequests Timeout. Stopping listening...") - keepAccepting = false - } - } - if (clientSocket != null) { - logInfo("Guide:Accepted new client connection:" + clientSocket) - try { - threadPool.execute (new GuideSingleRequest (clientSocket)) - } catch { - // In failure, close the socket here; else, the thread will close it - case ioe: IOException => clientSocket.close - } - } - } - } finally { - serverSocket.close - } - } - - class GuideSingleRequest (val clientSocket: Socket) - extends Runnable with Logging { - private val oos = new ObjectOutputStream (clientSocket.getOutputStream) - private val ois = new ObjectInputStream (clientSocket.getInputStream) - - private var selectedSourceInfo: SourceInfo = null - private var thisWorkerInfo:SourceInfo = null - - def run = { - try { - logInfo("new GuideSingleRequest is running") - // Connecting worker is sending in its hostAddress and listenPort it will - // be listening to. ReplicaID is 0 and other fields are invalid (-1) - var sourceInfo = ois.readObject.asInstanceOf[SourceInfo] - - // Select a suitable source and send it back to the worker - selectedSourceInfo = selectSuitableSource (sourceInfo) - logInfo("Sending selectedSourceInfo:" + selectedSourceInfo) - oos.writeObject (selectedSourceInfo) - oos.flush - - // Add this new (if it can finish) source to the PQ of sources - thisWorkerInfo = new SourceInfo(sourceInfo.hostAddress, - sourceInfo.listenPort, totalBlocks, totalBytes, 0) - logInfo("Adding possible new source to pqOfSources: " + thisWorkerInfo) - pqOfSources.synchronized { - pqOfSources.add (thisWorkerInfo) - } - - // Wait till the whole transfer is done. Then receive and update source - // statistics in pqOfSources - sourceInfo = ois.readObject.asInstanceOf[SourceInfo] - - pqOfSources.synchronized { - // This should work since SourceInfo is a case class - assert (pqOfSources.contains (selectedSourceInfo)) - - // Remove first - pqOfSources.remove (selectedSourceInfo) - // TODO: Removing a source based on just one failure notification! - // Update leecher count and put it back in IF reception succeeded - if (!sourceInfo.receptionFailed) { - selectedSourceInfo.currentLeechers -= 1 - pqOfSources.add (selectedSourceInfo) - - // No need to find and update thisWorkerInfo, but add its replica - if (BroadcastCS.dualMode) { - pqOfSources.add (new SourceInfo (thisWorkerInfo.hostAddress, - thisWorkerInfo.listenPort, totalBlocks, totalBytes, 1)) - } - } - } - } catch { - // If something went wrong, e.g., the worker at the other end died etc. - // then close everything up - case e: Exception => { - // Assuming that exception caused due to receiver worker failure - // Remove failed worker from pqOfSources and update leecherCount of - // corresponding source worker - pqOfSources.synchronized { - if (selectedSourceInfo != null) { - // Remove first - pqOfSources.remove (selectedSourceInfo) - // Update leecher count and put it back in - selectedSourceInfo.currentLeechers -= 1 - pqOfSources.add (selectedSourceInfo) - } - - // Remove thisWorkerInfo - if (pqOfSources != null) { pqOfSources.remove (thisWorkerInfo) } - } - } - } finally { - ois.close - oos.close - clientSocket.close - } - } - - // TODO: If a worker fails to get the broadcasted variable from a source and - // comes back to Master, this function might choose the worker itself as a - // source tp create a dependency cycle (this worker was put into pqOfSources - // as a streming source when it first arrived). The length of this cycle can - // be arbitrarily long. - private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = { - // Select one with the lowest number of leechers - pqOfSources.synchronized { - // take is a blocking call removing the element from PQ - var selectedSource = pqOfSources.poll - assert (selectedSource != null) - // Update leecher count - selectedSource.currentLeechers += 1 - // Add it back and then return - pqOfSources.add (selectedSource) - return selectedSource - } - } - } - } - - class ServeMultipleRequests extends Thread with Logging { - override def run = { - var threadPool = Executors.newCachedThreadPool - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket (0) - listenPort = serverSocket.getLocalPort - logInfo("ServeMultipleRequests" + serverSocket + " " + listenPort) - - listenPortLock.synchronized { - listenPortLock.notifyAll - } - - var keepAccepting = true - try { - while (keepAccepting) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout (serverSocketTimout) - clientSocket = serverSocket.accept - } catch { - case e: Exception => { - logInfo("ServeMultipleRequests Timeout. Stopping listening...") - keepAccepting = false - } - } - if (clientSocket != null) { - logInfo("Serve:Accepted new client connection:" + clientSocket) - try { - threadPool.execute (new ServeSingleRequest (clientSocket)) - } catch { - // In failure, close socket here; else, the thread will close it - case ioe: IOException => clientSocket.close - } - } - } - } finally { - serverSocket.close - } - } - - class ServeSingleRequest (val clientSocket: Socket) - extends Runnable with Logging { - private val oos = new ObjectOutputStream (clientSocket.getOutputStream) - private val ois = new ObjectInputStream (clientSocket.getInputStream) - - def run = { - try { - logInfo("new ServeSingleRequest is running") - sendObject - } catch { - // TODO: Need to add better exception handling here - // If something went wrong, e.g., the worker at the other end died etc. - // then close everything up - case e: Exception => { - logInfo("ServeSingleRequest had a " + e) - } - } finally { - logInfo("ServeSingleRequest is closing streams and sockets") - ois.close - oos.close - clientSocket.close - } - } - - private def sendObject = { - // Wait till receiving the SourceInfo from Master - while (totalBlocks == -1) { - totalBlocksLock.synchronized { - totalBlocksLock.wait - } - } - - for (i <- 0 until totalBlocks) { - while (i == hasBlocks) { - hasBlocksLock.synchronized { - hasBlocksLock.wait - } - } - try { - oos.writeObject (arrayOfBlocks(i)) - oos.flush - } catch { - case e: Exception => { } - } - logInfo("Send block: " + i + " " + arrayOfBlocks(i)) - } - } - } - } -} - -private object BroadcastCH extends Logging { - val values = Cache.newKeySpace() - - private var initialized = false - - private var fileSystem: FileSystem = null - private var workDir: String = null - private var compress: Boolean = false - private var bufferSize: Int = 65536 - - def initialize () { - synchronized { - if (!initialized) { - bufferSize = System.getProperty("spark.buffer.size", "65536").toInt - val dfs = System.getProperty("spark.dfs", "file:///") - if (!dfs.startsWith("file://")) { - val conf = new Configuration() - conf.setInt("io.file.buffer.size", bufferSize) - val rep = System.getProperty("spark.dfs.replication", "3").toInt - conf.setInt("dfs.replication", rep) - fileSystem = FileSystem.get(new URI(dfs), conf) - } - workDir = System.getProperty("spark.dfs.workdir", "/tmp") - compress = System.getProperty("spark.compress", "false").toBoolean - - initialized = true - } - } - } - - private def getPath(uuid: UUID) = new Path(workDir + "/broadcast-" + uuid) - - def openFileForReading(uuid: UUID): InputStream = { - val fileStream = if (fileSystem != null) { - fileSystem.open(getPath(uuid)) - } else { - // Local filesystem - new FileInputStream(getPath(uuid).toString) - } - if (compress) - new LZFInputStream(fileStream) // LZF stream does its own buffering - else if (fileSystem == null) - new BufferedInputStream(fileStream, bufferSize) - else - fileStream // Hadoop streams do their own buffering - } - - def openFileForWriting(uuid: UUID): OutputStream = { - val fileStream = if (fileSystem != null) { - fileSystem.create(getPath(uuid)) - } else { - // Local filesystem - new FileOutputStream(getPath(uuid).toString) - } - if (compress) - new LZFOutputStream(fileStream) // LZF stream does its own buffering - else if (fileSystem == null) - new BufferedOutputStream(fileStream, bufferSize) - else - fileStream // Hadoop streams do their own buffering - } } diff --git a/src/scala/spark/ChainedBroadcast.scala b/src/scala/spark/ChainedBroadcast.scala new file mode 100644 index 0000000000..32f97ce442 --- /dev/null +++ b/src/scala/spark/ChainedBroadcast.scala @@ -0,0 +1,863 @@ +package spark + +import java.io._ +import java.net._ +import java.util.{Comparator, PriorityQueue, Random, UUID} + +import com.google.common.collect.MapMaker + +import scala.collection.mutable.{Map, Set} + +@serializable +class ChainedBroadcast[T] (@transient var value_ : T, local: Boolean) +extends Broadcast with Logging { + + def value = value_ + + ChainedBroadcast.synchronized { + ChainedBroadcast.values.put (uuid, value_) + } + + @transient var arrayOfBlocks: Array[BroadcastBlock] = null + @transient var totalBytes = -1 + @transient var totalBlocks = -1 + @transient var hasBlocks = 0 + + @transient var listenPortLock = new Object + @transient var guidePortLock = new Object + @transient var totalBlocksLock = new Object + @transient var hasBlocksLock = new Object + + @transient var pqOfSources = new PriorityQueue[SourceInfo] + + @transient var serveMR: ServeMultipleRequests = null + @transient var guideMR: GuideMultipleRequests = null + + @transient var hostAddress = InetAddress.getLocalHost.getHostAddress + @transient var listenPort = -1 + @transient var guidePort = -1 + + @transient var hasCopyInHDFS = false + @transient var stopBroadcast = false + + // Must call this after all the variables have been created/initialized + if (!local) { + sendBroadcast + } + + def sendBroadcast (): Unit = { + logInfo ("Local host address: " + hostAddress) + + // Store a persistent copy in HDFS + // TODO: Turned OFF for now + // val out = new ObjectOutputStream (DfsBroadcast.openFileForWriting(uuid)) + // out.writeObject (value_) + // out.close + // TODO: Fix this at some point + hasCopyInHDFS = true + + // Create a variableInfo object and store it in valueInfos + var variableInfo = blockifyObject (value_, ChainedBroadcast.BlockSize) + + guideMR = new GuideMultipleRequests + guideMR.setDaemon (true) + guideMR.start + logInfo ("GuideMultipleRequests started...") + + serveMR = new ServeMultipleRequests + serveMR.setDaemon (true) + serveMR.start + logInfo ("ServeMultipleRequests started...") + + // Prepare the value being broadcasted + // TODO: Refactoring and clean-up required here + arrayOfBlocks = variableInfo.arrayOfBlocks + totalBytes = variableInfo.totalBytes + totalBlocks = variableInfo.totalBlocks + hasBlocks = variableInfo.totalBlocks + + while (listenPort == -1) { + listenPortLock.synchronized { + listenPortLock.wait + } + } + + pqOfSources = new PriorityQueue[SourceInfo] + val masterSource_0 = + SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes, 0) + pqOfSources.add (masterSource_0) + + // Register with the Tracker + while (guidePort == -1) { + guidePortLock.synchronized { + guidePortLock.wait + } + } + ChainedBroadcast.registerValue (uuid, guidePort) + } + + private def readObject (in: ObjectInputStream): Unit = { + in.defaultReadObject + ChainedBroadcast.synchronized { + val cachedVal = ChainedBroadcast.values.get (uuid) + if (cachedVal != null) { + value_ = cachedVal.asInstanceOf[T] + } else { + // Initializing everything because Master will only send null/0 values + initializeSlaveVariables + + logInfo ("Local host address: " + hostAddress) + + serveMR = new ServeMultipleRequests + serveMR.setDaemon (true) + serveMR.start + logInfo ("ServeMultipleRequests started...") + + val start = System.nanoTime + + val receptionSucceeded = receiveBroadcast (uuid) + // If does not succeed, then get from HDFS copy + if (receptionSucceeded) { + value_ = unBlockifyObject[T] + ChainedBroadcast.values.put (uuid, value_) + } else { + val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid)) + value_ = fileIn.readObject.asInstanceOf[T] + ChainedBroadcast.values.put(uuid, value_) + fileIn.close + } + + val time = (System.nanoTime - start) / 1e9 + logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s") + } + } + } + + private def initializeSlaveVariables: Unit = { + arrayOfBlocks = null + totalBytes = -1 + totalBlocks = -1 + hasBlocks = 0 + + listenPortLock = new Object + totalBlocksLock = new Object + hasBlocksLock = new Object + + serveMR = null + + hostAddress = InetAddress.getLocalHost.getHostAddress + listenPort = -1 + + stopBroadcast = false + } + + private def blockifyObject (obj: T, blockSize: Int): VariableInfo = { + val baos = new ByteArrayOutputStream + val oos = new ObjectOutputStream (baos) + oos.writeObject (obj) + oos.close + baos.close + val byteArray = baos.toByteArray + val bais = new ByteArrayInputStream (byteArray) + + var blockNum = (byteArray.length / blockSize) + if (byteArray.length % blockSize != 0) + blockNum += 1 + + var retVal = new Array[BroadcastBlock] (blockNum) + var blockID = 0 + + for (i <- 0 until (byteArray.length, blockSize)) { + val thisBlockSize = Math.min (blockSize, byteArray.length - i) + var tempByteArray = new Array[Byte] (thisBlockSize) + val hasRead = bais.read (tempByteArray, 0, thisBlockSize) + + retVal (blockID) = new BroadcastBlock (blockID, tempByteArray) + blockID += 1 + } + bais.close + + var variableInfo = VariableInfo (retVal, blockNum, byteArray.length) + variableInfo.hasBlocks = blockNum + + return variableInfo + } + + private def unBlockifyObject[A]: A = { + var retByteArray = new Array[Byte] (totalBytes) + for (i <- 0 until totalBlocks) { + System.arraycopy (arrayOfBlocks(i).byteArray, 0, retByteArray, + i * ChainedBroadcast.BlockSize, arrayOfBlocks(i).byteArray.length) + } + byteArrayToObject (retByteArray) + } + + private def byteArrayToObject[A] (bytes: Array[Byte]): A = { + val in = new ObjectInputStream (new ByteArrayInputStream (bytes)) + val retVal = in.readObject.asInstanceOf[A] + in.close + return retVal + } + + def getMasterListenPort (variableUUID: UUID): Int = { + var clientSocketToTracker: Socket = null + var oosTracker: ObjectOutputStream = null + var oisTracker: ObjectInputStream = null + + var masterListenPort: Int = SourceInfo.TxOverGoToHDFS + + var retriesLeft = ChainedBroadcast.MaxRetryCount + do { + try { + // Connect to the tracker to find out the guide + val clientSocketToTracker = + new Socket(ChainedBroadcast.MasterHostAddress, ChainedBroadcast.MasterTrackerPort) + val oosTracker = + new ObjectOutputStream (clientSocketToTracker.getOutputStream) + oosTracker.flush + val oisTracker = + new ObjectInputStream (clientSocketToTracker.getInputStream) + + // Send UUID and receive masterListenPort + oosTracker.writeObject (uuid) + oosTracker.flush + masterListenPort = oisTracker.readObject.asInstanceOf[Int] + } catch { + case e: Exception => { + logInfo ("getMasterListenPort had a " + e) + } + } finally { + if (oisTracker != null) { + oisTracker.close + } + if (oosTracker != null) { + oosTracker.close + } + if (clientSocketToTracker != null) { + clientSocketToTracker.close + } + } + retriesLeft -= 1 + + Thread.sleep (ChainedBroadcast.ranGen.nextInt ( + ChainedBroadcast.MaxKnockInterval - ChainedBroadcast.MinKnockInterval) + + ChainedBroadcast.MinKnockInterval) + + } while (retriesLeft > 0 && masterListenPort == SourceInfo.TxNotStartedRetry) + + logInfo ("Got this guidePort from Tracker: " + masterListenPort) + return masterListenPort + } + + def receiveBroadcast (variableUUID: UUID): Boolean = { + val masterListenPort = getMasterListenPort (variableUUID) + + if (masterListenPort == SourceInfo.TxOverGoToHDFS || + masterListenPort == SourceInfo.TxNotStartedRetry) { + // TODO: SourceInfo.TxNotStartedRetry is not really in use because we go + // to HDFS anyway when receiveBroadcast returns false + return false + } + + // Wait until hostAddress and listenPort are created by the + // ServeMultipleRequests thread + while (listenPort == -1) { + listenPortLock.synchronized { + listenPortLock.wait + } + } + + var clientSocketToMaster: Socket = null + var oosMaster: ObjectOutputStream = null + var oisMaster: ObjectInputStream = null + + // Connect and receive broadcast from the specified source, retrying the + // specified number of times in case of failures + var retriesLeft = ChainedBroadcast.MaxRetryCount + do { + // Connect to Master and send this worker's Information + clientSocketToMaster = + new Socket(ChainedBroadcast.MasterHostAddress, masterListenPort) + // TODO: Guiding object connection is reusable + oosMaster = + new ObjectOutputStream (clientSocketToMaster.getOutputStream) + oosMaster.flush + oisMaster = + new ObjectInputStream (clientSocketToMaster.getInputStream) + + logInfo ("Connected to Master's guiding object") + + // Send local source information + oosMaster.writeObject(SourceInfo (hostAddress, listenPort, -1, -1, 0)) + oosMaster.flush + + // Receive source information from Master + var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo] + totalBlocks = sourceInfo.totalBlocks + arrayOfBlocks = new Array[BroadcastBlock] (totalBlocks) + totalBlocksLock.synchronized { + totalBlocksLock.notifyAll + } + totalBytes = sourceInfo.totalBytes + + logInfo ("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort) + + val start = System.nanoTime + val receptionSucceeded = receiveSingleTransmission (sourceInfo) + val time = (System.nanoTime - start) / 1e9 + + // Updating some statistics in sourceInfo. Master will be using them later + if (!receptionSucceeded) { + sourceInfo.receptionFailed = true + } + + // Send back statistics to the Master + oosMaster.writeObject (sourceInfo) + + if (oisMaster != null) { + oisMaster.close + } + if (oosMaster != null) { + oosMaster.close + } + if (clientSocketToMaster != null) { + clientSocketToMaster.close + } + + retriesLeft -= 1 + } while (retriesLeft > 0 && hasBlocks < totalBlocks) + + return (hasBlocks == totalBlocks) + } + + // Tries to receive broadcast from the source and returns Boolean status. + // This might be called multiple times to retry a defined number of times. + private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = { + var clientSocketToSource: Socket = null + var oosSource: ObjectOutputStream = null + var oisSource: ObjectInputStream = null + + var receptionSucceeded = false + try { + // Connect to the source to get the object itself + clientSocketToSource = + new Socket (sourceInfo.hostAddress, sourceInfo.listenPort) + oosSource = + new ObjectOutputStream (clientSocketToSource.getOutputStream) + oosSource.flush + oisSource = + new ObjectInputStream (clientSocketToSource.getInputStream) + + logInfo ("Inside receiveSingleTransmission") + logInfo ("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks) + + // Send the range + oosSource.writeObject((hasBlocks, totalBlocks)) + oosSource.flush + + for (i <- hasBlocks until totalBlocks) { + val recvStartTime = System.currentTimeMillis + val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock] + val receptionTime = (System.currentTimeMillis - recvStartTime) + + logInfo ("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.") + + arrayOfBlocks(hasBlocks) = bcBlock + hasBlocks += 1 + // Set to true if at least one block is received + receptionSucceeded = true + hasBlocksLock.synchronized { + hasBlocksLock.notifyAll + } + } + } catch { + case e: Exception => { + logInfo ("receiveSingleTransmission had a " + e) + } + } finally { + if (oisSource != null) { + oisSource.close + } + if (oosSource != null) { + oosSource.close + } + if (clientSocketToSource != null) { + clientSocketToSource.close + } + } + + return receptionSucceeded + } + + class GuideMultipleRequests + extends Thread with Logging { + // Keep track of sources that have completed reception + private var setOfCompletedSources = Set[SourceInfo] () + + override def run: Unit = { + var threadPool = Broadcast.newDaemonCachedThreadPool + var serverSocket: ServerSocket = null + + serverSocket = new ServerSocket (0) + guidePort = serverSocket.getLocalPort + logInfo ("GuideMultipleRequests => " + serverSocket + " " + guidePort) + + guidePortLock.synchronized { + guidePortLock.notifyAll + } + + try { + // Don't stop until there is a copy in HDFS + while (!stopBroadcast || !hasCopyInHDFS) { + var clientSocket: Socket = null + try { + serverSocket.setSoTimeout (ChainedBroadcast.ServerSocketTimeout) + clientSocket = serverSocket.accept + } catch { + case e: Exception => { + logInfo ("GuideMultipleRequests Timeout.") + + // Stop broadcast if at least one worker has connected and + // everyone connected so far are done. Comparing with + // pqOfSources.size - 1, because it includes the Guide itself + if (pqOfSources.size > 1 && + setOfCompletedSources.size == pqOfSources.size - 1) { + stopBroadcast = true + } + } + } + if (clientSocket != null) { + logInfo ("Guide: Accepted new client connection: " + clientSocket) + try { + threadPool.execute (new GuideSingleRequest (clientSocket)) + } catch { + // In failure, close the socket here; else, the thread will close it + case ioe: IOException => clientSocket.close + } + } + } + + logInfo ("Sending stopBroadcast notifications...") + sendStopBroadcastNotifications + + ChainedBroadcast.unregisterValue (uuid) + } finally { + if (serverSocket != null) { + logInfo ("GuideMultipleRequests now stopping...") + serverSocket.close + } + } + + // Shutdown the thread pool + threadPool.shutdown + } + + private def sendStopBroadcastNotifications: Unit = { + pqOfSources.synchronized { + var pqIter = pqOfSources.iterator + while (pqIter.hasNext) { + var sourceInfo = pqIter.next + + var guideSocketToSource: Socket = null + var gosSource: ObjectOutputStream = null + var gisSource: ObjectInputStream = null + + try { + // Connect to the source + guideSocketToSource = + new Socket (sourceInfo.hostAddress, sourceInfo.listenPort) + gosSource = + new ObjectOutputStream (guideSocketToSource.getOutputStream) + gosSource.flush + gisSource = + new ObjectInputStream (guideSocketToSource.getInputStream) + + // Send stopBroadcast signal. Range = SourceInfo.StopBroadcast*2 + gosSource.writeObject ((SourceInfo.StopBroadcast, + SourceInfo.StopBroadcast)) + gosSource.flush + } catch { + case e: Exception => { + logInfo ("sendStopBroadcastNotifications had a " + e) + } + } finally { + if (gisSource != null) { + gisSource.close + } + if (gosSource != null) { + gosSource.close + } + if (guideSocketToSource != null) { + guideSocketToSource.close + } + } + } + } + } + + class GuideSingleRequest (val clientSocket: Socket) + extends Thread with Logging { + private val oos = new ObjectOutputStream (clientSocket.getOutputStream) + oos.flush + private val ois = new ObjectInputStream (clientSocket.getInputStream) + + private var selectedSourceInfo: SourceInfo = null + private var thisWorkerInfo:SourceInfo = null + + override def run: Unit = { + try { + logInfo ("new GuideSingleRequest is running") + // Connecting worker is sending in its hostAddress and listenPort it will + // be listening to. ReplicaID is 0 and other fields are invalid (-1) + var sourceInfo = ois.readObject.asInstanceOf[SourceInfo] + + pqOfSources.synchronized { + // Select a suitable source and send it back to the worker + selectedSourceInfo = selectSuitableSource (sourceInfo) + logInfo ("Sending selectedSourceInfo: " + selectedSourceInfo) + oos.writeObject (selectedSourceInfo) + oos.flush + + // Add this new (if it can finish) source to the PQ of sources + thisWorkerInfo = SourceInfo (sourceInfo.hostAddress, + sourceInfo.listenPort, totalBlocks, totalBytes, 0) + logInfo ("Adding possible new source to pqOfSources: " + thisWorkerInfo) + pqOfSources.add (thisWorkerInfo) + } + + // Wait till the whole transfer is done. Then receive and update source + // statistics in pqOfSources + sourceInfo = ois.readObject.asInstanceOf[SourceInfo] + + pqOfSources.synchronized { + // This should work since SourceInfo is a case class + assert (pqOfSources.contains (selectedSourceInfo)) + + // Remove first + pqOfSources.remove (selectedSourceInfo) + // TODO: Removing a source based on just one failure notification! + + // Update sourceInfo and put it back in, IF reception succeeded + if (!sourceInfo.receptionFailed) { + // Add thisWorkerInfo to sources that have completed reception + setOfCompletedSources += thisWorkerInfo + + selectedSourceInfo.currentLeechers -= 1 + + // Put it back + pqOfSources.add (selectedSourceInfo) + } + } + } catch { + // If something went wrong, e.g., the worker at the other end died etc. + // then close everything up + case e: Exception => { + // Assuming that exception caused due to receiver worker failure. + // Remove failed worker from pqOfSources and update leecherCount of + // corresponding source worker + pqOfSources.synchronized { + if (selectedSourceInfo != null) { + // Remove first + pqOfSources.remove (selectedSourceInfo) + // Update leecher count and put it back in + selectedSourceInfo.currentLeechers -= 1 + pqOfSources.add (selectedSourceInfo) + } + + // Remove thisWorkerInfo + if (pqOfSources != null) { + pqOfSources.remove (thisWorkerInfo) + } + } + } + } finally { + ois.close + oos.close + clientSocket.close + } + } + + // TODO: Caller must have a synchronized block on pqOfSources + // TODO: If a worker fails to get the broadcasted variable from a source and + // comes back to Master, this function might choose the worker itself as a + // source tp create a dependency cycle (this worker was put into pqOfSources + // as a streming source when it first arrived). The length of this cycle can + // be arbitrarily long. + private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = { + // Select one based on the ordering strategy (e.g., least leechers etc.) + // take is a blocking call removing the element from PQ + var selectedSource = pqOfSources.poll + assert (selectedSource != null) + // Update leecher count + selectedSource.currentLeechers += 1 + // Add it back and then return + pqOfSources.add (selectedSource) + return selectedSource + } + } + } + + class ServeMultipleRequests + extends Thread with Logging { + override def run: Unit = { + var threadPool = Broadcast.newDaemonCachedThreadPool + var serverSocket: ServerSocket = null + + serverSocket = new ServerSocket (0) + listenPort = serverSocket.getLocalPort + logInfo ("ServeMultipleRequests started with " + serverSocket) + + listenPortLock.synchronized { + listenPortLock.notifyAll + } + + try { + while (!stopBroadcast) { + var clientSocket: Socket = null + try { + serverSocket.setSoTimeout (ChainedBroadcast.ServerSocketTimeout) + clientSocket = serverSocket.accept + } catch { + case e: Exception => { + logInfo ("ServeMultipleRequests Timeout.") + } + } + if (clientSocket != null) { + logInfo ("Serve: Accepted new client connection: " + clientSocket) + try { + threadPool.execute (new ServeSingleRequest (clientSocket)) + } catch { + // In failure, close socket here; else, the thread will close it + case ioe: IOException => clientSocket.close + } + } + } + } finally { + if (serverSocket != null) { + logInfo ("ServeMultipleRequests now stopping...") + serverSocket.close + } + } + + // Shutdown the thread pool + threadPool.shutdown + } + + class ServeSingleRequest (val clientSocket: Socket) + extends Thread with Logging { + private val oos = new ObjectOutputStream (clientSocket.getOutputStream) + oos.flush + private val ois = new ObjectInputStream (clientSocket.getInputStream) + + private var sendFrom = 0 + private var sendUntil = totalBlocks + + override def run: Unit = { + try { + logInfo ("new ServeSingleRequest is running") + + // Receive range to send + var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)] + sendFrom = rangeToSend._1 + sendUntil = rangeToSend._2 + + if (sendFrom == SourceInfo.StopBroadcast && + sendUntil == SourceInfo.StopBroadcast) { + stopBroadcast = true + } else { + // Carry on + sendObject + } + } catch { + // If something went wrong, e.g., the worker at the other end died etc. + // then close everything up + case e: Exception => { + logInfo ("ServeSingleRequest had a " + e) + } + } finally { + logInfo ("ServeSingleRequest is closing streams and sockets") + ois.close + oos.close + clientSocket.close + } + } + + private def sendObject: Unit = { + // Wait till receiving the SourceInfo from Master + while (totalBlocks == -1) { + totalBlocksLock.synchronized { + totalBlocksLock.wait + } + } + + for (i <- sendFrom until sendUntil) { + while (i == hasBlocks) { + hasBlocksLock.synchronized { + hasBlocksLock.wait + } + } + try { + oos.writeObject (arrayOfBlocks(i)) + oos.flush + } catch { + case e: Exception => { + logInfo ("sendObject had a " + e) + } + } + logInfo ("Sent block: " + i + " to " + clientSocket) + } + } + } + } +} + +private object ChainedBroadcast +extends Logging { + val values = new MapMaker ().softValues ().makeMap[UUID, Any] + + var valueToGuidePortMap = Map[UUID, Int] () + + // Random number generator + var ranGen = new Random + + private var initialized = false + private var isMaster_ = false + + private var MasterHostAddress_ = "127.0.0.1" + private var MasterTrackerPort_ : Int = 22222 + private var BlockSize_ : Int = 512 * 1024 + private var MaxRetryCount_ : Int = 2 + + private var TrackerSocketTimeout_ : Int = 50000 + private var ServerSocketTimeout_ : Int = 10000 + + private var trackMV: TrackMultipleValues = null + + private var MinKnockInterval_ = 500 + private var MaxKnockInterval_ = 999 + + def initialize (isMaster__ : Boolean): Unit = { + synchronized { + if (!initialized) { + MasterHostAddress_ = + System.getProperty ("spark.broadcast.MasterHostAddress", "127.0.0.1") + MasterTrackerPort_ = + System.getProperty ("spark.broadcast.MasterTrackerPort", "22222").toInt + BlockSize_ = + System.getProperty ("spark.broadcast.BlockSize", "512").toInt * 1024 + MaxRetryCount_ = + System.getProperty ("spark.broadcast.MaxRetryCount", "2").toInt + + TrackerSocketTimeout_ = + System.getProperty ("spark.broadcast.TrackerSocketTimeout", "50000").toInt + ServerSocketTimeout_ = + System.getProperty ("spark.broadcast.ServerSocketTimeout", "10000").toInt + + MinKnockInterval_ = + System.getProperty ("spark.broadcast.MinKnockInterval", "500").toInt + MaxKnockInterval_ = + System.getProperty ("spark.broadcast.MaxKnockInterval", "999").toInt + + isMaster_ = isMaster__ + + if (isMaster) { + trackMV = new TrackMultipleValues + trackMV.setDaemon (true) + trackMV.start + logInfo ("TrackMultipleValues started...") + } + + initialized = true + } + } + } + + def MasterHostAddress = MasterHostAddress_ + def MasterTrackerPort = MasterTrackerPort_ + def BlockSize = BlockSize_ + def MaxRetryCount = MaxRetryCount_ + + def TrackerSocketTimeout = TrackerSocketTimeout_ + def ServerSocketTimeout = ServerSocketTimeout_ + + def isMaster = isMaster_ + + def MinKnockInterval = MinKnockInterval_ + def MaxKnockInterval = MaxKnockInterval_ + + def registerValue (uuid: UUID, guidePort: Int): Unit = { + valueToGuidePortMap.synchronized { + valueToGuidePortMap += (uuid -> guidePort) + logInfo ("New value registered with the Tracker " + valueToGuidePortMap) + } + } + + def unregisterValue (uuid: UUID): Unit = { + valueToGuidePortMap.synchronized { + valueToGuidePortMap (uuid) = SourceInfo.TxOverGoToHDFS + logInfo ("Value unregistered from the Tracker " + valueToGuidePortMap) + } + } + + class TrackMultipleValues + extends Thread with Logging { + override def run: Unit = { + var threadPool = Broadcast.newDaemonCachedThreadPool + var serverSocket: ServerSocket = null + + serverSocket = new ServerSocket (ChainedBroadcast.MasterTrackerPort) + logInfo ("TrackMultipleValues" + serverSocket) + + try { + while (true) { + var clientSocket: Socket = null + try { + serverSocket.setSoTimeout (TrackerSocketTimeout) + clientSocket = serverSocket.accept + } catch { + case e: Exception => { + logInfo ("TrackMultipleValues Timeout. Stopping listening...") + } + } + + if (clientSocket != null) { + try { + threadPool.execute (new Thread { + override def run: Unit = { + val oos = new ObjectOutputStream (clientSocket.getOutputStream) + oos.flush + val ois = new ObjectInputStream (clientSocket.getInputStream) + try { + val uuid = ois.readObject.asInstanceOf[UUID] + var guidePort = + if (valueToGuidePortMap.contains (uuid)) { + valueToGuidePortMap (uuid) + } else SourceInfo.TxNotStartedRetry + logInfo ("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + guidePort) + oos.writeObject (guidePort) + } catch { + case e: Exception => { + logInfo ("TrackMultipleValues had a " + e) + } + } finally { + ois.close + oos.close + clientSocket.close + } + } + }) + } catch { + // In failure, close socket here; else, client thread will close + case ioe: IOException => clientSocket.close + } + } + } + } finally { + serverSocket.close + } + + // Shutdown the thread pool + threadPool.shutdown + } + } +} diff --git a/src/scala/spark/DfsBroadcast.scala b/src/scala/spark/DfsBroadcast.scala new file mode 100644 index 0000000000..5be5f98e8c --- /dev/null +++ b/src/scala/spark/DfsBroadcast.scala @@ -0,0 +1,127 @@ +package spark + +import com.google.common.collect.MapMaker + +import java.io._ +import java.net._ +import java.util.UUID + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem} + +import spark.compress.lzf.{LZFInputStream, LZFOutputStream} + +@serializable +class DfsBroadcast[T](@transient var value_ : T, local: Boolean) +extends Broadcast with Logging { + + def value = value_ + + DfsBroadcast.synchronized { + DfsBroadcast.values.put(uuid, value_) + } + + if (!local) { + sendBroadcast + } + + def sendBroadcast (): Unit = { + val out = new ObjectOutputStream (DfsBroadcast.openFileForWriting(uuid)) + out.writeObject (value_) + out.close + } + + // Called by JVM when deserializing an object + private def readObject(in: ObjectInputStream): Unit = { + in.defaultReadObject + DfsBroadcast.synchronized { + val cachedVal = DfsBroadcast.values.get(uuid) + if (cachedVal != null) { + value_ = cachedVal.asInstanceOf[T] + } else { + logInfo( "Started reading Broadcasted variable " + uuid) + val start = System.nanoTime + + val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid)) + value_ = fileIn.readObject.asInstanceOf[T] + DfsBroadcast.values.put(uuid, value_) + fileIn.close + + val time = (System.nanoTime - start) / 1e9 + logInfo( "Reading Broadcasted variable " + uuid + " took " + time + " s") + } + } + } +} + +private object DfsBroadcast +extends Logging { + val values = new MapMaker ().softValues ().makeMap[UUID, Any] + + private var initialized = false + + private var fileSystem: FileSystem = null + private var workDir: String = null + private var compress: Boolean = false + private var bufferSize: Int = 65536 + + def initialize (): Unit = { + synchronized { + if (!initialized) { + bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + val dfs = System.getProperty("spark.dfs", "file:///") + if (!dfs.startsWith("file://")) { + val conf = new Configuration() + conf.setInt("io.file.buffer.size", bufferSize) + val rep = System.getProperty("spark.dfs.replication", "3").toInt + conf.setInt("dfs.replication", rep) + fileSystem = FileSystem.get(new URI(dfs), conf) + } + workDir = System.getProperty("spark.dfs.workdir", "/tmp") + compress = System.getProperty("spark.compress", "false").toBoolean + + initialized = true + } + } + } + + private def getPath(uuid: UUID) = new Path(workDir + "/broadcast-" + uuid) + + def openFileForReading(uuid: UUID): InputStream = { + val fileStream = if (fileSystem != null) { + fileSystem.open(getPath(uuid)) + } else { + // Local filesystem + new FileInputStream(getPath(uuid).toString) + } + + if (compress) { + // LZF stream does its own buffering + new LZFInputStream(fileStream) + } else if (fileSystem == null) { + new BufferedInputStream(fileStream, bufferSize) + } else { + // Hadoop streams do their own buffering + fileStream + } + } + + def openFileForWriting(uuid: UUID): OutputStream = { + val fileStream = if (fileSystem != null) { + fileSystem.create(getPath(uuid)) + } else { + // Local filesystem + new FileOutputStream(getPath(uuid).toString) + } + + if (compress) { + // LZF stream does its own buffering + new LZFOutputStream(fileStream) + } else if (fileSystem == null) { + new BufferedOutputStream(fileStream, bufferSize) + } else { + // Hadoop streams do their own buffering + fileStream + } + } +} diff --git a/src/scala/spark/SparkContext.scala b/src/scala/spark/SparkContext.scala index 02e80c7756..8b8e408266 100644 --- a/src/scala/spark/SparkContext.scala +++ b/src/scala/spark/SparkContext.scala @@ -3,6 +3,7 @@ package spark import java.io._ import scala.collection.mutable.ArrayBuffer +import scala.actors.Actor._ import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.mapred.SequenceFileInputFormat @@ -46,6 +47,23 @@ extends Logging { def textFile(path: String): RDD[String] = new HadoopTextFile(this, path) + // TODO: Keep around a weak hash map of values to Cached versions? + // def broadcast[T](value: T) = new DfsBroadcast(value, isLocal) + def broadcast[T](value: T) = new ChainedBroadcast(value, isLocal) + +// def broadcast[T](value: T) = { +// val broadcastClass = System.getProperty("spark.broadcast.Class", +// "spark.ChainedBroadcast") +// val booleanArgs = Array[AnyRef] (local.asInstanceOf[AnyRef]) +// Class.forName(broadcastClass).getConstructors()(0).newInstance(booleanArgs:_*).asInstanceOf[Class.forName(broadcastClass)] +// } + +// def initialize() { +// val cacheClass = System.getProperty("spark.cache.class", +// "spark.SoftReferenceCache") +// instance = Class.forName(cacheClass).newInstance().asInstanceOf[Cache] +// } + /** Get an RDD for a Hadoop file with an arbitrary InputFormat */ def hadoopFile[K, V](path: String, inputFormatClass: Class[_ <: InputFormat[K, V]], @@ -97,10 +115,6 @@ extends Logging { def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) = new Accumulator(initialValue, param) - // TODO: Keep around a weak hash map of values to Cached versions? - def broadcast[T](value: T) = new CentralizedHDFSBroadcast(value, isLocal) - //def broadcast[T](value: T) = new ChainedStreamingBroadcast(value, isLocal) - // Stop the SparkContext def stop() { scheduler.stop() diff --git a/src/scala/spark/repl/ClassServer.scala b/src/scala/spark/repl/ClassServer.scala new file mode 100644 index 0000000000..6a40d92765 --- /dev/null +++ b/src/scala/spark/repl/ClassServer.scala @@ -0,0 +1,77 @@ +package spark.repl + +import java.io.File +import java.net.InetAddress + +import org.eclipse.jetty.server.Server +import org.eclipse.jetty.server.handler.DefaultHandler +import org.eclipse.jetty.server.handler.HandlerList +import org.eclipse.jetty.server.handler.ResourceHandler + +import spark.Logging + + +/** + * Exception type thrown by ClassServer when it is in the wrong state + * for an operation. + */ +class ServerStateException(message: String) extends Exception(message) + + +/** + * An HTTP server used by the interpreter to allow worker nodes to access + * class files created as the user types in lines of code. This is just a + * wrapper around a Jetty embedded HTTP server. + */ +class ClassServer(classDir: File) extends Logging { + private var server: Server = null + private var port: Int = -1 + + def start() { + if (server != null) { + throw new ServerStateException("Server is already started") + } else { + server = new Server(0) + val resHandler = new ResourceHandler + resHandler.setResourceBase(classDir.getAbsolutePath) + val handlerList = new HandlerList + handlerList.setHandlers(Array(resHandler, new DefaultHandler)) + server.setHandler(handlerList) + server.start() + port = server.getConnectors()(0).getLocalPort() + logDebug("ClassServer started at " + uri) + } + } + + def stop() { + if (server == null) { + throw new ServerStateException("Server is already stopped") + } else { + server.stop() + port = -1 + server = null + } + } + + /** + * Get the URI of this HTTP server (http://host:port) + */ + def uri: String = { + if (server == null) { + throw new ServerStateException("Server is not started") + } else { + return "http://" + getLocalIpAddress + ":" + port + } + } + + /** + * Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4) + */ + private def getLocalIpAddress: String = { + // Get local IP as an array of four bytes + val bytes = InetAddress.getLocalHost().getAddress() + // Convert the bytes to ints (keeping in mind that they may be negative) + // and join them into a string + return bytes.map(b => (b.toInt + 256) % 256).mkString(".") + } +} |