aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xrun2
-rw-r--r--src/examples/BroadcastTest.scala24
-rw-r--r--src/examples/SparkALS.scala12
-rw-r--r--src/scala/spark/Broadcast.scala798
-rw-r--r--src/scala/spark/Cached.scala110
-rw-r--r--src/scala/spark/Executor.scala4
-rw-r--r--src/scala/spark/HdfsFile.scala38
-rw-r--r--src/scala/spark/NexusScheduler.scala317
-rw-r--r--src/scala/spark/SparkContext.scala5
-rw-r--r--src/scala/spark/Task.scala4
-rw-r--r--src/test/spark/repl/ReplSuite.scala18
11 files changed, 1039 insertions, 293 deletions
diff --git a/run b/run
index 456615fba4..c1156892ad 100755
--- a/run
+++ b/run
@@ -4,7 +4,7 @@
FWDIR=`dirname $0`
# Set JAVA_OPTS to be able to load libnexus.so and set various other misc options
-JAVA_OPTS="-Djava.library.path=$FWDIR/third_party:$FWDIR/src/native -Xmx750m"
+export JAVA_OPTS="-Djava.library.path=$FWDIR/third_party:$FWDIR/src/native -Xmx2000m -Dspark.broadcast.masterHostAddress=127.0.0.1 -Dspark.broadcast.masterListenPort=11111 -Dspark.broadcast.blockSize=1024 -Dspark.broadcast.maxRetryCount=2 -Dspark.broadcast.serverSocketTimout=50000 -Dspark.broadcast.dualMode=false"
if [ -e $FWDIR/conf/java-opts ] ; then
JAVA_OPTS+=" `cat $FWDIR/conf/java-opts`"
fi
diff --git a/src/examples/BroadcastTest.scala b/src/examples/BroadcastTest.scala
new file mode 100644
index 0000000000..7764013413
--- /dev/null
+++ b/src/examples/BroadcastTest.scala
@@ -0,0 +1,24 @@
+import spark.SparkContext
+
+object BroadcastTest {
+ def main(args: Array[String]) {
+ if (args.length == 0) {
+ System.err.println("Usage: BroadcastTest <host> [<slices>]")
+ System.exit(1)
+ }
+ val spark = new SparkContext(args(0), "Broadcast Test")
+ val slices = if (args.length > 1) args(1).toInt else 2
+ val num = if (args.length > 2) args(2).toInt else 1000000
+
+ var arr = new Array[Int](num)
+ for (i <- 0 until arr.length)
+ arr(i) = i
+
+ val barr = spark.broadcast(arr)
+ spark.parallelize(1 to 10, slices).foreach {
+ println("in task: barr = " + barr)
+ i => println(barr.value.size)
+ }
+ }
+}
+
diff --git a/src/examples/SparkALS.scala b/src/examples/SparkALS.scala
index 2fd58ed3a5..38dd0e665d 100644
--- a/src/examples/SparkALS.scala
+++ b/src/examples/SparkALS.scala
@@ -119,18 +119,18 @@ object SparkALS {
// Iteratively update movies then users
val Rc = spark.broadcast(R)
- var msb = spark.broadcast(ms)
- var usb = spark.broadcast(us)
+ var msc = spark.broadcast(ms)
+ var usc = spark.broadcast(us)
for (iter <- 1 to ITERATIONS) {
println("Iteration " + iter + ":")
ms = spark.parallelize(0 until M, slices)
- .map(i => updateMovie(i, msb.value(i), usb.value, Rc.value))
+ .map(i => updateMovie(i, msc.value(i), usc.value, Rc.value))
.toArray
- msb = spark.broadcast(ms) // Re-broadcast ms because it was updated
+ msc = spark.broadcast(ms) // Re-broadcast ms because it was updated
us = spark.parallelize(0 until U, slices)
- .map(i => updateUser(i, usb.value(i), msb.value, Rc.value))
+ .map(i => updateUser(i, usc.value(i), msc.value, Rc.value))
.toArray
- usb = spark.broadcast(us) // Re-broadcast us because it was updated
+ usc = spark.broadcast(us) // Re-broadcast us because it was updated
println("RMSE = " + rmse(R, ms, us))
println()
}
diff --git a/src/scala/spark/Broadcast.scala b/src/scala/spark/Broadcast.scala
new file mode 100644
index 0000000000..2da5e28a0a
--- /dev/null
+++ b/src/scala/spark/Broadcast.scala
@@ -0,0 +1,798 @@
+package spark
+
+import java.io._
+import java.net._
+import java.util.{UUID, PriorityQueue, Comparator}
+
+import com.google.common.collect.MapMaker
+
+import java.util.concurrent.{Executors, ExecutorService}
+
+import scala.actors.Actor
+import scala.actors.Actor._
+
+import scala.collection.mutable.Map
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem}
+
+import spark.compress.lzf.{LZFInputStream, LZFOutputStream}
+
+@serializable
+trait BroadcastRecipe {
+ val uuid = UUID.randomUUID
+
+ // We cannot have an abstract readObject here due to some weird issues with
+ // readObject having to be 'private' in sub-classes. Possibly a Scala bug!
+ def sendBroadcast: Unit
+
+ override def toString = "spark.Broadcast(" + uuid + ")"
+}
+
+// TODO: Should think about storing in HDFS in the future
+// TODO: Right, now no parallelization between multiple broadcasts
+@serializable
+class ChainedStreamingBroadcast[T] (@transient var value_ : T, local: Boolean)
+ extends BroadcastRecipe {
+
+ def value = value_
+
+ BroadcastCS.synchronized { BroadcastCS.values.put (uuid, value_) }
+
+ if (!local) { sendBroadcast }
+
+ def sendBroadcast () {
+ // Create a variableInfo object and store it in valueInfos
+ var variableInfo = blockifyObject (value_, BroadcastCS.blockSize)
+ // TODO: Even though this part is not in use now, there is problem in the
+ // following statement. Shouldn't use constant port and hostAddress anymore?
+ // val masterSource =
+ // new SourceInfo (BroadcastCS.masterHostAddress, BroadcastCS.masterListenPort,
+ // variableInfo.totalBlocks, variableInfo.totalBytes, 0)
+ // variableInfo.pqOfSources.add (masterSource)
+
+ BroadcastCS.synchronized {
+ // BroadcastCS.valueInfos.put (uuid, variableInfo)
+
+ // TODO: Not using variableInfo in current implementation. Manually
+ // setting all the variables inside BroadcastCS object
+
+ BroadcastCS.initializeVariable (variableInfo)
+ }
+
+ // Now store a persistent copy in HDFS, just in case
+ val out = new ObjectOutputStream (BroadcastCH.openFileForWriting(uuid))
+ out.writeObject (value_)
+ out.close
+ }
+
+ private def readObject (in: ObjectInputStream) {
+ in.defaultReadObject
+ BroadcastCS.synchronized {
+ val cachedVal = BroadcastCS.values.get (uuid)
+ if (cachedVal != null) {
+ value_ = cachedVal.asInstanceOf[T]
+ } else {
+ // Only a single worker (the first one) in the same node can ever be
+ // here. The rest will always get the value ready
+ val start = System.nanoTime
+
+ val retByteArray = BroadcastCS.receiveBroadcast (uuid)
+ // If does not succeed, then get from HDFS copy
+ if (retByteArray != null) {
+ value_ = byteArrayToObject[T] (retByteArray)
+ BroadcastCS.values.put (uuid, value_)
+ // val variableInfo = blockifyObject (value_, BroadcastCS.blockSize)
+ // BroadcastCS.valueInfos.put (uuid, variableInfo)
+ } else {
+ val fileIn = new ObjectInputStream(BroadcastCH.openFileForReading(uuid))
+ value_ = fileIn.readObject.asInstanceOf[T]
+ BroadcastCH.values.put(uuid, value_)
+ fileIn.close
+ }
+
+ val time = (System.nanoTime - start) / 1e9
+ println( System.currentTimeMillis + ": " + "Reading Broadcasted variable " + uuid + " took " + time + " s")
+ }
+ }
+ }
+
+ private def blockifyObject (obj: T, blockSize: Int): VariableInfo = {
+ val baos = new ByteArrayOutputStream
+ val oos = new ObjectOutputStream (baos)
+ oos.writeObject (obj)
+ oos.close
+ baos.close
+ val byteArray = baos.toByteArray
+ val bais = new ByteArrayInputStream (byteArray)
+
+ var blockNum = (byteArray.length / blockSize)
+ if (byteArray.length % blockSize != 0)
+ blockNum += 1
+
+ var retVal = new Array[BroadcastBlock] (blockNum)
+ var blockID = 0
+
+ // TODO: What happens in byteArray.length == 0 => blockNum == 0
+ for (i <- 0 until (byteArray.length, blockSize)) {
+ val thisBlockSize = Math.min (blockSize, byteArray.length - i)
+ var tempByteArray = new Array[Byte] (thisBlockSize)
+ val hasRead = bais.read (tempByteArray, 0, thisBlockSize)
+
+ retVal (blockID) = new BroadcastBlock (blockID, tempByteArray)
+ blockID += 1
+ }
+ bais.close
+
+ var variableInfo = VariableInfo (retVal, blockNum, byteArray.length)
+ variableInfo.hasBlocks = blockNum
+
+ return variableInfo
+ }
+
+ private def byteArrayToObject[A] (bytes: Array[Byte]): A = {
+ val in = new ObjectInputStream (new ByteArrayInputStream (bytes))
+ val retVal = in.readObject.asInstanceOf[A]
+ in.close
+ return retVal
+ }
+
+ private def getByteArrayOutputStream (obj: T): ByteArrayOutputStream = {
+ val bOut = new ByteArrayOutputStream
+ val out = new ObjectOutputStream (bOut)
+ out.writeObject (obj)
+ out.close
+ bOut.close
+ return bOut
+ }
+}
+
+@serializable
+class CentralizedHDFSBroadcast[T](@transient var value_ : T, local: Boolean)
+ extends BroadcastRecipe {
+
+ def value = value_
+
+ BroadcastCH.synchronized { BroadcastCH.values.put(uuid, value_) }
+
+ if (!local) { sendBroadcast }
+
+ def sendBroadcast () {
+ val out = new ObjectOutputStream (BroadcastCH.openFileForWriting(uuid))
+ out.writeObject (value_)
+ out.close
+ }
+
+ // Called by Java when deserializing an object
+ private def readObject(in: ObjectInputStream) {
+ in.defaultReadObject
+ BroadcastCH.synchronized {
+ val cachedVal = BroadcastCH.values.get(uuid)
+ if (cachedVal != null) {
+ value_ = cachedVal.asInstanceOf[T]
+ } else {
+ val start = System.nanoTime
+
+ val fileIn = new ObjectInputStream(BroadcastCH.openFileForReading(uuid))
+ value_ = fileIn.readObject.asInstanceOf[T]
+ BroadcastCH.values.put(uuid, value_)
+ fileIn.close
+
+ val time = (System.nanoTime - start) / 1e9
+ println( System.currentTimeMillis + ": " + "Reading Broadcasted variable " + uuid + " took " + time + " s")
+ }
+ }
+ }
+}
+
+@serializable
+case class SourceInfo (val hostAddress: String, val listenPort: Int,
+ val totalBlocks: Int, val totalBytes: Int, val replicaID: Int)
+ extends Comparable [SourceInfo]{
+
+ var currentLeechers = 0
+ var receptionFailed = false
+
+ def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers)
+}
+
+@serializable
+case class BroadcastBlock (val blockID: Int, val byteArray: Array[Byte]) { }
+
+@serializable
+case class VariableInfo (@transient val arrayOfBlocks : Array[BroadcastBlock],
+ val totalBlocks: Int, val totalBytes: Int) {
+
+ @transient var hasBlocks = 0
+
+ val listenPortLock = new AnyRef
+ val totalBlocksLock = new AnyRef
+ val hasBlocksLock = new AnyRef
+
+ @transient var pqOfSources = new PriorityQueue[SourceInfo]
+}
+
+private object Broadcast {
+ private var initialized = false
+
+ // Will be called by SparkContext or Executor before using Broadcast
+ // Calls all other initializers here
+ def initialize (isMaster: Boolean) {
+ synchronized {
+ if (!initialized) {
+ // Initialization for CentralizedHDFSBroadcast
+ BroadcastCH.initialize
+ // Initialization for ChainedStreamingBroadcast
+ BroadcastCS.initialize (isMaster)
+
+ initialized = true
+ }
+ }
+ }
+}
+
+private object BroadcastCS {
+ val values = new MapMaker ().softValues ().makeMap[UUID, Any]
+ // val valueInfos = new MapMaker ().softValues ().makeMap[UUID, Any]
+
+ // private var valueToPort = Map[UUID, Int] ()
+
+ private var initialized = false
+ private var isMaster_ = false
+
+ private var masterHostAddress_ = "127.0.0.1"
+ private var masterListenPort_ : Int = 11111
+ private var blockSize_ : Int = 512 * 1024
+ private var maxRetryCount_ : Int = 2
+ private var serverSocketTimout_ : Int = 50000
+ private var dualMode_ : Boolean = false
+
+ private val hostAddress = InetAddress.getLocalHost.getHostAddress
+ private var listenPort = -1
+
+ var arrayOfBlocks: Array[BroadcastBlock] = null
+ var totalBytes = -1
+ var totalBlocks = -1
+ var hasBlocks = 0
+
+ val listenPortLock = new Object
+ val totalBlocksLock = new Object
+ val hasBlocksLock = new Object
+
+ var pqOfSources = new PriorityQueue[SourceInfo]
+
+ private var serveMR: ServeMultipleRequests = null
+ private var guideMR: GuideMultipleRequests = null
+
+ def initialize (isMaster__ : Boolean) {
+ synchronized {
+ if (!initialized) {
+ masterHostAddress_ =
+ System.getProperty ("spark.broadcast.masterHostAddress", "127.0.0.1")
+ masterListenPort_ =
+ System.getProperty ("spark.broadcast.masterListenPort", "11111").toInt
+ blockSize_ =
+ System.getProperty ("spark.broadcast.blockSize", "512").toInt * 1024
+ maxRetryCount_ =
+ System.getProperty ("spark.broadcast.maxRetryCount", "2").toInt
+ serverSocketTimout_ =
+ System.getProperty ("spark.broadcast.serverSocketTimout", "50000").toInt
+ dualMode_ =
+ System.getProperty ("spark.broadcast.dualMode", "false").toBoolean
+
+ isMaster_ = isMaster__
+
+ if (isMaster) {
+ guideMR = new GuideMultipleRequests
+ // guideMR.setDaemon (true)
+ guideMR.start
+ println (System.currentTimeMillis + ": " + "GuideMultipleRequests started")
+ }
+ serveMR = new ServeMultipleRequests
+ // serveMR.setDaemon (true)
+ serveMR.start
+
+ println (System.currentTimeMillis + ": " + "ServeMultipleRequests started")
+
+ println (System.currentTimeMillis + ": " + "BroadcastCS object has been initialized")
+
+ initialized = true
+ }
+ }
+ }
+
+ // TODO: This should change in future implementation.
+ // Called from the Master constructor to setup states for this particular that
+ // is being broadcasted
+ def initializeVariable (variableInfo: VariableInfo) {
+ arrayOfBlocks = variableInfo.arrayOfBlocks
+ totalBytes = variableInfo.totalBytes
+ totalBlocks = variableInfo.totalBlocks
+ hasBlocks = variableInfo.totalBlocks
+
+ // listenPort should already be valid
+ assert (listenPort != -1)
+
+ pqOfSources = new PriorityQueue[SourceInfo]
+ val masterSource_0 =
+ new SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes, 0)
+ BroadcastCS.pqOfSources.add (masterSource_0)
+ // Add one more time to have two replicas of any seeds in the PQ
+ if (BroadcastCS.dualMode) {
+ val masterSource_1 =
+ new SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes, 1)
+ BroadcastCS.pqOfSources.add (masterSource_1)
+ }
+ }
+
+ def masterHostAddress = masterHostAddress_
+ def masterListenPort = masterListenPort_
+ def blockSize = blockSize_
+ def maxRetryCount = maxRetryCount_
+ def serverSocketTimout = serverSocketTimout_
+ def dualMode = dualMode_
+
+ def isMaster = isMaster_
+
+ def receiveBroadcast (variableUUID: UUID): Array[Byte] = {
+ // Wait until hostAddress and listenPort are created by the
+ // ServeMultipleRequests thread
+ // NO need to wait; ServeMultipleRequests is created much further ahead
+ while (listenPort == -1) {
+ listenPortLock.synchronized {
+ listenPortLock.wait
+ }
+ }
+
+ // Connect and receive broadcast from the specified source, retrying the
+ // specified number of times in case of failures
+ var retriesLeft = BroadcastCS.maxRetryCount
+ var retByteArray: Array[Byte] = null
+ do {
+ // Connect to Master and send this worker's Information
+ val clientSocketToMaster =
+ new Socket(BroadcastCS.masterHostAddress, BroadcastCS.masterListenPort)
+ println (System.currentTimeMillis + ": " + "Connected to Master's guiding object")
+ // TODO: Guiding object connection is reusable
+ val oisMaster =
+ new ObjectInputStream (clientSocketToMaster.getInputStream)
+ val oosMaster =
+ new ObjectOutputStream (clientSocketToMaster.getOutputStream)
+
+ oosMaster.writeObject(new SourceInfo (hostAddress, listenPort, -1, -1, 0))
+ oosMaster.flush
+
+ // Receive source information from Master
+ var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo]
+ totalBlocks = sourceInfo.totalBlocks
+ arrayOfBlocks = new Array[BroadcastBlock] (totalBlocks)
+ totalBlocksLock.synchronized {
+ totalBlocksLock.notifyAll
+ }
+ totalBytes = sourceInfo.totalBytes
+
+ println (System.currentTimeMillis + ": " + "Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort)
+
+ retByteArray = receiveSingleTransmission (sourceInfo)
+
+ println (System.currentTimeMillis + ": " + "I got this from receiveSingleTransmission: " + retByteArray)
+
+ // TODO: Update sourceInfo to add error notifactions for Master
+ if (retByteArray == null) { sourceInfo.receptionFailed = true }
+
+ // TODO: Supposed to update values here, but we don't support advanced
+ // statistics right now. Master can handle leecherCount by itself.
+
+ // Send back statistics to the Master
+ oosMaster.writeObject (sourceInfo)
+
+ oisMaster.close
+ oosMaster.close
+ clientSocketToMaster.close
+
+ retriesLeft -= 1
+ } while (retriesLeft > 0 && retByteArray == null)
+
+ return retByteArray
+ }
+
+ // Tries to receive broadcast from the Master and returns Boolean status.
+ // This might be called multiple times to retry a defined number of times.
+ private def receiveSingleTransmission(sourceInfo: SourceInfo): Array[Byte] = {
+ var clientSocketToSource: Socket = null
+ var oisSource: ObjectInputStream = null
+ var oosSource: ObjectOutputStream = null
+
+ var retByteArray:Array[Byte] = null
+
+ try {
+ // Connect to the source to get the object itself
+ clientSocketToSource =
+ new Socket (sourceInfo.hostAddress, sourceInfo.listenPort)
+ oosSource =
+ new ObjectOutputStream (clientSocketToSource.getOutputStream)
+ oisSource =
+ new ObjectInputStream (clientSocketToSource.getInputStream)
+
+ println (System.currentTimeMillis + ": " + "Inside receiveSingleTransmission")
+ println (System.currentTimeMillis + ": " + "totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks)
+ retByteArray = new Array[Byte] (totalBytes)
+ for (i <- 0 until totalBlocks) {
+ val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
+ System.arraycopy (bcBlock.byteArray, 0, retByteArray,
+ i * BroadcastCS.blockSize, bcBlock.byteArray.length)
+ arrayOfBlocks(hasBlocks) = bcBlock
+ hasBlocks += 1
+ hasBlocksLock.synchronized {
+ hasBlocksLock.notifyAll
+ }
+ println (System.currentTimeMillis + ": " + "Received block: " + i + " " + bcBlock)
+ }
+ assert (hasBlocks == totalBlocks)
+ println (System.currentTimeMillis + ": " + "After the receive loop")
+ } catch {
+ case e: Exception => {
+ retByteArray = null
+ println (System.currentTimeMillis + ": " + "receiveSingleTransmission had a " + e)
+ }
+ } finally {
+ if (oisSource != null) { oisSource.close }
+ if (oosSource != null) {
+ oosSource.close
+ }
+ if (clientSocketToSource != null) { clientSocketToSource.close }
+ }
+
+ return retByteArray
+ }
+
+ class TrackMultipleValues extends Thread {
+ override def run = {
+ var threadPool = Executors.newCachedThreadPool
+ var serverSocket: ServerSocket = null
+
+ serverSocket = new ServerSocket (BroadcastCS.masterListenPort)
+ println (System.currentTimeMillis + ": " + "TrackMultipleVariables" + serverSocket + " " + listenPort)
+
+ var keepAccepting = true
+ try {
+ while (true) {
+ var clientSocket: Socket = null
+ try {
+ serverSocket.setSoTimeout (serverSocketTimout)
+ clientSocket = serverSocket.accept
+ } catch {
+ case e: Exception => {
+ println ("TrackMultipleValues Timeout. Stopping listening...")
+ keepAccepting = false
+ }
+ }
+ println (System.currentTimeMillis + ": " + "TrackMultipleValues:Got new request:" + clientSocket)
+ if (clientSocket != null) {
+ try {
+ threadPool.execute (new Runnable {
+ def run = {
+ val oos = new ObjectOutputStream (clientSocket.getOutputStream)
+ val ois = new ObjectInputStream (clientSocket.getInputStream)
+ try {
+ val variableUUID = ois.readObject.asInstanceOf[UUID]
+ var contactPort = 0
+ // TODO: Add logic and data structures to find out UUID->port
+ // mapping. 0 = missed the broadcast, read from HDFS; <0 =
+ // Haven't started yet, wait & retry; >0 = Read from this port
+ oos.writeObject (contactPort)
+ } catch {
+ case e: Exception => { }
+ } finally {
+ ois.close
+ oos.close
+ clientSocket.close
+ }
+ }
+ })
+ } catch {
+ // In failure, close the socket here; else, the thread will close it
+ case ioe: IOException => clientSocket.close
+ }
+ }
+ }
+ } finally {
+ serverSocket.close
+ }
+ }
+ }
+
+ class TrackSingleValue {
+
+ }
+
+ class GuideMultipleRequests extends Thread {
+ override def run = {
+ var threadPool = Executors.newCachedThreadPool
+ var serverSocket: ServerSocket = null
+
+ serverSocket = new ServerSocket (BroadcastCS.masterListenPort)
+ // listenPort = BroadcastCS.masterListenPort
+ println (System.currentTimeMillis + ": " + "GuideMultipleRequests" + serverSocket + " " + listenPort)
+
+ var keepAccepting = true
+ try {
+ while (keepAccepting) {
+ var clientSocket: Socket = null
+ try {
+ serverSocket.setSoTimeout (serverSocketTimout)
+ clientSocket = serverSocket.accept
+ } catch {
+ case e: Exception => {
+ println ("GuideMultipleRequests Timeout. Stopping listening...")
+ keepAccepting = false
+ }
+ }
+ if (clientSocket != null) {
+ println (System.currentTimeMillis + ": " + "Guide:Accepted new client connection:" + clientSocket)
+ try {
+ threadPool.execute (new GuideSingleRequest (clientSocket))
+ } catch {
+ // In failure, close the socket here; else, the thread will close it
+ case ioe: IOException => clientSocket.close
+ }
+ }
+ }
+ } finally {
+ serverSocket.close
+ }
+ }
+
+ class GuideSingleRequest (val clientSocket: Socket) extends Runnable {
+ private val oos = new ObjectOutputStream (clientSocket.getOutputStream)
+ private val ois = new ObjectInputStream (clientSocket.getInputStream)
+
+ private var selectedSourceInfo: SourceInfo = null
+ private var thisWorkerInfo:SourceInfo = null
+
+ def run = {
+ try {
+ println (System.currentTimeMillis + ": " + "new GuideSingleRequest is running")
+ // Connecting worker is sending in its hostAddress and listenPort it will
+ // be listening to. ReplicaID is 0 and other fields are invalid (-1)
+ var sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
+
+ // Select a suitable source and send it back to the worker
+ selectedSourceInfo = selectSuitableSource (sourceInfo)
+ println (System.currentTimeMillis + ": " + "Sending selectedSourceInfo:" + selectedSourceInfo)
+ oos.writeObject (selectedSourceInfo)
+ oos.flush
+
+ // Add this new (if it can finish) source to the PQ of sources
+ thisWorkerInfo = new SourceInfo(sourceInfo.hostAddress,
+ sourceInfo.listenPort, totalBlocks, totalBytes, 0)
+ println (System.currentTimeMillis + ": " + "Adding possible new source to pqOfSources: " + thisWorkerInfo)
+ pqOfSources.synchronized {
+ pqOfSources.add (thisWorkerInfo)
+ }
+
+ // Wait till the whole transfer is done. Then receive and update source
+ // statistics in pqOfSources
+ sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
+
+ pqOfSources.synchronized {
+ // This should work since SourceInfo is a case class
+ assert (pqOfSources.contains (selectedSourceInfo))
+
+ // Remove first
+ pqOfSources.remove (selectedSourceInfo)
+ // TODO: Removing a source based on just one failure notification!
+ // Update leecher count and put it back in IF reception succeeded
+ if (!sourceInfo.receptionFailed) {
+ selectedSourceInfo.currentLeechers -= 1
+ pqOfSources.add (selectedSourceInfo)
+
+ // No need to find and update thisWorkerInfo, but add its replica
+ if (BroadcastCS.dualMode) {
+ pqOfSources.add (new SourceInfo (thisWorkerInfo.hostAddress,
+ thisWorkerInfo.listenPort, totalBlocks, totalBytes, 1))
+ }
+ }
+ }
+ } catch {
+ // If something went wrong, e.g., the worker at the other end died etc.
+ // then close everything up
+ case e: Exception => {
+ // Assuming that exception caused due to receiver worker failure
+ // Remove failed worker from pqOfSources and update leecherCount of
+ // corresponding source worker
+ pqOfSources.synchronized {
+ if (selectedSourceInfo != null) {
+ // Remove first
+ pqOfSources.remove (selectedSourceInfo)
+ // Update leecher count and put it back in
+ selectedSourceInfo.currentLeechers -= 1
+ pqOfSources.add (selectedSourceInfo)
+ }
+
+ // Remove thisWorkerInfo
+ if (pqOfSources != null) { pqOfSources.remove (thisWorkerInfo) }
+ }
+ }
+ } finally {
+ ois.close
+ oos.close
+ clientSocket.close
+ }
+ }
+
+ // TODO: If a worker fails to get the broadcasted variable from a source and
+ // comes back to Master, this function might choose the worker itself as a
+ // source tp create a dependency cycle (this worker was put into pqOfSources
+ // as a streming source when it first arrived). The length of this cycle can
+ // be arbitrarily long.
+ private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = {
+ // Select one with the lowest number of leechers
+ pqOfSources.synchronized {
+ // take is a blocking call removing the element from PQ
+ var selectedSource = pqOfSources.poll
+ assert (selectedSource != null)
+ // Update leecher count
+ selectedSource.currentLeechers += 1
+ // Add it back and then return
+ pqOfSources.add (selectedSource)
+ return selectedSource
+ }
+ }
+ }
+ }
+
+ class ServeMultipleRequests extends Thread {
+ override def run = {
+ var threadPool = Executors.newCachedThreadPool
+ var serverSocket: ServerSocket = null
+
+ serverSocket = new ServerSocket (0)
+ listenPort = serverSocket.getLocalPort
+ println (System.currentTimeMillis + ": " + "ServeMultipleRequests" + serverSocket + " " + listenPort)
+
+ listenPortLock.synchronized {
+ listenPortLock.notifyAll
+ }
+
+ var keepAccepting = true
+ try {
+ while (keepAccepting) {
+ var clientSocket: Socket = null
+ try {
+ serverSocket.setSoTimeout (serverSocketTimout)
+ clientSocket = serverSocket.accept
+ } catch {
+ case e: Exception => {
+ println ("ServeMultipleRequests Timeout. Stopping listening...")
+ keepAccepting = false
+ }
+ }
+ if (clientSocket != null) {
+ println (System.currentTimeMillis + ": " + "Serve:Accepted new client connection:" + clientSocket)
+ try {
+ threadPool.execute (new ServeSingleRequest (clientSocket))
+ } catch {
+ // In failure, close socket here; else, the thread will close it
+ case ioe: IOException => clientSocket.close
+ }
+ }
+ }
+ } finally {
+ serverSocket.close
+ }
+ }
+
+ class ServeSingleRequest (val clientSocket: Socket) extends Runnable {
+ private val oos = new ObjectOutputStream (clientSocket.getOutputStream)
+ private val ois = new ObjectInputStream (clientSocket.getInputStream)
+
+ def run = {
+ try {
+ println (System.currentTimeMillis + ": " + "new ServeSingleRequest is running")
+ sendObject
+ } catch {
+ // TODO: Need to add better exception handling here
+ // If something went wrong, e.g., the worker at the other end died etc.
+ // then close everything up
+ case e: Exception => {
+ println (System.currentTimeMillis + ": " + "ServeSingleRequest had a " + e)
+ }
+ } finally {
+ println (System.currentTimeMillis + ": " + "ServeSingleRequest is closing streams and sockets")
+ ois.close
+ oos.close
+ clientSocket.close
+ }
+ }
+
+ private def sendObject = {
+ // Wait till receiving the SourceInfo from Master
+ while (totalBlocks == -1) {
+ totalBlocksLock.synchronized {
+ totalBlocksLock.wait
+ }
+ }
+
+ for (i <- 0 until totalBlocks) {
+ while (i == hasBlocks) {
+ hasBlocksLock.synchronized {
+ hasBlocksLock.wait
+ }
+ }
+ try {
+ oos.writeObject (arrayOfBlocks(i))
+ oos.flush
+ } catch {
+ case e: Exception => { }
+ }
+ println (System.currentTimeMillis + ": " + "Send block: " + i + " " + arrayOfBlocks(i))
+ }
+ }
+ }
+
+ }
+}
+
+private object BroadcastCH {
+ val values = new MapMaker ().softValues ().makeMap[UUID, Any]
+
+ private var initialized = false
+
+ private var fileSystem: FileSystem = null
+ private var workDir: String = null
+ private var compress: Boolean = false
+ private var bufferSize: Int = 65536
+
+ def initialize () {
+ synchronized {
+ if (!initialized) {
+ bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+ val dfs = System.getProperty("spark.dfs", "file:///")
+ if (!dfs.startsWith("file://")) {
+ val conf = new Configuration()
+ conf.setInt("io.file.buffer.size", bufferSize)
+ val rep = System.getProperty("spark.dfs.replication", "3").toInt
+ conf.setInt("dfs.replication", rep)
+ fileSystem = FileSystem.get(new URI(dfs), conf)
+ }
+ workDir = System.getProperty("spark.dfs.workdir", "/tmp")
+ compress = System.getProperty("spark.compress", "false").toBoolean
+
+ initialized = true
+ }
+ }
+ }
+
+ private def getPath(uuid: UUID) = new Path(workDir + "/broadcast-" + uuid)
+
+ def openFileForReading(uuid: UUID): InputStream = {
+ val fileStream = if (fileSystem != null) {
+ fileSystem.open(getPath(uuid))
+ } else {
+ // Local filesystem
+ new FileInputStream(getPath(uuid).toString)
+ }
+ if (compress)
+ new LZFInputStream(fileStream) // LZF stream does its own buffering
+ else if (fileSystem == null)
+ new BufferedInputStream(fileStream, bufferSize)
+ else
+ fileStream // Hadoop streams do their own buffering
+ }
+
+ def openFileForWriting(uuid: UUID): OutputStream = {
+ val fileStream = if (fileSystem != null) {
+ fileSystem.create(getPath(uuid))
+ } else {
+ // Local filesystem
+ new FileOutputStream(getPath(uuid).toString)
+ }
+ if (compress)
+ new LZFOutputStream(fileStream) // LZF stream does its own buffering
+ else if (fileSystem == null)
+ new BufferedOutputStream(fileStream, bufferSize)
+ else
+ fileStream // Hadoop streams do their own buffering
+ }
+}
diff --git a/src/scala/spark/Cached.scala b/src/scala/spark/Cached.scala
deleted file mode 100644
index 8113340e1f..0000000000
--- a/src/scala/spark/Cached.scala
+++ /dev/null
@@ -1,110 +0,0 @@
-package spark
-
-import java.io._
-import java.net.URI
-import java.util.UUID
-
-import com.google.common.collect.MapMaker
-
-import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem}
-
-import spark.compress.lzf.{LZFInputStream, LZFOutputStream}
-
-@serializable class Cached[T](@transient var value_ : T, local: Boolean) {
- val uuid = UUID.randomUUID()
- def value = value_
-
- Cache.synchronized { Cache.values.put(uuid, value_) }
-
- if (!local) writeCacheFile()
-
- private def writeCacheFile() {
- val out = new ObjectOutputStream(Cache.openFileForWriting(uuid))
- out.writeObject(value_)
- out.close()
- }
-
- // Called by Java when deserializing an object
- private def readObject(in: ObjectInputStream) {
- in.defaultReadObject
- Cache.synchronized {
- val cachedVal = Cache.values.get(uuid)
- if (cachedVal != null) {
- value_ = cachedVal.asInstanceOf[T]
- } else {
- val start = System.nanoTime
- val fileIn = new ObjectInputStream(Cache.openFileForReading(uuid))
- value_ = fileIn.readObject().asInstanceOf[T]
- Cache.values.put(uuid, value_)
- fileIn.close()
- val time = (System.nanoTime - start) / 1e9
- println("Reading cached variable " + uuid + " took " + time + " s")
- }
- }
- }
-
- override def toString = "spark.Cached(" + uuid + ")"
-}
-
-private object Cache {
- val values = new MapMaker().softValues().makeMap[UUID, Any]()
-
- private var initialized = false
- private var fileSystem: FileSystem = null
- private var workDir: String = null
- private var compress: Boolean = false
- private var bufferSize: Int = 65536
-
- // Will be called by SparkContext or Executor before using cache
- def initialize() {
- synchronized {
- if (!initialized) {
- bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
- val dfs = System.getProperty("spark.dfs", "file:///")
- if (!dfs.startsWith("file://")) {
- val conf = new Configuration()
- conf.setInt("io.file.buffer.size", bufferSize)
- val rep = System.getProperty("spark.dfs.replication", "3").toInt
- conf.setInt("dfs.replication", rep)
- fileSystem = FileSystem.get(new URI(dfs), conf)
- }
- workDir = System.getProperty("spark.dfs.workdir", "/tmp")
- compress = System.getProperty("spark.compress", "false").toBoolean
- initialized = true
- }
- }
- }
-
- private def getPath(uuid: UUID) = new Path(workDir + "/cache-" + uuid)
-
- def openFileForReading(uuid: UUID): InputStream = {
- val fileStream = if (fileSystem != null) {
- fileSystem.open(getPath(uuid))
- } else {
- // Local filesystem
- new FileInputStream(getPath(uuid).toString)
- }
- if (compress)
- new LZFInputStream(fileStream) // LZF stream does its own buffering
- else if (fileSystem == null)
- new BufferedInputStream(fileStream, bufferSize)
- else
- fileStream // Hadoop streams do their own buffering
- }
-
- def openFileForWriting(uuid: UUID): OutputStream = {
- val fileStream = if (fileSystem != null) {
- fileSystem.create(getPath(uuid))
- } else {
- // Local filesystem
- new FileOutputStream(getPath(uuid).toString)
- }
- if (compress)
- new LZFOutputStream(fileStream) // LZF stream does its own buffering
- else if (fileSystem == null)
- new BufferedOutputStream(fileStream, bufferSize)
- else
- fileStream // Hadoop streams do their own buffering
- }
-}
diff --git a/src/scala/spark/Executor.scala b/src/scala/spark/Executor.scala
index 4cc8f00aa9..d115c6acd9 100644
--- a/src/scala/spark/Executor.scala
+++ b/src/scala/spark/Executor.scala
@@ -18,8 +18,8 @@ object Executor {
for ((key, value) <- props)
System.setProperty(key, value)
- // Initialize cache (uses some properties read above)
- Cache.initialize()
+ // Initialize broadcast system (uses some properties read above)
+ Broadcast.initialize(false)
// If the REPL is in use, create a ClassLoader that will be able to
// read new classes defined by the REPL as the user types code
diff --git a/src/scala/spark/HdfsFile.scala b/src/scala/spark/HdfsFile.scala
index 8050683f99..87d8e8cc81 100644
--- a/src/scala/spark/HdfsFile.scala
+++ b/src/scala/spark/HdfsFile.scala
@@ -27,9 +27,9 @@ import org.apache.hadoop.mapred.Reporter
abstract class DistributedFile[T, Split](@transient sc: SparkContext) {
def splits: Array[Split]
def iterator(split: Split): Iterator[T]
- def prefers(split: Split, slot: SlaveOffer): Boolean
+ def preferredLocations(split: Split): Seq[String]
- def taskStarted(split: Split, slot: SlaveOffer) {}
+ def taskStarted(split: Split, offer: SlaveOffer) {}
def sparkContext = sc
@@ -87,8 +87,8 @@ abstract class DistributedFile[T, Split](@transient sc: SparkContext) {
abstract class FileTask[U, T, Split](val file: DistributedFile[T, Split],
val split: Split)
extends Task[U] {
- override def prefers(slot: SlaveOffer) = file.prefers(split, slot)
- override def markStarted(slot: SlaveOffer) { file.taskStarted(split, slot) }
+ override def preferredLocations: Seq[String] = file.preferredLocations(split)
+ override def markStarted(offer: SlaveOffer) { file.taskStarted(split, offer) }
}
class ForeachTask[T, Split](file: DistributedFile[T, Split],
@@ -124,31 +124,31 @@ extends FileTask[Option[T], T, Split](file, split) {
class MappedFile[U, T, Split](prev: DistributedFile[T, Split], f: T => U)
extends DistributedFile[U, Split](prev.sparkContext) {
override def splits = prev.splits
- override def prefers(split: Split, slot: SlaveOffer) = prev.prefers(split, slot)
+ override def preferredLocations(sp: Split) = prev.preferredLocations(sp)
override def iterator(split: Split) = prev.iterator(split).map(f)
- override def taskStarted(split: Split, slot: SlaveOffer) = prev.taskStarted(split, slot)
+ override def taskStarted(split: Split, offer: SlaveOffer) = prev.taskStarted(split, offer)
}
class FilteredFile[T, Split](prev: DistributedFile[T, Split], f: T => Boolean)
extends DistributedFile[T, Split](prev.sparkContext) {
override def splits = prev.splits
- override def prefers(split: Split, slot: SlaveOffer) = prev.prefers(split, slot)
+ override def preferredLocations(sp: Split) = prev.preferredLocations(sp)
override def iterator(split: Split) = prev.iterator(split).filter(f)
- override def taskStarted(split: Split, slot: SlaveOffer) = prev.taskStarted(split, slot)
+ override def taskStarted(split: Split, offer: SlaveOffer) = prev.taskStarted(split, offer)
}
class CachedFile[T, Split](prev: DistributedFile[T, Split])
extends DistributedFile[T, Split](prev.sparkContext) {
val id = CachedFile.newId()
- @transient val cacheLocs = Map[Split, List[Int]]()
+ @transient val cacheLocs = Map[Split, List[String]]()
override def splits = prev.splits
- override def prefers(split: Split, slot: SlaveOffer): Boolean = {
+ override def preferredLocations(split: Split): Seq[String] = {
if (cacheLocs.contains(split))
- cacheLocs(split).contains(slot.getSlaveId)
+ cacheLocs(split)
else
- prev.prefers(split, slot)
+ prev.preferredLocations(split)
}
override def iterator(split: Split): Iterator[T] = {
@@ -183,11 +183,11 @@ extends DistributedFile[T, Split](prev.sparkContext) {
}
}
- override def taskStarted(split: Split, slot: SlaveOffer) {
+ override def taskStarted(split: Split, offer: SlaveOffer) {
val oldList = cacheLocs.getOrElse(split, Nil)
- val slaveId = slot.getSlaveId
- if (!oldList.contains(slaveId))
- cacheLocs(split) = slaveId :: oldList
+ val host = offer.getHost
+ if (!oldList.contains(host))
+ cacheLocs(split) = host :: oldList
}
}
@@ -251,8 +251,10 @@ extends DistributedFile[String, HdfsSplit](sc) {
}
}
- override def prefers(split: HdfsSplit, slot: SlaveOffer) =
- split.value.getLocations().contains(slot.getHost)
+ override def preferredLocations(split: HdfsSplit) = {
+ // TODO: Filtering out "localhost" in case of file:// URLs
+ split.value.getLocations().filter(_ != "localhost").toArray
+ }
}
object ConfigureLock {}
diff --git a/src/scala/spark/NexusScheduler.scala b/src/scala/spark/NexusScheduler.scala
index a96fca9350..a8a5e2947a 100644
--- a/src/scala/spark/NexusScheduler.scala
+++ b/src/scala/spark/NexusScheduler.scala
@@ -1,11 +1,11 @@
package spark
import java.io.File
-import java.util.concurrent.Semaphore
-import nexus.{ExecutorInfo, TaskDescription, TaskState, TaskStatus}
-import nexus.{SlaveOffer, SchedulerDriver, NexusSchedulerDriver}
-import nexus.{SlaveOfferVector, TaskDescriptionVector, StringMap}
+import scala.collection.mutable.Map
+
+import nexus.{Scheduler => NScheduler}
+import nexus._
// The main Scheduler implementation, which talks to Nexus. Clients are expected
// to first call start(), then submit tasks through the runTasks method.
@@ -21,30 +21,26 @@ import nexus.{SlaveOfferVector, TaskDescriptionVector, StringMap}
// can be made cleaner.
private class NexusScheduler(
master: String, frameworkName: String, execArg: Array[Byte])
-extends nexus.Scheduler with spark.Scheduler
+extends NScheduler with spark.Scheduler
{
- // Semaphore used by runTasks to ensure only one thread can be in it
- val semaphore = new Semaphore(1)
+ // Lock used by runTasks to ensure only one thread can be in it
+ val runTasksMutex = new Object()
// Lock used to wait for scheduler to be registered
var isRegistered = false
val registeredLock = new Object()
- // Trait representing a set of scheduler callbacks
- trait Callbacks {
- def slotOffer(s: SlaveOffer): Option[TaskDescription]
- def taskFinished(t: TaskStatus): Unit
- def error(code: Int, message: String): Unit
- }
-
// Current callback object (may be null)
- var callbacks: Callbacks = null
+ var activeOp: ParallelOperation = null
// Incrementing task ID
- var nextTaskId = 0
+ private var nextTaskId = 0
- // Maximum time to wait to run a task in a preferred location (in ms)
- val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "1000").toLong
+ def newTaskId(): Int = {
+ val id = nextTaskId;
+ nextTaskId += 1;
+ return id
+ }
// Driver for talking to Nexus
var driver: SchedulerDriver = null
@@ -66,125 +62,27 @@ extends nexus.Scheduler with spark.Scheduler
new ExecutorInfo(new File("spark-executor").getCanonicalPath(), execArg)
override def runTasks[T](tasks: Array[Task[T]]): Array[T] = {
- val results = new Array[T](tasks.length)
- if (tasks.length == 0)
- return results
-
- val launched = new Array[Boolean](tasks.length)
-
- val callingThread = currentThread
-
- var errorHappened = false
- var errorCode = 0
- var errorMessage = ""
-
- // Wait for scheduler to be registered with Nexus
- waitForRegister()
-
- try {
- // Acquire the runTasks semaphore
- semaphore.acquire()
-
- val myCallbacks = new Callbacks {
- val firstTaskId = nextTaskId
- var tasksLaunched = 0
- var tasksFinished = 0
- var lastPreferredLaunchTime = System.currentTimeMillis
-
- def slotOffer(slot: SlaveOffer): Option[TaskDescription] = {
- try {
- if (tasksLaunched < tasks.length) {
- // TODO: Add a short wait if no task with location pref is found
- // TODO: Figure out why a function is needed around this to
- // avoid scala.runtime.NonLocalReturnException
- def findTask: Option[TaskDescription] = {
- var checkPrefVals: Array[Boolean] = Array(true)
- val time = System.currentTimeMillis
- if (time - lastPreferredLaunchTime > LOCALITY_WAIT)
- checkPrefVals = Array(true, false) // Allow non-preferred tasks
- // TODO: Make desiredCpus and desiredMem configurable
- val desiredCpus = 1
- val desiredMem = 750L * 1024L * 1024L
- if (slot.getParams.get("cpus").toInt < desiredCpus ||
- slot.getParams.get("mem").toLong < desiredMem)
- return None
- for (checkPref <- checkPrefVals;
- i <- 0 until tasks.length;
- if !launched(i) && (!checkPref || tasks(i).prefers(slot)))
- {
- val taskId = nextTaskId
- nextTaskId += 1
- printf("Starting task %d as TID %d on slave %d: %s (%s)\n",
- i, taskId, slot.getSlaveId, slot.getHost,
- if(checkPref) "preferred" else "non-preferred")
- tasks(i).markStarted(slot)
- launched(i) = true
- tasksLaunched += 1
- if (checkPref)
- lastPreferredLaunchTime = time
- val params = new StringMap
- params.set("cpus", "" + desiredCpus)
- params.set("mem", "" + desiredMem)
- val serializedTask = Utils.serialize(tasks(i))
- return Some(new TaskDescription(taskId, slot.getSlaveId,
- "task_" + taskId, params, serializedTask))
- }
- return None
- }
- return findTask
- } else {
- return None
- }
- } catch {
- case e: Exception => {
- e.printStackTrace
- System.exit(1)
- return None
- }
- }
- }
+ runTasksMutex.synchronized {
+ waitForRegister()
+ val myOp = new SimpleParallelOperation(this, tasks)
- def taskFinished(status: TaskStatus) {
- println("Finished TID " + status.getTaskId)
- // Deserialize task result
- val result = Utils.deserialize[TaskResult[T]](status.getData)
- results(status.getTaskId - firstTaskId) = result.value
- // Update accumulators
- Accumulators.add(callingThread, result.accumUpdates)
- // Stop if we've finished all the tasks
- tasksFinished += 1
- if (tasksFinished == tasks.length) {
- NexusScheduler.this.callbacks = null
- NexusScheduler.this.notifyAll()
- }
+ try {
+ this.synchronized {
+ this.activeOp = myOp
}
-
- def error(code: Int, message: String) {
- // Save the error message
- errorHappened = true
- errorCode = code
- errorMessage = message
- // Indicate to caller thread that we're done
- NexusScheduler.this.callbacks = null
- NexusScheduler.this.notifyAll()
+ driver.reviveOffers();
+ myOp.join();
+ } finally {
+ this.synchronized {
+ this.activeOp = null
}
}
- this.synchronized {
- this.callbacks = myCallbacks
- }
- driver.reviveOffers();
- this.synchronized {
- while (this.callbacks != null) this.wait()
- }
- } finally {
- semaphore.release()
+ if (myOp.errorHappened)
+ throw new SparkException(myOp.errorMessage, myOp.errorCode)
+ else
+ return myOp.results
}
-
- if (errorHappened)
- throw new SparkException(errorMessage, errorCode)
- else
- return results
}
override def registered(d: SchedulerDriver, frameworkId: Int) {
@@ -197,18 +95,19 @@ extends nexus.Scheduler with spark.Scheduler
override def waitForRegister() {
registeredLock.synchronized {
- while (!isRegistered) registeredLock.wait()
+ while (!isRegistered)
+ registeredLock.wait()
}
}
override def resourceOffer(
- d: SchedulerDriver, oid: Long, slots: SlaveOfferVector) {
+ d: SchedulerDriver, oid: Long, offers: SlaveOfferVector) {
synchronized {
val tasks = new TaskDescriptionVector
- if (callbacks != null) {
+ if (activeOp != null) {
try {
- for (i <- 0 until slots.size.toInt) {
- callbacks.slotOffer(slots.get(i)) match {
+ for (i <- 0 until offers.size.toInt) {
+ activeOp.slaveOffer(offers.get(i)) match {
case Some(task) => tasks.add(task)
case None => {}
}
@@ -225,21 +124,21 @@ extends nexus.Scheduler with spark.Scheduler
override def statusUpdate(d: SchedulerDriver, status: TaskStatus) {
synchronized {
- if (callbacks != null && status.getState == TaskState.TASK_FINISHED) {
- try {
- callbacks.taskFinished(status)
- } catch {
- case e: Exception => e.printStackTrace
+ try {
+ if (activeOp != null) {
+ activeOp.statusUpdate(status)
}
+ } catch {
+ case e: Exception => e.printStackTrace
}
}
}
override def error(d: SchedulerDriver, code: Int, message: String) {
synchronized {
- if (callbacks != null) {
+ if (activeOp != null) {
try {
- callbacks.error(code, message)
+ activeOp.error(code, message)
} catch {
case e: Exception => e.printStackTrace
}
@@ -256,3 +155,135 @@ extends nexus.Scheduler with spark.Scheduler
driver.stop()
}
}
+
+
+// Trait representing a set of scheduler callbacks
+trait ParallelOperation {
+ def slaveOffer(s: SlaveOffer): Option[TaskDescription]
+ def statusUpdate(t: TaskStatus): Unit
+ def error(code: Int, message: String): Unit
+}
+
+
+class SimpleParallelOperation[T](sched: NexusScheduler, tasks: Array[Task[T]])
+extends ParallelOperation
+{
+ // Maximum time to wait to run a task in a preferred location (in ms)
+ val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "1000").toLong
+
+ val callingThread = currentThread
+ val numTasks = tasks.length
+ val results = new Array[T](numTasks)
+ val launched = new Array[Boolean](numTasks)
+ val finished = new Array[Boolean](numTasks)
+ val tidToIndex = Map[Int, Int]()
+
+ var allFinished = false
+ val joinLock = new Object()
+
+ var errorHappened = false
+ var errorCode = 0
+ var errorMessage = ""
+
+ var tasksLaunched = 0
+ var tasksFinished = 0
+ var lastPreferredLaunchTime = System.currentTimeMillis
+
+ def setAllFinished() {
+ joinLock.synchronized {
+ allFinished = true
+ joinLock.notifyAll()
+ }
+ }
+
+ def join() {
+ joinLock.synchronized {
+ while (!allFinished)
+ joinLock.wait()
+ }
+ }
+
+ def slaveOffer(offer: SlaveOffer): Option[TaskDescription] = {
+ if (tasksLaunched < numTasks) {
+ var checkPrefVals: Array[Boolean] = Array(true)
+ val time = System.currentTimeMillis
+ if (time - lastPreferredLaunchTime > LOCALITY_WAIT)
+ checkPrefVals = Array(true, false) // Allow non-preferred tasks
+ // TODO: Make desiredCpus and desiredMem configurable
+ val desiredCpus = 1
+ val desiredMem = 750L * 1024L * 1024L
+ if (offer.getParams.get("cpus").toInt < desiredCpus ||
+ offer.getParams.get("mem").toLong < desiredMem)
+ return None
+ for (checkPref <- checkPrefVals; i <- 0 until numTasks) {
+ if (!launched(i) && (!checkPref ||
+ tasks(i).preferredLocations.contains(offer.getHost) ||
+ tasks(i).preferredLocations.isEmpty))
+ {
+ val taskId = sched.newTaskId()
+ tidToIndex(taskId) = i
+ printf("Starting task %d as TID %d on slave %d: %s (%s)\n",
+ i, taskId, offer.getSlaveId, offer.getHost,
+ if(checkPref) "preferred" else "non-preferred")
+ tasks(i).markStarted(offer)
+ launched(i) = true
+ tasksLaunched += 1
+ if (checkPref)
+ lastPreferredLaunchTime = time
+ val params = new StringMap
+ params.set("cpus", "" + desiredCpus)
+ params.set("mem", "" + desiredMem)
+ val serializedTask = Utils.serialize(tasks(i))
+ return Some(new TaskDescription(taskId, offer.getSlaveId,
+ "task_" + taskId, params, serializedTask))
+ }
+ }
+ }
+ return None
+ }
+
+ def statusUpdate(status: TaskStatus) {
+ status.getState match {
+ case TaskState.TASK_FINISHED =>
+ taskFinished(status)
+ case TaskState.TASK_LOST =>
+ taskLost(status)
+ case TaskState.TASK_FAILED =>
+ taskLost(status)
+ case TaskState.TASK_KILLED =>
+ taskLost(status)
+ case _ =>
+ }
+ }
+
+ def taskFinished(status: TaskStatus) {
+ val tid = status.getTaskId
+ println("Finished TID " + tid)
+ // Deserialize task result
+ val result = Utils.deserialize[TaskResult[T]](status.getData)
+ results(tidToIndex(tid)) = result.value
+ // Update accumulators
+ Accumulators.add(callingThread, result.accumUpdates)
+ // Mark finished and stop if we've finished all the tasks
+ finished(tidToIndex(tid)) = true
+ tasksFinished += 1
+ if (tasksFinished == numTasks)
+ setAllFinished()
+ }
+
+ def taskLost(status: TaskStatus) {
+ val tid = status.getTaskId
+ println("Lost TID " + tid)
+ launched(tidToIndex(tid)) = false
+ tasksLaunched -= 1
+ }
+
+ def error(code: Int, message: String) {
+ // Save the error message
+ errorHappened = true
+ errorCode = code
+ errorMessage = message
+ // Indicate to caller thread that we're done
+ setAllFinished()
+ }
+}
diff --git a/src/scala/spark/SparkContext.scala b/src/scala/spark/SparkContext.scala
index 4bfbcb6f21..7972702205 100644
--- a/src/scala/spark/SparkContext.scala
+++ b/src/scala/spark/SparkContext.scala
@@ -6,7 +6,7 @@ import java.util.UUID
import scala.collection.mutable.ArrayBuffer
class SparkContext(master: String, frameworkName: String) {
- Cache.initialize()
+ Broadcast.initialize(true)
def parallelize[T](seq: Seq[T], numSlices: Int): ParallelArray[T] =
new SimpleParallelArray[T](this, seq, numSlices)
@@ -17,7 +17,8 @@ class SparkContext(master: String, frameworkName: String) {
new Accumulator(initialValue, param)
// TODO: Keep around a weak hash map of values to Cached versions?
- def broadcast[T](value: T) = new Cached(value, local)
+ def broadcast[T](value: T) = new ChainedStreamingBroadcast (value, local)
+ // def broadcast[T](value: T) = new CentralizedHDFSBroadcast (value, local)
def textFile(path: String) = new HdfsTextFile(this, path)
diff --git a/src/scala/spark/Task.scala b/src/scala/spark/Task.scala
index e559996a37..efb864472d 100644
--- a/src/scala/spark/Task.scala
+++ b/src/scala/spark/Task.scala
@@ -5,8 +5,8 @@ import nexus._
@serializable
trait Task[T] {
def run: T
- def prefers(slot: SlaveOffer): Boolean = true
- def markStarted(slot: SlaveOffer) {}
+ def preferredLocations: Seq[String] = Nil
+ def markStarted(offer: SlaveOffer) {}
}
@serializable
diff --git a/src/test/spark/repl/ReplSuite.scala b/src/test/spark/repl/ReplSuite.scala
index d71fe20a94..43ef296efe 100644
--- a/src/test/spark/repl/ReplSuite.scala
+++ b/src/test/spark/repl/ReplSuite.scala
@@ -85,15 +85,15 @@ class ReplSuite extends FunSuite {
assertContains("res2: Int = 100", output)
}
- test ("cached vars") {
- // Test that the value that a cached var had when it was created is used,
- // even if that cached var is then modified in the driver program
+ test ("broadcast vars") {
+ // Test that the value that a broadcast var had when it was created is used,
+ // even if that broadcast var is then modified in the driver program
val output = runInterpreter("local", """
var array = new Array[Int](5)
- val cachedArray = sc.cache(array)
- sc.parallelize(0 to 4).map(x => cachedArray.value(x)).toArray
+ val broadcastedArray = sc.broadcast(array)
+ sc.parallelize(0 to 4).map(x => broadcastedArray.value(x)).toArray
array(0) = 5
- sc.parallelize(0 to 4).map(x => cachedArray.value(x)).toArray
+ sc.parallelize(0 to 4).map(x => broadcastedArray.value(x)).toArray
""")
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
@@ -109,10 +109,10 @@ class ReplSuite extends FunSuite {
v = 10
sc.parallelize(1 to 10).map(x => getV()).toArray.reduceLeft(_+_)
var array = new Array[Int](5)
- val cachedArray = sc.cache(array)
- sc.parallelize(0 to 4).map(x => cachedArray.value(x)).toArray
+ val broadcastedArray = sc.broadcast(array)
+ sc.parallelize(0 to 4).map(x => broadcastedArray.value(x)).toArray
array(0) = 5
- sc.parallelize(0 to 4).map(x => cachedArray.value(x)).toArray
+ sc.parallelize(0 to 4).map(x => broadcastedArray.value(x)).toArray
""")
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)