aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMosharaf Chowdhury <mosharaf@mosharaf-ubuntu.(none)>2011-01-10 17:13:52 -0800
committerMosharaf Chowdhury <mosharaf@mosharaf-ubuntu.(none)>2011-01-10 17:13:52 -0800
commitfd3fd3738311e022afda305183fb8417d724b5cf (patch)
tree3b796cd9dea775ab46d4800fe930385faeb7d5cc
parentd7081a927f1f9aa321efa6dbd74bd85fd78f0447 (diff)
downloadspark-fd3fd3738311e022afda305183fb8417d724b5cf.tar.gz
spark-fd3fd3738311e022afda305183fb8417d724b5cf.tar.bz2
spark-fd3fd3738311e022afda305183fb8417d724b5cf.zip
In-memory version of tracker+blocked shuffle checked in.
-rw-r--r--conf/java-opts4
-rw-r--r--src/scala/spark/CustomBlockedInMemoryShuffle.scala27
-rw-r--r--src/scala/spark/TrackedCustomBlockedInMemoryShuffle.scala927
-rw-r--r--src/scala/spark/TrackedCustomBlockedLocalFileShuffle.scala1
4 files changed, 933 insertions, 26 deletions
diff --git a/conf/java-opts b/conf/java-opts
index bcce9f97e3..20eb4511f4 100644
--- a/conf/java-opts
+++ b/conf/java-opts
@@ -1,7 +1,7 @@
--Dspark.shuffle.class=spark.CustomParallelLocalFileShuffle
+-Dspark.shuffle.class=spark.TrackedCustomBlockedInMemoryShuffle
-Dspark.shuffle.masterHostAddress=127.0.0.1
-Dspark.shuffle.masterTrackerPort=22222
--Dspark.shuffle.trackerStrategy=spark.BalanceRemainingShuffleTrackerStrategy
+-Dspark.shuffle.trackerStrategy=spark.BalanceConnectionsShuffleTrackerStrategy
-Dspark.shuffle.maxRxConnections=40
-Dspark.shuffle.maxTxConnections=120
-Dspark.shuffle.blockSize=4096
diff --git a/src/scala/spark/CustomBlockedInMemoryShuffle.scala b/src/scala/spark/CustomBlockedInMemoryShuffle.scala
index aae940dc8e..898ebe5af1 100644
--- a/src/scala/spark/CustomBlockedInMemoryShuffle.scala
+++ b/src/scala/spark/CustomBlockedInMemoryShuffle.scala
@@ -83,7 +83,7 @@ extends Shuffle[K, V, C] with Logging {
var writeStartTime: Long = 0
buckets(i).foreach(pair => {
- // Open a new file if necessary
+ // Open a new stream if necessary
if (!isDirty) {
splitName = CustomBlockedInMemoryShuffle.getSplitName(shuffleId,
myIndex, i, blockNum)
@@ -98,7 +98,7 @@ extends Shuffle[K, V, C] with Logging {
oos.writeObject(pair)
isDirty = true
- // Close the old file if has crossed the blockSize limit
+ // Close the old stream if has crossed the blockSize limit
if (baos.size > Shuffle.BlockSize) {
CustomBlockedInMemoryShuffle.splitsCache(splitName) =
baos.toByteArray
@@ -477,30 +477,9 @@ object CustomBlockedInMemoryShuffle extends Logging {
nextShuffleId.getAndIncrement()
}
- // Returns a standard ThreadFactory except all threads are daemons
- private def newDaemonThreadFactory: ThreadFactory = {
- new ThreadFactory {
- def newThread(r: Runnable): Thread = {
- var t = Executors.defaultThreadFactory.newThread(r)
- t.setDaemon(true)
- return t
- }
- }
- }
-
- // Wrapper over newFixedThreadPool
- def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = {
- var threadPool =
- Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor]
-
- threadPool.setThreadFactory(newDaemonThreadFactory)
-
- return threadPool
- }
-
class ShuffleServer
extends Thread with Logging {
- var threadPool = newDaemonFixedThreadPool(Shuffle.MaxTxConnections)
+ var threadPool = Shuffle.newDaemonFixedThreadPool(Shuffle.MaxTxConnections)
var serverSocket: ServerSocket = null
diff --git a/src/scala/spark/TrackedCustomBlockedInMemoryShuffle.scala b/src/scala/spark/TrackedCustomBlockedInMemoryShuffle.scala
new file mode 100644
index 0000000000..7c7281c9d7
--- /dev/null
+++ b/src/scala/spark/TrackedCustomBlockedInMemoryShuffle.scala
@@ -0,0 +1,927 @@
+package spark
+
+import java.io._
+import java.net._
+import java.util.{BitSet, Random, Timer, TimerTask, UUID}
+import java.util.concurrent.atomic.AtomicLong
+import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory}
+
+import scala.collection.mutable.{ArrayBuffer, HashMap}
+
+/**
+ * An implementation of shuffle using memory served through custom server
+ * where receivers create simultaneous connections to multiple servers by
+ * setting the 'spark.shuffle.maxRxConnections' config option.
+ *
+ * By controlling the 'spark.shuffle.blockSize' config option one can also
+ * control the largest block size to divide each map output into. Essentially,
+ * instead of creating one large output file for each reducer, maps create
+ * multiple smaller files to enable finer level of engagement.
+ *
+ * 'spark.shuffle.maxTxConnections' enforces server-side cap. Ideally,
+ * maxTxConnections >= maxRxConnections * numReducersPerMachine
+ *
+ * 'spark.shuffle.TrackerStrategy' decides which strategy to use in the tracker
+ *
+ * TODO: Add support for compression when spark.compress is set to true.
+ */
+@serializable
+class TrackedCustomBlockedInMemoryShuffle[K, V, C]
+extends Shuffle[K, V, C] with Logging {
+ @transient var totalSplits = 0
+ @transient var hasSplits = 0
+
+ @transient var totalBlocksInSplit: Array[Int] = null
+ @transient var hasBlocksInSplit: Array[Int] = null
+
+ @transient var hasSplitsBitVector: BitSet = null
+ @transient var splitsInRequestBitVector: BitSet = null
+
+ @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null
+ @transient var combiners: HashMap[K,C] = null
+
+ override def compute(input: RDD[(K, V)],
+ numOutputSplits: Int,
+ createCombiner: V => C,
+ mergeValue: (C, V) => C,
+ mergeCombiners: (C, C) => C)
+ : RDD[(K, C)] =
+ {
+ val sc = input.sparkContext
+ val shuffleId = TrackedCustomBlockedInMemoryShuffle.newShuffleId()
+ logInfo("Shuffle ID: " + shuffleId)
+
+ val splitRdd = new NumberedSplitRDD(input)
+ val numInputSplits = splitRdd.splits.size
+
+ // Run a parallel map and collect to write the intermediate data files,
+ // returning a list of inputSplitId -> serverUri pairs
+ val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => {
+ val myIndex = pair._1
+ val myIterator = pair._2
+ val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C])
+ for ((k, v) <- myIterator) {
+ var bucketId = k.hashCode % numOutputSplits
+ if (bucketId < 0) { // Fix bucket ID if hash code was negative
+ bucketId += numOutputSplits
+ }
+ val bucket = buckets(bucketId)
+ bucket(k) = bucket.get(k) match {
+ case Some(c) => mergeValue(c, v)
+ case None => createCombiner(v)
+ }
+ }
+
+ // Keep track of number of blocks for each output split
+ var numBlocksPerOutputSplit = Array.tabulate(numOutputSplits)(_ => 0)
+
+ for (i <- 0 until numOutputSplits) {
+ var blockNum = 0
+ var isDirty = false
+
+ var splitName = ""
+ var baos: ByteArrayOutputStream = null
+ var oos: ObjectOutputStream = null
+
+ var writeStartTime: Long = 0
+
+ buckets(i).foreach(pair => {
+ // Open a new stream if necessary
+ if (!isDirty) {
+ splitName = TrackedCustomBlockedInMemoryShuffle.getSplitName(shuffleId,
+ myIndex, i, blockNum)
+
+ baos = new ByteArrayOutputStream
+ oos = new ObjectOutputStream(baos)
+
+ writeStartTime = System.currentTimeMillis
+ logInfo("BEGIN WRITE: " + splitName)
+ }
+
+ oos.writeObject(pair)
+ isDirty = true
+
+ // Close the old stream if has crossed the blockSize limit
+ if (baos.size > Shuffle.BlockSize) {
+ TrackedCustomBlockedInMemoryShuffle.splitsCache(splitName) =
+ baos.toByteArray
+
+ logInfo("END WRITE: " + splitName)
+ val writeTime = System.currentTimeMillis - writeStartTime
+ logInfo("Writing " + splitName + " of size " + baos.size + " bytes took " + writeTime + " millis.")
+
+ blockNum = blockNum + 1
+ isDirty = false
+ oos.close()
+ }
+ })
+
+ if (isDirty) {
+ TrackedCustomBlockedInMemoryShuffle.splitsCache(splitName) = baos.toByteArray
+
+ logInfo("END WRITE: " + splitName)
+ val writeTime = System.currentTimeMillis - writeStartTime
+ logInfo("Writing " + splitName + " of size " + baos.size + " bytes took " + writeTime + " millis.")
+
+ blockNum = blockNum + 1
+ oos.close()
+ }
+
+ // Store BLOCKNUM info
+ splitName = TrackedCustomBlockedInMemoryShuffle.getBlockNumOutputName(
+ shuffleId, myIndex, i)
+ baos = new ByteArrayOutputStream
+ oos = new ObjectOutputStream(baos)
+ oos.writeObject(blockNum)
+ TrackedCustomBlockedInMemoryShuffle.splitsCache(splitName) = baos.toByteArray
+
+ // Close streams
+ oos.close()
+
+ // Store number of blocks for this outputSplit
+ numBlocksPerOutputSplit(i) = blockNum
+ }
+
+ var retVal = SplitInfo(TrackedCustomBlockedInMemoryShuffle.serverAddress,
+ TrackedCustomBlockedInMemoryShuffle.serverPort, myIndex)
+ retVal.totalBlocksPerOutputSplit = numBlocksPerOutputSplit
+
+ (retVal)
+ }).collect()
+
+ // Start tracker
+ var shuffleTracker = new ShuffleTracker(outputLocs)
+ shuffleTracker.setDaemon(true)
+ shuffleTracker.start()
+ logInfo("ShuffleTracker started...")
+
+ // Return an RDD that does each of the merges for a given partition
+ val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits)
+ return indexes.flatMap((myId: Int) => {
+ totalSplits = outputLocs.size
+ hasSplits = 0
+
+ totalBlocksInSplit = Array.tabulate(totalSplits)(_ => -1)
+ hasBlocksInSplit = Array.tabulate(totalSplits)(_ => 0)
+
+ hasSplitsBitVector = new BitSet(totalSplits)
+ splitsInRequestBitVector = new BitSet(totalSplits)
+
+ receivedData = new LinkedBlockingQueue[(Int, Array[Byte])]
+ combiners = new HashMap[K, C]
+
+ var threadPool = Shuffle.newDaemonFixedThreadPool(
+ Shuffle.MaxRxConnections)
+
+ while (hasSplits < totalSplits) {
+ // Local status of hasSplitsBitVector and splitsInRequestBitVector
+ val localSplitInfo = getLocalSplitInfo(myId)
+
+ // DO NOT talk to the tracker if all the required splits are already busy
+ val hasOrWillHaveSplits = localSplitInfo.hasSplitsBitVector.cardinality
+
+ var numThreadsToCreate =
+ Math.min(totalSplits - hasOrWillHaveSplits, Shuffle.MaxRxConnections) -
+ threadPool.getActiveCount
+
+ while (hasSplits < totalSplits && numThreadsToCreate > 0) {
+ // Receive which split to pull from the tracker
+ logInfo("Talking to tracker...")
+ val splitIndex = getTrackerSelectedSplit(myId)
+ logInfo("Got %d from tracker...".format(splitIndex))
+
+ if (splitIndex != -1) {
+ val selectedSplitInfo = outputLocs(splitIndex)
+ val requestSplit =
+ "%d/%d/%d".format(shuffleId, selectedSplitInfo.splitId, myId)
+
+ threadPool.execute(new ShuffleClient(splitIndex, selectedSplitInfo,
+ requestSplit, myId))
+
+ // splitIndex is in transit. Will be unset in the ShuffleClient
+ splitsInRequestBitVector.synchronized {
+ splitsInRequestBitVector.set(splitIndex)
+ }
+ } else {
+ // Tracker replied back with a NO. Sleep for a while.
+ Thread.sleep(Shuffle.MinKnockInterval)
+ }
+
+ numThreadsToCreate = numThreadsToCreate - 1
+ }
+
+ // Sleep for a while before creating new threads
+ Thread.sleep(Shuffle.MinKnockInterval)
+ }
+
+ threadPool.shutdown()
+
+ // Start consumer
+ // TODO: Consumption is delayed until everything has been received.
+ // Otherwise it interferes with network performance
+ var shuffleConsumer = new ShuffleConsumer(mergeCombiners)
+ shuffleConsumer.setDaemon(true)
+ shuffleConsumer.start()
+ logInfo("ShuffleConsumer started...")
+
+ // Don't return until consumption is finished
+ // TODO: Replace with a lock later.
+ while (receivedData.size > 0) {
+ Thread.sleep(Shuffle.MinKnockInterval)
+ }
+
+ combiners
+ })
+ }
+
+ private def getLocalSplitInfo(myId: Int): SplitInfo = {
+ var localSplitInfo = SplitInfo(InetAddress.getLocalHost.getHostAddress,
+ SplitInfo.UnusedParam, myId)
+
+ // Store hasSplits
+ localSplitInfo.hasSplits = hasSplits
+
+ // Store hasSplitsBitVector
+ hasSplitsBitVector.synchronized {
+ localSplitInfo.hasSplitsBitVector =
+ hasSplitsBitVector.clone.asInstanceOf[BitSet]
+ }
+
+ // Store hasBlocksInSplit to hasBlocksPerInputSplit
+ hasBlocksInSplit.synchronized {
+ localSplitInfo.hasBlocksPerInputSplit =
+ hasBlocksInSplit.clone.asInstanceOf[Array[Int]]
+ }
+
+ // Include the splitsInRequest as well
+ splitsInRequestBitVector.synchronized {
+ localSplitInfo.hasSplitsBitVector.or(splitsInRequestBitVector)
+ }
+
+ return localSplitInfo
+ }
+
+ def selectRandomSplit: Int = {
+ var requiredSplits = new ArrayBuffer[Int]
+
+ synchronized {
+ for (i <- 0 until totalSplits) {
+ if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) {
+ requiredSplits += i
+ }
+ }
+ }
+
+ if (requiredSplits.size > 0) {
+ requiredSplits(TrackedCustomBlockedInMemoryShuffle.ranGen.nextInt(
+ requiredSplits.size))
+ } else {
+ -1
+ }
+ }
+
+ // Talks to the tracker and receives instruction
+ private def getTrackerSelectedSplit(myId: Int): Int = {
+ // Local status of hasSplitsBitVector and splitsInRequestBitVector
+ val localSplitInfo = getLocalSplitInfo(myId)
+
+ // DO NOT talk to the tracker if all the required splits are already busy
+ if (localSplitInfo.hasSplitsBitVector.cardinality == totalSplits) {
+ return -1
+ }
+
+ val clientSocketToTracker = new Socket(Shuffle.MasterHostAddress,
+ Shuffle.MasterTrackerPort)
+ val oosTracker =
+ new ObjectOutputStream(clientSocketToTracker.getOutputStream)
+ oosTracker.flush()
+ val oisTracker =
+ new ObjectInputStream(clientSocketToTracker.getInputStream)
+
+ var selectedSplitIndex = -1
+
+ // Setup the timeout mechanism
+ var timeOutTask = new TimerTask {
+ override def run: Unit = {
+ logInfo("Waited enough for tracker response... Take random response...")
+
+ // sockets will be closed in finally
+
+ // TODO: Selecting randomly here. Tracker won't know about it and get an
+ // asssertion failure when this thread leaves
+
+ selectedSplitIndex = selectRandomSplit
+ }
+ }
+
+ var timeOutTimer = new Timer
+ // TODO: Which timeout to use?
+ timeOutTimer.schedule(timeOutTask, Shuffle.MinKnockInterval)
+
+ try {
+ // Send intention
+ oosTracker.writeObject(
+ TrackedCustomBlockedInMemoryShuffle.ReducerEntering)
+ oosTracker.flush()
+
+ // Send what this reducer has
+ oosTracker.writeObject(localSplitInfo)
+ oosTracker.flush()
+
+ // Receive reply from the tracker
+ selectedSplitIndex = oisTracker.readObject.asInstanceOf[Int]
+
+ // Turn the timer OFF
+ timeOutTimer.cancel()
+ } catch {
+ case e: Exception => {
+ logInfo("getTrackerSelectedSplit had a " + e)
+ }
+ } finally {
+ oisTracker.close()
+ oosTracker.close()
+ clientSocketToTracker.close()
+ }
+
+ return selectedSplitIndex
+ }
+
+ class ShuffleTracker(outputLocs: Array[SplitInfo])
+ extends Thread with Logging {
+ var threadPool = Shuffle.newDaemonCachedThreadPool
+ var serverSocket: ServerSocket = null
+
+ // Create trackerStrategy object
+ val trackerStrategyClass = System.getProperty(
+ "spark.shuffle.trackerStrategy",
+ "spark.BalanceConnectionsShuffleTrackerStrategy")
+
+ val trackerStrategy =
+ Class.forName(trackerStrategyClass).newInstance().asInstanceOf[ShuffleTrackerStrategy]
+
+ // Must initialize here by supplying the outputLocs param
+ // TODO: This could be avoided by directly passing it to the constructor
+ trackerStrategy.initialize(outputLocs)
+
+ override def run: Unit = {
+ serverSocket = new ServerSocket(Shuffle.MasterTrackerPort)
+ logInfo("ShuffleTracker" + serverSocket)
+
+ try {
+ while (true) {
+ var clientSocket: Socket = null
+ try {
+ clientSocket = serverSocket.accept()
+ } catch {
+ case e: Exception => {
+ logInfo("ShuffleTracker had a " + e)
+ }
+ }
+
+ if (clientSocket != null) {
+ try {
+ threadPool.execute(new Thread {
+ override def run: Unit = {
+ val oos = new ObjectOutputStream(clientSocket.getOutputStream)
+ oos.flush()
+ val ois = new ObjectInputStream(clientSocket.getInputStream)
+
+ try {
+ // Receive intention
+ val reducerIntention = ois.readObject.asInstanceOf[Int]
+
+ if (reducerIntention ==
+ TrackedCustomBlockedInMemoryShuffle.ReducerEntering) {
+ // Receive what the reducer has
+ val reducerSplitInfo =
+ ois.readObject.asInstanceOf[SplitInfo]
+
+ // Select split and update stats if necessary
+ var selectedSplitIndex = -1
+ trackerStrategy.synchronized {
+ selectedSplitIndex = trackerStrategy.selectSplit(
+ reducerSplitInfo)
+ }
+
+ // Send reply back
+ oos.writeObject(selectedSplitIndex)
+ oos.flush()
+
+ // Update internal stats, only if receiver got the reply
+ trackerStrategy.synchronized {
+ trackerStrategy.AddReducerToSplit(reducerSplitInfo,
+ selectedSplitIndex)
+ }
+ }
+ else if (reducerIntention ==
+ TrackedCustomBlockedInMemoryShuffle.ReducerLeaving) {
+ val reducerSplitInfo =
+ ois.readObject.asInstanceOf[SplitInfo]
+
+ // Receive reception stats: how many blocks the reducer
+ // read in how much time and from where
+ val receptionStat =
+ ois.readObject.asInstanceOf[ReceptionStats]
+
+ // Update stats
+ trackerStrategy.synchronized {
+ trackerStrategy.deleteReducerFrom(reducerSplitInfo,
+ receptionStat)
+ }
+
+ // Send ACK
+ oos.writeObject(receptionStat.serverSplitIndex)
+ oos.flush()
+ }
+ else {
+ throw new SparkException("Undefined reducerIntention")
+ }
+ } catch {
+ // EOFException is expected to happen because receiver can
+ // break connection due to timeout and pick random instead
+ case eofe: java.io.EOFException => { }
+ case e: Exception => {
+ logInfo("ShuffleTracker had a " + e)
+ }
+ } finally {
+ ois.close()
+ oos.close()
+ clientSocket.close()
+ }
+ }
+ })
+ } catch {
+ // In failure, close socket here; else, client thread will close
+ case ioe: IOException => {
+ clientSocket.close()
+ }
+ }
+ }
+ }
+ } finally {
+ serverSocket.close()
+ }
+ // Shutdown the thread pool
+ threadPool.shutdown()
+ }
+ }
+
+ class ShuffleConsumer(mergeCombiners: (C, C) => C)
+ extends Thread with Logging {
+ override def run: Unit = {
+ // Run until all splits are here
+ while (receivedData.size > 0) {
+ var splitIndex = -1
+ var recvByteArray: Array[Byte] = null
+
+ try {
+ var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])]
+ splitIndex = tempPair._1
+ recvByteArray = tempPair._2
+ } catch {
+ case e: Exception => {
+ logInfo("Exception during taking data from receivedData")
+ }
+ }
+
+ val inputStream =
+ new ObjectInputStream(new ByteArrayInputStream(recvByteArray))
+
+ try{
+ while (true) {
+ val (k, c) = inputStream.readObject.asInstanceOf[(K, C)]
+ combiners(k) = combiners.get(k) match {
+ case Some(oldC) => mergeCombiners(oldC, c)
+ case None => c
+ }
+ }
+ } catch {
+ case e: EOFException => { }
+ }
+ inputStream.close()
+ }
+ }
+ }
+
+ class ShuffleClient(splitIndex: Int, serversplitInfo: SplitInfo,
+ requestSplit: String, myId: Int)
+ extends Thread with Logging {
+ private var peerSocketToSource: Socket = null
+ private var oosSource: ObjectOutputStream = null
+ private var oisSource: ObjectInputStream = null
+
+ private var receptionSucceeded = false
+
+ // Make sure that multiple messages don't go to the tracker
+ private var alreadySentLeavingNotification = false
+
+ // Keep track of bytes received and time spent
+ private var numBytesReceived = 0
+ private var totalTimeSpent = 0
+
+ override def run: Unit = {
+ // Setup the timeout mechanism
+ var timeOutTask = new TimerTask {
+ override def run: Unit = {
+ cleanUp()
+ }
+ }
+
+ var timeOutTimer = new Timer
+ timeOutTimer.schedule(timeOutTask, Shuffle.MaxKnockInterval)
+
+ try {
+ // Everything will break if BLOCKNUM is not correctly received
+ // First get BLOCKNUM file if totalBlocksInSplit(splitIndex) is unknown
+ peerSocketToSource = new Socket(
+ serversplitInfo.hostAddress, serversplitInfo.listenPort)
+ oosSource =
+ new ObjectOutputStream(peerSocketToSource.getOutputStream)
+ oosSource.flush()
+ var isSource = peerSocketToSource.getInputStream
+ oisSource = new ObjectInputStream(isSource)
+
+ // Send path information
+ oosSource.writeObject(requestSplit)
+
+ // TODO: Can be optimized. No need to do it everytime.
+ // Receive BLOCKNUM
+ totalBlocksInSplit(splitIndex) = oisSource.readObject.asInstanceOf[Int]
+
+ // Turn the timer OFF, if the sender responds before timeout
+ timeOutTimer.cancel()
+
+ while (hasBlocksInSplit(splitIndex) < totalBlocksInSplit(splitIndex)) {
+ // Set receptionSucceeded to false before trying for each block
+ receptionSucceeded = false
+
+ // Request specific block
+ oosSource.writeObject(hasBlocksInSplit(splitIndex))
+
+ // Good to go. First, receive the length of the requested file
+ var requestedFileLen = oisSource.readObject.asInstanceOf[Int]
+ logInfo("Received requestedFileLen = " + requestedFileLen)
+
+ // Create a temp variable to be used in different places
+ val requestPath = "http://%s:%d/shuffle/%s-%d".format(
+ serversplitInfo.hostAddress, serversplitInfo.listenPort, requestSplit,
+ hasBlocksInSplit(splitIndex))
+
+ // Receive the file
+ if (requestedFileLen != -1) {
+ val readStartTime = System.currentTimeMillis
+ logInfo("BEGIN READ: " + requestPath)
+
+ // Receive data in an Array[Byte]
+ var recvByteArray = new Array[Byte](requestedFileLen)
+ var alreadyRead = 0
+ var bytesRead = 0
+
+ while (alreadyRead != requestedFileLen) {
+ bytesRead = isSource.read(recvByteArray, alreadyRead,
+ requestedFileLen - alreadyRead)
+ if (bytesRead > 0) {
+ alreadyRead = alreadyRead + bytesRead
+ }
+ }
+
+ // Make it available to the consumer
+ try {
+ receivedData.put((splitIndex, recvByteArray))
+ } catch {
+ case e: Exception => {
+ logInfo("Exception during putting data into receivedData")
+ }
+ }
+
+ // TODO: Updating stats before consumption is completed
+ hasBlocksInSplit(splitIndex) = hasBlocksInSplit(splitIndex) + 1
+
+ // Split has been received only if all the blocks have been received
+ if (hasBlocksInSplit(splitIndex) == totalBlocksInSplit(splitIndex)) {
+ hasSplitsBitVector.synchronized {
+ hasSplitsBitVector.set(splitIndex)
+ }
+ hasSplits += 1
+ }
+
+ receptionSucceeded = true
+
+ logInfo("END READ: " + requestPath)
+ val readTime = System.currentTimeMillis - readStartTime
+ logInfo("Reading " + requestPath + " took " + readTime + " millis.")
+
+ // Update stats
+ numBytesReceived = numBytesReceived + requestedFileLen
+ totalTimeSpent = totalTimeSpent + readTime.toInt
+ } else {
+ throw new SparkException("ShuffleServer " + serversplitInfo.hostAddress + " does not have " + requestSplit)
+ }
+ }
+ } catch {
+ // EOFException is expected to happen because sender can break
+ // connection due to timeout
+ case eofe: java.io.EOFException => { }
+ case e: Exception => {
+ logInfo("ShuffleClient had a " + e)
+ }
+ } finally {
+ splitsInRequestBitVector.synchronized {
+ splitsInRequestBitVector.set(splitIndex, false)
+ }
+ cleanUp()
+ }
+ }
+
+ // Connect to the tracker and update its stats
+ private def sendLeavingNotification(): Unit = synchronized {
+ if (!alreadySentLeavingNotification) {
+ val clientSocketToTracker = new Socket(Shuffle.MasterHostAddress,
+ Shuffle.MasterTrackerPort)
+ val oosTracker =
+ new ObjectOutputStream(clientSocketToTracker.getOutputStream)
+ oosTracker.flush()
+ val oisTracker =
+ new ObjectInputStream(clientSocketToTracker.getInputStream)
+
+ try {
+ // Send intention
+ oosTracker.writeObject(
+ TrackedCustomBlockedInMemoryShuffle.ReducerLeaving)
+ oosTracker.flush()
+
+ // Send reducerSplitInfo
+ oosTracker.writeObject(getLocalSplitInfo(myId))
+ oosTracker.flush()
+
+ // Send reception stats
+ oosTracker.writeObject(ReceptionStats(
+ numBytesReceived, totalTimeSpent, splitIndex))
+ oosTracker.flush()
+
+ // Receive ACK. No need to do anything with that
+ oisTracker.readObject.asInstanceOf[Int]
+
+ // Now update sentLeavingNotifacation
+ alreadySentLeavingNotification = true
+ } catch {
+ case e: Exception => {
+ logInfo("sendLeavingNotification had a " + e)
+ }
+ } finally {
+ oisTracker.close()
+ oosTracker.close()
+ clientSocketToTracker.close()
+ }
+ }
+ }
+
+ private def cleanUp(): Unit = {
+ // Update tracker stats first
+ sendLeavingNotification()
+
+ // Clean up the connections to the mapper
+ if (oisSource != null) {
+ oisSource.close()
+ }
+ if (oosSource != null) {
+ oosSource.close()
+ }
+ if (peerSocketToSource != null) {
+ peerSocketToSource.close()
+ }
+
+ logInfo("Leaving client")
+ }
+ }
+}
+
+object TrackedCustomBlockedInMemoryShuffle extends Logging {
+ // Tracker communication constants
+ val ReducerEntering = 0
+ val ReducerLeaving = 1
+
+ // Cache for keeping the splits around
+ val splitsCache = new HashMap[String, Array[Byte]]
+
+ private var initialized = false
+ private var nextShuffleId = new AtomicLong(0)
+
+ // Variables initialized by initializeIfNeeded()
+ private var shuffleDir: File = null
+
+ private var shuffleServer: ShuffleServer = null
+ private var serverAddress = InetAddress.getLocalHost.getHostAddress
+ private var serverPort: Int = -1
+
+ // Random number generator
+ var ranGen = new Random
+
+ private def initializeIfNeeded() = synchronized {
+ if (!initialized) {
+ // TODO: localDir should be created by some mechanism common to Spark
+ // so that it can be shared among shuffle, broadcast, etc
+ val localDirRoot = System.getProperty("spark.local.dir", "/tmp")
+ var tries = 0
+ var foundLocalDir = false
+ var localDir: File = null
+ var localDirUuid: UUID = null
+ while (!foundLocalDir && tries < 10) {
+ tries += 1
+ try {
+ localDirUuid = UUID.randomUUID
+ localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
+ if (!localDir.exists) {
+ localDir.mkdirs()
+ foundLocalDir = true
+ }
+ } catch {
+ case e: Exception =>
+ logWarning("Attempt " + tries + " to create local dir failed", e)
+ }
+ }
+ if (!foundLocalDir) {
+ logError("Failed 10 attempts to create local dir in " + localDirRoot)
+ System.exit(1)
+ }
+ shuffleDir = new File(localDir, "shuffle")
+ shuffleDir.mkdirs()
+ logInfo("Shuffle dir: " + shuffleDir)
+
+ // Create and start the shuffleServer
+ shuffleServer = new ShuffleServer
+ shuffleServer.setDaemon(true)
+ shuffleServer.start()
+ logInfo("ShuffleServer started...")
+
+ initialized = true
+ }
+ }
+
+ def getSplitName(shuffleId: Long, inputId: Int, outputId: Int,
+ blockId: Int): String = {
+ initializeIfNeeded()
+ // Adding shuffleDir is unnecessary. Added to keep the parsers working
+ return "%s/%d/%d/%d-%d".format(shuffleDir, shuffleId, inputId, outputId,
+ blockId)
+ }
+
+ def getBlockNumOutputName(shuffleId: Long, inputId: Int,
+ outputId: Int): String = {
+ initializeIfNeeded()
+ return "%s/%d/%d/%d-BLOCKNUM".format(shuffleDir, shuffleId, inputId,
+ outputId)
+ }
+
+ def newShuffleId(): Long = {
+ nextShuffleId.getAndIncrement()
+ }
+
+ class ShuffleServer
+ extends Thread with Logging {
+ var threadPool = Shuffle.newDaemonFixedThreadPool(Shuffle.MaxTxConnections)
+
+ var serverSocket: ServerSocket = null
+
+ override def run: Unit = {
+ serverSocket = new ServerSocket(0)
+ serverPort = serverSocket.getLocalPort
+
+ logInfo("ShuffleServer started with " + serverSocket)
+ logInfo("Local URI: http://" + serverAddress + ":" + serverPort)
+
+ try {
+ while (true) {
+ var clientSocket: Socket = null
+ try {
+ clientSocket = serverSocket.accept()
+ } catch {
+ case e: Exception => { }
+ }
+ if (clientSocket != null) {
+ logInfo("Serve: Accepted new client connection:" + clientSocket)
+ try {
+ threadPool.execute(new ShuffleServerThread(clientSocket))
+ } catch {
+ // In failure, close socket here; else, the thread will close it
+ case ioe: IOException => {
+ clientSocket.close()
+ }
+ }
+ }
+ }
+ } finally {
+ if (serverSocket != null) {
+ logInfo("ShuffleServer now stopping...")
+ serverSocket.close()
+ }
+ }
+ // Shutdown the thread pool
+ threadPool.shutdown()
+ }
+
+ class ShuffleServerThread(val clientSocket: Socket)
+ extends Thread with Logging {
+ private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream]
+ os.flush()
+ private val bos = new BufferedOutputStream(os)
+ bos.flush()
+ private val oos = new ObjectOutputStream(os)
+ oos.flush()
+ private val ois = new ObjectInputStream(clientSocket.getInputStream)
+
+ logInfo("new ShuffleServerThread is running")
+
+ override def run: Unit = {
+ try {
+ // Receive basic path information
+ var requestedSplitBase = ois.readObject.asInstanceOf[String]
+
+ logInfo("requestedSplitBase: " + requestedSplitBase)
+
+ // Read BLOCKNUM and send back the total number of blocks
+ val blockNumName = "%s/%s-BLOCKNUM".format(shuffleDir,
+ requestedSplitBase)
+
+ val blockNumIn = new ObjectInputStream(new ByteArrayInputStream(
+ TrackedCustomBlockedInMemoryShuffle.splitsCache(blockNumName)))
+ val BLOCKNUM = blockNumIn.readObject.asInstanceOf[Int]
+ blockNumIn.close()
+
+ oos.writeObject(BLOCKNUM)
+ oos.flush()
+
+ val startTime = System.currentTimeMillis
+ var curTime = startTime
+ var keepSending = true
+ var numBlocksToSend = Shuffle.MaxChatBlocks
+
+ while (keepSending && numBlocksToSend > 0) {
+ // Receive specific block request
+ val blockId = ois.readObject.asInstanceOf[Int]
+
+ // Ready to send
+ var requestedSplit = shuffleDir + "/" + requestedSplitBase + "-" + blockId
+
+ // Send the length of the requestedSplit to let the receiver know that
+ // transfer is about to start
+ // In the case of receiver timeout and connection close, this will
+ // throw a java.net.SocketException: Broken pipe
+ var requestedSplitLen = -1
+
+ try {
+ requestedSplitLen =
+ TrackedCustomBlockedInMemoryShuffle.splitsCache(requestedSplit).length
+ } catch {
+ case e: Exception => { }
+ }
+
+ oos.writeObject(requestedSplitLen)
+ oos.flush()
+
+ logInfo("requestedSplitLen = " + requestedSplitLen)
+
+ // Read and send the requested file
+ if (requestedSplitLen != -1) {
+ // Send
+ bos.write(TrackedCustomBlockedInMemoryShuffle.splitsCache(requestedSplit),
+ 0, requestedSplitLen)
+ bos.flush()
+
+ // Update loop variables
+ numBlocksToSend = numBlocksToSend - 1
+
+ curTime = System.currentTimeMillis
+ // Revoke sending only if there is anyone waiting in the queue
+ // TODO: Turning OFF the optimization so that reducers go back to
+ // tracker get advice
+ if (curTime - startTime >= Shuffle.MaxChatTime /* &&
+ threadPool.getQueue.size > 0 */) {
+ keepSending = false
+ }
+ } else {
+ // Close the connection
+ }
+ }
+ } catch {
+ // If something went wrong, e.g., the worker at the other end died etc
+ // then close everything up
+ // Exception can happen if the receiver stops receiving
+ // EOFException is expected to happen because receiver can break
+ // connection as soon as it has all the blocks
+ case eofe: java.io.EOFException => { }
+ case e: Exception => {
+ logInfo("ShuffleServerThread had a " + e)
+ }
+ } finally {
+ logInfo("ShuffleServerThread is closing streams and sockets")
+ ois.close()
+ // TODO: Following can cause "java.net.SocketException: Socket closed"
+ oos.close()
+ bos.close()
+ clientSocket.close()
+ }
+ }
+ }
+ }
+}
diff --git a/src/scala/spark/TrackedCustomBlockedLocalFileShuffle.scala b/src/scala/spark/TrackedCustomBlockedLocalFileShuffle.scala
index a1032368dd..798aba9598 100644
--- a/src/scala/spark/TrackedCustomBlockedLocalFileShuffle.scala
+++ b/src/scala/spark/TrackedCustomBlockedLocalFileShuffle.scala
@@ -836,6 +836,7 @@ object TrackedCustomBlockedLocalFileShuffle extends Logging {
blockNumIn.close()
oos.writeObject(BLOCKNUM)
+ oos.flush()
val startTime = System.currentTimeMillis
var curTime = startTime