diff options
author | Mosharaf Chowdhury <mosharaf@cs.berkeley.edu> | 2011-04-27 20:53:43 -0700 |
---|---|---|
committer | Mosharaf Chowdhury <mosharaf@cs.berkeley.edu> | 2011-04-27 20:53:43 -0700 |
commit | 2742de707a7abfd76e3de20e10a0e4a974f12fd5 (patch) | |
tree | 2befc1d283f3b239ab2b033491cd79fcf900010d | |
parent | 9d78779257b156bec335af4ab2a66bb3cac30ca6 (diff) | |
download | spark-2742de707a7abfd76e3de20e10a0e4a974f12fd5.tar.gz spark-2742de707a7abfd76e3de20e10a0e4a974f12fd5.tar.bz2 spark-2742de707a7abfd76e3de20e10a0e4a974f12fd5.zip |
Removed some shuffle implementations. Remaining ones all use local files
to write map outputs.
9 files changed, 0 insertions, 5723 deletions
diff --git a/core/src/main/scala/spark/CustomBlockedInMemoryShuffle.scala b/core/src/main/scala/spark/CustomBlockedInMemoryShuffle.scala deleted file mode 100644 index 5aef43f302..0000000000 --- a/core/src/main/scala/spark/CustomBlockedInMemoryShuffle.scala +++ /dev/null @@ -1,629 +0,0 @@ -package spark - -import java.io._ -import java.net._ -import java.util.{BitSet, Random, Timer, TimerTask, UUID} -import java.util.concurrent.atomic.AtomicLong -import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory} - -import scala.collection.mutable.{ArrayBuffer, HashMap} - -/** - * TODO: THIS IS AN ABSOLUTELY EXPERIMENTAL IMPLEMENTATON (FOR NOW). - * - * An implementation of shuffle using local memory served through custom server - * where receivers create simultaneous connections to multiple servers by - * setting the 'spark.shuffle.maxRxConnections' config option. - * - * By controlling the 'spark.shuffle.blockSize' config option one can also - * control the largest block size to divide each map output into. Essentially, - * instead of creating one large output file for each reducer, maps create - * multiple smaller files to enable finer level of engagement. - * - * 'spark.shuffle.maxTxConnections' enforces server-side cap. Ideally, - * maxTxConnections >= maxRxConnections * numReducersPerMachine - * - * TODO: Add support for compression when spark.compress is set to true. - */ -@serializable -class CustomBlockedInMemoryShuffle[K, V, C] -extends Shuffle[K, V, C] with Logging { - @transient var totalSplits = 0 - @transient var hasSplits = 0 - - @transient var totalBlocksInSplit: Array[Int] = null - @transient var hasBlocksInSplit: Array[Int] = null - - @transient var hasSplitsBitVector: BitSet = null - @transient var splitsInRequestBitVector: BitSet = null - - @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null - @transient var combiners: HashMap[K,C] = null - - override def compute(input: RDD[(K, V)], - numOutputSplits: Int, - createCombiner: V => C, - mergeValue: (C, V) => C, - mergeCombiners: (C, C) => C) - : RDD[(K, C)] = - { - val sc = input.sparkContext - val shuffleId = CustomBlockedInMemoryShuffle.newShuffleId() - logInfo("Shuffle ID: " + shuffleId) - - val splitRdd = new NumberedSplitRDD(input) - val numInputSplits = splitRdd.splits.size - - // Run a parallel map and collect to write the intermediate data files, - // returning a list of inputSplitId -> serverUri pairs - val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => { - val myIndex = pair._1 - val myIterator = pair._2 - val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C]) - for ((k, v) <- myIterator) { - var bucketId = k.hashCode % numOutputSplits - if (bucketId < 0) { // Fix bucket ID if hash code was negative - bucketId += numOutputSplits - } - val bucket = buckets(bucketId) - bucket(k) = bucket.get(k) match { - case Some(c) => mergeValue(c, v) - case None => createCombiner(v) - } - } - - for (i <- 0 until numOutputSplits) { - var blockNum = 0 - var isDirty = false - - var splitName = "" - var baos: ByteArrayOutputStream = null - var oos: ObjectOutputStream = null - - var writeStartTime: Long = 0 - - buckets(i).foreach(pair => { - // Open a new stream if necessary - if (!isDirty) { - splitName = CustomBlockedInMemoryShuffle.getSplitName(shuffleId, - myIndex, i, blockNum) - - baos = new ByteArrayOutputStream - oos = new ObjectOutputStream(baos) - - writeStartTime = System.currentTimeMillis - logInfo("BEGIN WRITE: " + splitName) - } - - oos.writeObject(pair) - isDirty = true - - // Close the old stream if has crossed the blockSize limit - if (baos.size > Shuffle.BlockSize) { - CustomBlockedInMemoryShuffle.splitsCache(splitName) = - baos.toByteArray - - logInfo("END WRITE: " + splitName) - val writeTime = System.currentTimeMillis - writeStartTime - logInfo("Writing " + splitName + " of size " + baos.size + " bytes took " + writeTime + " millis.") - - blockNum = blockNum + 1 - isDirty = false - oos.close() - } - }) - - if (isDirty) { - CustomBlockedInMemoryShuffle.splitsCache(splitName) = baos.toByteArray - - logInfo("END WRITE: " + splitName) - val writeTime = System.currentTimeMillis - writeStartTime - logInfo("Writing " + splitName + " of size " + baos.size + " bytes took " + writeTime + " millis.") - - blockNum = blockNum + 1 - oos.close() - } - - // Store BLOCKNUM info - splitName = CustomBlockedInMemoryShuffle.getBlockNumOutputName( - shuffleId, myIndex, i) - baos = new ByteArrayOutputStream - oos = new ObjectOutputStream(baos) - oos.writeObject(blockNum) - CustomBlockedInMemoryShuffle.splitsCache(splitName) = baos.toByteArray - - // Close streams - oos.close() - } - - (myIndex, CustomBlockedInMemoryShuffle.serverAddress, - CustomBlockedInMemoryShuffle.serverPort) - }).collect() - - val splitsByUri = new ArrayBuffer[(String, Int, Int)] - for ((inputId, serverAddress, serverPort) <- outputLocs) { - splitsByUri += ((serverAddress, serverPort, inputId)) - } - - // TODO: Could broadcast outputLocs - - // Return an RDD that does each of the merges for a given partition - val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits) - return indexes.flatMap((myId: Int) => { - totalSplits = outputLocs.size - hasSplits = 0 - - totalBlocksInSplit = Array.tabulate(totalSplits)(_ => -1) - hasBlocksInSplit = Array.tabulate(totalSplits)(_ => 0) - - hasSplitsBitVector = new BitSet(totalSplits) - splitsInRequestBitVector = new BitSet(totalSplits) - - receivedData = new LinkedBlockingQueue[(Int, Array[Byte])] - combiners = new HashMap[K, C] - - var threadPool = Shuffle.newDaemonFixedThreadPool( - Shuffle.MaxRxConnections) - - while (hasSplits < totalSplits) { - var numThreadsToCreate = - Math.min(totalSplits, Shuffle.MaxRxConnections) - - threadPool.getActiveCount - - while (hasSplits < totalSplits && numThreadsToCreate > 0) { - // Select a random split to pull - val splitIndex = selectRandomSplit - - if (splitIndex != -1) { - val (serverAddress, serverPort, inputId) = splitsByUri(splitIndex) - - threadPool.execute(new ShuffleClient(serverAddress, serverPort, - shuffleId.toInt, inputId, myId, splitIndex)) - - // splitIndex is in transit. Will be unset in the ShuffleClient - splitsInRequestBitVector.synchronized { - splitsInRequestBitVector.set(splitIndex) - } - } - - numThreadsToCreate = numThreadsToCreate - 1 - } - - // Sleep for a while before creating new threads - Thread.sleep(Shuffle.MinKnockInterval) - } - - threadPool.shutdown() - - // Start consumer - // TODO: Consumption is delayed until everything has been received. - // Otherwise it interferes with network performance - var shuffleConsumer = new ShuffleConsumer(mergeCombiners) - shuffleConsumer.setDaemon(true) - shuffleConsumer.start() - logInfo("ShuffleConsumer started...") - - // Don't return until consumption is finished - // while (receivedData.size > 0) { - // Thread.sleep(Shuffle.MinKnockInterval) - // } - - // Wait till shuffleConsumer is done - shuffleConsumer.join - - combiners - }) - } - - def selectRandomSplit: Int = { - var requiredSplits = new ArrayBuffer[Int] - - synchronized { - for (i <- 0 until totalSplits) { - if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) { - requiredSplits += i - } - } - } - - if (requiredSplits.size > 0) { - requiredSplits(CustomBlockedInMemoryShuffle.ranGen.nextInt( - requiredSplits.size)) - } else { - -1 - } - } - - class ShuffleConsumer(mergeCombiners: (C, C) => C) - extends Thread with Logging { - override def run: Unit = { - // Run until all splits are here - while (receivedData.size > 0) { - var splitIndex = -1 - var recvByteArray: Array[Byte] = null - - try { - var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])] - splitIndex = tempPair._1 - recvByteArray = tempPair._2 - } catch { - case e: Exception => { - logInfo("Exception during taking data from receivedData") - } - } - - val inputStream = - new ObjectInputStream(new ByteArrayInputStream(recvByteArray)) - - try{ - while (true) { - val (k, c) = inputStream.readObject.asInstanceOf[(K, C)] - combiners(k) = combiners.get(k) match { - case Some(oldC) => mergeCombiners(oldC, c) - case None => c - } - } - } catch { - case e: EOFException => { } - } - inputStream.close() - } - } - } - - class ShuffleClient(hostAddress: String, listenPort: Int, shuffleId: Int, - inputId: Int, myId: Int, splitIndex: Int) - extends Thread with Logging { - private var peerSocketToSource: Socket = null - private var oosSource: ObjectOutputStream = null - private var oisSource: ObjectInputStream = null - - private var receptionSucceeded = false - - override def run: Unit = { - // Setup the timeout mechanism - var timeOutTask = new TimerTask { - override def run: Unit = { - cleanUpConnections() - } - } - - var timeOutTimer = new Timer - timeOutTimer.schedule(timeOutTask, Shuffle.MaxKnockInterval) - - try { - // Everything will break if BLOCKNUM is not correctly received - // First get BLOCKNUM file if totalBlocksInSplit(splitIndex) is unknown - peerSocketToSource = new Socket(hostAddress, listenPort) - oosSource = - new ObjectOutputStream(peerSocketToSource.getOutputStream) - oosSource.flush() - var isSource = peerSocketToSource.getInputStream - oisSource = new ObjectInputStream(isSource) - - // Send path information - oosSource.writeObject((shuffleId, inputId, myId)) - - // TODO: Can be optimized. No need to do it everytime. - // Receive BLOCKNUM - totalBlocksInSplit(splitIndex) = oisSource.readObject.asInstanceOf[Int] - - // Turn the timer OFF, if the sender responds before timeout - timeOutTimer.cancel() - - while (hasBlocksInSplit(splitIndex) < totalBlocksInSplit(splitIndex)) { - // Set receptionSucceeded to false before trying for each block - receptionSucceeded = false - - // Request specific block - oosSource.writeObject(hasBlocksInSplit(splitIndex)) - - // Good to go. First, receive the length of the requested file - var requestedFileLen = oisSource.readObject.asInstanceOf[Int] - logInfo("Received requestedFileLen = " + requestedFileLen) - - val requestSplit = "%d/%d/%d-%d".format(shuffleId, inputId, myId, - hasBlocksInSplit(splitIndex)) - - // Receive the file - if (requestedFileLen != -1) { - val readStartTime = System.currentTimeMillis - logInfo("BEGIN READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit)) - - // Receive data in an Array[Byte] - var recvByteArray = new Array[Byte](requestedFileLen) - var alreadyRead = 0 - var bytesRead = 0 - - while (alreadyRead != requestedFileLen) { - bytesRead = isSource.read(recvByteArray, alreadyRead, - requestedFileLen - alreadyRead) - if (bytesRead > 0) { - alreadyRead = alreadyRead + bytesRead - } - } - - // Make it available to the consumer - try { - receivedData.put((splitIndex, recvByteArray)) - } catch { - case e: Exception => { - logInfo("Exception during putting data into receivedData") - } - } - - // TODO: Updating stats before consumption is completed - hasBlocksInSplit(splitIndex) = hasBlocksInSplit(splitIndex) + 1 - - // Split has been received only if all the blocks have been received - if (hasBlocksInSplit(splitIndex) == totalBlocksInSplit(splitIndex)) { - hasSplitsBitVector.synchronized { - hasSplitsBitVector.set(splitIndex) - hasSplits += 1 - } - } - - // Consistent state in accounting variables - receptionSucceeded = true - - logInfo("END READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit)) - val readTime = System.currentTimeMillis - readStartTime - logInfo("Reading http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit) + " took " + readTime + " millis.") - } else { - throw new SparkException("ShuffleServer " + hostAddress + " does not have " + requestSplit) - } - } - } catch { - // EOFException is expected to happen because sender can break - // connection due to timeout - case eofe: java.io.EOFException => { } - case e: Exception => { - logInfo("ShuffleClient had a " + e) - } - } finally { - splitsInRequestBitVector.synchronized { - splitsInRequestBitVector.set(splitIndex, false) - } - cleanUpConnections() - } - } - - private def cleanUpConnections(): Unit = { - if (oisSource != null) { - oisSource.close() - } - if (oosSource != null) { - oosSource.close() - } - if (peerSocketToSource != null) { - peerSocketToSource.close() - } - } - } -} - -object CustomBlockedInMemoryShuffle extends Logging { - // Cache for keeping the splits around - val splitsCache = new HashMap[String, Array[Byte]] - - private var initialized = false - private var nextShuffleId = new AtomicLong(0) - - // Variables initialized by initializeIfNeeded() - private var shuffleDir: File = null - - private var shuffleServer: ShuffleServer = null - private var serverAddress = InetAddress.getLocalHost.getHostAddress - private var serverPort: Int = -1 - - // Random number generator - var ranGen = new Random - - private def initializeIfNeeded() = synchronized { - if (!initialized) { - // TODO: localDir should be created by some mechanism common to Spark - // so that it can be shared among shuffle, broadcast, etc - val localDirRoot = System.getProperty("spark.local.dir", "/tmp") - var tries = 0 - var foundLocalDir = false - var localDir: File = null - var localDirUuid: UUID = null - while (!foundLocalDir && tries < 10) { - tries += 1 - try { - localDirUuid = UUID.randomUUID - localDir = new File(localDirRoot, "spark-local-" + localDirUuid) - if (!localDir.exists) { - localDir.mkdirs() - foundLocalDir = true - } - } catch { - case e: Exception => - logWarning("Attempt " + tries + " to create local dir failed", e) - } - } - if (!foundLocalDir) { - logError("Failed 10 attempts to create local dir in " + localDirRoot) - System.exit(1) - } - shuffleDir = new File(localDir, "shuffle") - shuffleDir.mkdirs() - logInfo("Shuffle dir: " + shuffleDir) - - // Create and start the shuffleServer - shuffleServer = new ShuffleServer - shuffleServer.setDaemon(true) - shuffleServer.start() - logInfo("ShuffleServer started...") - - initialized = true - } - } - - def getSplitName(shuffleId: Long, inputId: Int, outputId: Int, - blockId: Int): String = { - initializeIfNeeded() - // Adding shuffleDir is unnecessary. Added to keep the parsers working - return "%s/%d/%d/%d-%d".format(shuffleDir, shuffleId, inputId, outputId, - blockId) - } - - def getBlockNumOutputName(shuffleId: Long, inputId: Int, - outputId: Int): String = { - initializeIfNeeded() - return "%s/%d/%d/BLOCKNUM-%d".format(shuffleDir, shuffleId, inputId, - outputId) - } - - def newShuffleId(): Long = { - nextShuffleId.getAndIncrement() - } - - class ShuffleServer - extends Thread with Logging { - var threadPool = Shuffle.newDaemonFixedThreadPool(Shuffle.MaxTxConnections) - - var serverSocket: ServerSocket = null - - override def run: Unit = { - serverSocket = new ServerSocket(0) - serverPort = serverSocket.getLocalPort - - logInfo("ShuffleServer started with " + serverSocket) - logInfo("Local URI: http://" + serverAddress + ":" + serverPort) - - try { - while (true) { - var clientSocket: Socket = null - try { - clientSocket = serverSocket.accept() - } catch { - case e: Exception => { } - } - if (clientSocket != null) { - logInfo("Serve: Accepted new client connection:" + clientSocket) - try { - threadPool.execute(new ShuffleServerThread(clientSocket)) - } catch { - // In failure, close socket here; else, the thread will close it - case ioe: IOException => { - clientSocket.close() - } - } - } - } - } finally { - if (serverSocket != null) { - logInfo("ShuffleServer now stopping...") - serverSocket.close() - } - } - // Shutdown the thread pool - threadPool.shutdown() - } - - class ShuffleServerThread(val clientSocket: Socket) - extends Thread with Logging { - private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream] - os.flush() - private val bos = new BufferedOutputStream(os) - bos.flush() - private val oos = new ObjectOutputStream(os) - oos.flush() - private val ois = new ObjectInputStream(clientSocket.getInputStream) - - logInfo("new ShuffleServerThread is running") - - override def run: Unit = { - try { - // Receive basic path information - val (shuffleId, myIndex, outputId) = - ois.readObject.asInstanceOf[(Int, Int, Int)] - - var requestedSplitBase = "%s/%d/%d/%d".format( - shuffleDir, shuffleId, myIndex, outputId) - logInfo("requestedSplitBase: " + requestedSplitBase) - - // Read BLOCKNUM and send back the total number of blocks - val blockNumName = "%s/%d/%d/BLOCKNUM-%d".format(shuffleDir, - shuffleId, myIndex, outputId) - - val blockNumIn = new ObjectInputStream(new ByteArrayInputStream( - CustomBlockedInMemoryShuffle.splitsCache(blockNumName))) - val BLOCKNUM = blockNumIn.readObject.asInstanceOf[Int] - blockNumIn.close() - - oos.writeObject(BLOCKNUM) - - val startTime = System.currentTimeMillis - var curTime = startTime - var keepSending = true - var numBlocksToSend = Shuffle.MaxChatBlocks - - while (keepSending && numBlocksToSend > 0) { - // Receive specific block request - val blockId = ois.readObject.asInstanceOf[Int] - - // Ready to send - var requestedSplit = requestedSplitBase + "-" + blockId - - // Send the length of the requestedSplit to let the receiver know that - // transfer is about to start - // In the case of receiver timeout and connection close, this will - // throw a java.net.SocketException: Broken pipe - var requestedSplitLen = -1 - - try { - requestedSplitLen = - CustomBlockedInMemoryShuffle.splitsCache(requestedSplit).length - } catch { - case e: Exception => { } - } - - oos.writeObject(requestedSplitLen) - oos.flush() - - logInfo("requestedSplitLen = " + requestedSplitLen) - - // Read and send the requested file - if (requestedSplitLen != -1) { - // Send - bos.write(CustomBlockedInMemoryShuffle.splitsCache(requestedSplit), - 0, requestedSplitLen) - bos.flush() - - // Update loop variables - numBlocksToSend = numBlocksToSend - 1 - - curTime = System.currentTimeMillis - // Revoke sending only if there is anyone waiting in the queue - if (curTime - startTime >= Shuffle.MaxChatTime && - threadPool.getQueue.size > 0) { - keepSending = false - } - } else { - // Close the connection - } - } - } catch { - // If something went wrong, e.g., the worker at the other end died etc - // then close everything up - // Exception can happen if the receiver stops receiving - // EOFException is expected to happen because receiver can break - // connection as soon as it has all the blocks - case eofe: java.io.EOFException => { } - case e: Exception => { - logInfo("ShuffleServerThread had a " + e) - } - } finally { - logInfo("ShuffleServerThread is closing streams and sockets") - ois.close() - // TODO: Following can cause "java.net.SocketException: Socket closed" - oos.close() - bos.close() - clientSocket.close() - } - } - } - } -} diff --git a/core/src/main/scala/spark/CustomParallelFakeShuffle.scala b/core/src/main/scala/spark/CustomParallelFakeShuffle.scala deleted file mode 100644 index 889c5111b6..0000000000 --- a/core/src/main/scala/spark/CustomParallelFakeShuffle.scala +++ /dev/null @@ -1,495 +0,0 @@ -package spark - -import java.io._ -import java.net._ -import java.util.{BitSet, Random, Timer, TimerTask, UUID} -import java.util.concurrent.atomic.AtomicLong -import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory} - -import scala.collection.mutable.{ArrayBuffer, HashMap} - -/** - * TODO: THIS IS AN ABSOLUTELY EXPERIMENTAL IMPLEMENTATON (FOR NOW). - * - * An implementation of shuffle using fake data served through custom server - * where receivers create simultaneous connections to multiple servers by - * setting the 'spark.shuffle.maxRxConnections' config option. - * - * TODO: Add support for compression when spark.compress is set to true. - */ -@serializable -class CustomParallelFakeShuffle[K, V, C] -extends Shuffle[K, V, C] with Logging { - @transient var totalSplits = 0 - @transient var hasSplits = 0 - @transient var hasSplitsBitVector: BitSet = null - @transient var splitsInRequestBitVector: BitSet = null - - @transient var combiners: HashMap[K,C] = null - - override def compute(input: RDD[(K, V)], - numOutputSplits: Int, - createCombiner: V => C, - mergeValue: (C, V) => C, - mergeCombiners: (C, C) => C) - : RDD[(K, C)] = - { - val sc = input.sparkContext - val shuffleId = CustomParallelFakeShuffle.newShuffleId() - logInfo("Shuffle ID: " + shuffleId) - - val splitRdd = new NumberedSplitRDD(input) - val numInputSplits = splitRdd.splits.size - - // Run a parallel map and collect to write the intermediate data files, - // returning a list of inputSplitId -> serverUri pairs - val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => { - val myIndex = pair._1 - val myIterator = pair._2 - val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C]) - for ((k, v) <- myIterator) { - var bucketId = k.hashCode % numOutputSplits - if (bucketId < 0) { // Fix bucket ID if hash code was negative - bucketId += numOutputSplits - } - val bucket = buckets(bucketId) - bucket(k) = bucket.get(k) match { - case Some(c) => mergeValue(c, v) - case None => createCombiner(v) - } - } - - for (i <- 0 until numOutputSplits) { - val splitName = - CustomParallelFakeShuffle.getSplitName(shuffleId, myIndex, i) - - val writeStartTime = System.currentTimeMillis - logInfo("BEGIN WRITE: " + splitName) - - // Write buckets(i) to a byte array & put in splitsCache instead of file - val baos = new ByteArrayOutputStream - val oos = new ObjectOutputStream(baos) - buckets(i).foreach(pair => oos.writeObject(pair)) - oos.close - baos.close - - // Store the length only - CustomParallelFakeShuffle.splitsCache(splitName) = baos.toByteArray.length - val splitLen = baos.toByteArray.length - - logInfo("END WRITE: " + splitName) - val writeTime = System.currentTimeMillis - writeStartTime - logInfo("Writing " + splitName + " of size " + splitLen + " bytes took " + writeTime + " millis.") - } - - (myIndex, CustomParallelFakeShuffle.serverAddress, - CustomParallelFakeShuffle.serverPort) - }).collect() - - val splitsByUri = new ArrayBuffer[(String, Int, Int)] - for ((inputId, serverAddress, serverPort) <- outputLocs) { - splitsByUri += ((serverAddress, serverPort, inputId)) - } - - // TODO: Could broadcast splitsByUri - - // Return an RDD that does each of the merges for a given partition - val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits) - return indexes.flatMap((myId: Int) => { - totalSplits = splitsByUri.size - hasSplits = 0 - hasSplitsBitVector = new BitSet(totalSplits) - splitsInRequestBitVector = new BitSet(totalSplits) - - combiners = new HashMap[K, C] - - var threadPool = Shuffle.newDaemonFixedThreadPool( - Shuffle.MaxRxConnections) - - var totalThreadsCreated = 0 - - while (hasSplits < totalSplits) { - var numThreadsToCreate = Math.min(totalSplits, - Shuffle.MaxRxConnections) - - threadPool.getActiveCount - - while (hasSplits < totalSplits && numThreadsToCreate > 0) { - // Select a random split to pull - val splitIndex = selectRandomSplit - - if (splitIndex != -1) { - val (serverAddress, serverPort, inputId) = splitsByUri(splitIndex) - val requestSplit = "%d/%d/%d".format(shuffleId, inputId, myId) - - threadPool.execute(new ShuffleClient(splitIndex, serverAddress, - serverPort, requestSplit)) - totalThreadsCreated += 1 - logInfo("totalThreadsCreated = %d / threadPool.getActiveCount = %d".format(totalThreadsCreated, threadPool.getActiveCount.toInt)) - - // splitIndex is in transit. Will be unset in the ShuffleClient - splitsInRequestBitVector.synchronized { - splitsInRequestBitVector.set(splitIndex) - } - } - - numThreadsToCreate = numThreadsToCreate - 1 - } - - // Sleep for a while before creating new threads - Thread.sleep(Shuffle.MinKnockInterval) - } - - threadPool.shutdown() - - combiners - }) - } - - def selectRandomSplit: Int = { - var requiredSplits = new ArrayBuffer[Int] - - synchronized { - for (i <- 0 until totalSplits) { - if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) { - requiredSplits += i - } - } - } - - if (requiredSplits.size > 0) { - requiredSplits(CustomParallelFakeShuffle.ranGen.nextInt( - requiredSplits.size)) - } else { - -1 - } - } - - class ShuffleClient(splitIndex: Int, hostAddress: String, listenPort: Int, - requestSplit: String) - extends Thread with Logging { - private var peerSocketToSource: Socket = null - private var oosSource: ObjectOutputStream = null - private var oisSource: ObjectInputStream = null - - private var receptionSucceeded = false - - override def run: Unit = { - // Setup the timeout mechanism - var timeOutTask = new TimerTask { - override def run: Unit = { - cleanUpConnections() - } - } - - var timeOutTimer = new Timer - timeOutTimer.schedule(timeOutTask, Shuffle.MaxKnockInterval) - - logInfo("ShuffleClient started... => %s:%d#%s".format(hostAddress, listenPort, requestSplit)) - - try { - // Connect to the source - peerSocketToSource = new Socket(hostAddress, listenPort) - oosSource = - new ObjectOutputStream(peerSocketToSource.getOutputStream) - oosSource.flush() - var isSource = peerSocketToSource.getInputStream - oisSource = new ObjectInputStream(isSource) - - // Send the request - oosSource.writeObject(requestSplit) - - // Receive the length of the requested file - var requestedFileLen = oisSource.readObject.asInstanceOf[Int] - logInfo("Received requestedFileLen = " + requestedFileLen) - - // Turn the timer OFF, if the sender responds before timeout - timeOutTimer.cancel() - - // Receive the file - if (requestedFileLen != -1) { - val readStartTime = System.currentTimeMillis - logInfo("BEGIN READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit)) - -// var recvByteArray = new Array[Byte](requestedFileLen) -// var alreadyRead = 0 -// var bytesRead = 0 -// -// while (alreadyRead != requestedFileLen) { -// bytesRead = isSource.read(recvByteArray, alreadyRead, -// requestedFileLen - alreadyRead) -// if (bytesRead > 0) { -// alreadyRead = alreadyRead + bytesRead -// } -// } - - // Receive data in an Array[Byte] - var recvByteArray = new Array[Byte](65536) - var alreadyRead = 0 - var bytesRead = 0 - - while (alreadyRead != requestedFileLen) { - bytesRead = isSource.read(recvByteArray) - if (bytesRead > 0) { - alreadyRead = alreadyRead + bytesRead - } - } - - // TODO: Updating stats before consumption is completed - hasSplitsBitVector.synchronized { - hasSplitsBitVector.set(splitIndex) - hasSplits += 1 - } - - // We have received splitIndex - splitsInRequestBitVector.synchronized { - splitsInRequestBitVector.set(splitIndex, false) - } - - receptionSucceeded = true - - logInfo("END READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit)) - val readTime = System.currentTimeMillis - readStartTime - logInfo("Reading http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit) + " took " + readTime + " millis.") - } else { - throw new SparkException("ShuffleServer " + hostAddress + " does not have " + requestSplit) - } - } catch { - // EOFException is expected to happen because sender can break - // connection due to timeout - case eofe: java.io.EOFException => { } - case e: Exception => { - logInfo("ShuffleClient had a " + e) - } - } finally { - // If reception failed, unset for future retry - if (!receptionSucceeded) { - splitsInRequestBitVector.synchronized { - splitsInRequestBitVector.set(splitIndex, false) - } - } - cleanUpConnections() - } - } - - private def cleanUpConnections(): Unit = { - if (oisSource != null) { - oisSource.close() - } - if (oosSource != null) { - oosSource.close() - } - if (peerSocketToSource != null) { - peerSocketToSource.close() - } - } - } -} - -object CustomParallelFakeShuffle extends Logging { - // Cache for keeping the splits around - val splitsCache = new HashMap[String, Int] - - private var initialized = false - private var nextShuffleId = new AtomicLong(0) - - // Variables initialized by initializeIfNeeded() - private var shuffleDir: File = null - - private var shuffleServer: ShuffleServer = null - private var serverAddress = InetAddress.getLocalHost.getHostAddress - private var serverPort: Int = -1 - - // Random number generator - var ranGen = new Random - - private def initializeIfNeeded() = synchronized { - if (!initialized) { - // TODO: localDir should be created by some mechanism common to Spark - // so that it can be shared among shuffle, broadcast, etc - val localDirRoot = System.getProperty("spark.local.dir", "/tmp") - var tries = 0 - var foundLocalDir = false - var localDir: File = null - var localDirUuid: UUID = null - - while (!foundLocalDir && tries < 10) { - tries += 1 - try { - localDirUuid = UUID.randomUUID - localDir = new File(localDirRoot, "spark-local-" + localDirUuid) - if (!localDir.exists) { - localDir.mkdirs() - foundLocalDir = true - } - } catch { - case e: Exception => - logWarning("Attempt " + tries + " to create local dir failed", e) - } - } - if (!foundLocalDir) { - logError("Failed 10 attempts to create local dir in " + localDirRoot) - System.exit(1) - } - shuffleDir = new File(localDir, "shuffle") - shuffleDir.mkdirs() - logInfo("Shuffle dir: " + shuffleDir) - - // Create and start the shuffleServer - shuffleServer = new ShuffleServer - shuffleServer.setDaemon(true) - shuffleServer.start() - logInfo("ShuffleServer started...") - - initialized = true - } - } - - def getSplitName(shuffleId: Long, inputId: Int, outputId: Int): String = { - initializeIfNeeded() - // Adding shuffleDir is unnecessary. Added to keep the parsers working - return "%s/%d/%d/%d".format(shuffleDir, shuffleId, inputId, outputId) - } - - def newShuffleId(): Long = { - nextShuffleId.getAndIncrement() - } - - // Returns a standard ThreadFactory except all threads are daemons - private def newDaemonThreadFactory: ThreadFactory = { - new ThreadFactory { - def newThread(r: Runnable): Thread = { - var t = Executors.defaultThreadFactory.newThread(r) - t.setDaemon(true) - return t - } - } - } - - // Wrapper over newFixedThreadPool - def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = { - var threadPool = - Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor] - - threadPool.setThreadFactory(newDaemonThreadFactory) - - return threadPool - } - - class ShuffleServer - extends Thread with Logging { - var threadPool = newDaemonFixedThreadPool(Shuffle.MaxTxConnections) - - var serverSocket: ServerSocket = null - - override def run: Unit = { - serverSocket = new ServerSocket(0) - serverPort = serverSocket.getLocalPort - - logInfo("ShuffleServer started with " + serverSocket) - logInfo("Local URI: http://" + serverAddress + ":" + serverPort) - - try { - while (true) { - var clientSocket: Socket = null - try { - clientSocket = serverSocket.accept() - } catch { - case e: Exception => { } - } - if (clientSocket != null) { - logInfo("Serve: Accepted new client connection:" + clientSocket) - try { - threadPool.execute(new ShuffleServerThread(clientSocket)) - } catch { - // In failure, close socket here; else, the thread will close it - case ioe: IOException => { - clientSocket.close() - } - } - } - } - } finally { - if (serverSocket != null) { - logInfo("ShuffleServer now stopping...") - serverSocket.close() - } - } - // Shutdown the thread pool - threadPool.shutdown() - } - - class ShuffleServerThread(val clientSocket: Socket) - extends Thread with Logging { - private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream] - os.flush() - private val bos = new BufferedOutputStream(os) - bos.flush() - private val oos = new ObjectOutputStream(os) - oos.flush() - private val ois = new ObjectInputStream(clientSocket.getInputStream) - - logInfo("new ShuffleServerThread is running") - - override def run: Unit = { - try { - // Receive requestedSplit from the receiver - // Adding shuffleDir is unnecessary. Added to keep the parsers working - var requestedSplit = - shuffleDir + "/" + ois.readObject.asInstanceOf[String] - logInfo("requestedSplit: " + requestedSplit) - - // Send the length of the requestedSplit to let the receiver know that - // transfer is about to start - // In the case of receiver timeout and connection close, this will - // throw a java.net.SocketException: Broken pipe - var requestedSplitLen = -1 - - try { - requestedSplitLen = - CustomParallelFakeShuffle.splitsCache(requestedSplit) - } catch { - case e: Exception => { } - } - - oos.writeObject(requestedSplitLen) - oos.flush() - - logInfo("requestedSplitLen = " + requestedSplitLen) - - // Read and send the requested split - if (requestedSplitLen != -1) { - // Send fake data -// var byteArray = new Array[Byte](requestedSplitLen) -// bos.write(byteArray, 0, byteArray.length) -// bos.flush() - - val buf = new Array[Byte](65536) - var bytesSent = 0 - while (bytesSent < requestedSplitLen) { - val bytesToSend = Math.min(requestedSplitLen - bytesSent, buf.length) - bos.write(buf, 0, bytesToSend) - bos.flush() - bytesSent += bytesToSend - } - } else { - // Close the connection - } - } catch { - // If something went wrong, e.g., the worker at the other end died etc - // then close everything up - // Exception can happen if the receiver stops receiving - case e: Exception => { - logInfo("ShuffleServerThread had a " + e) - } - } finally { - logInfo("ShuffleServerThread is closing streams and sockets") - ois.close() - // TODO: Following can cause "java.net.SocketException: Socket closed" - oos.close() - bos.close() - clientSocket.close() - } - } - } - } -} diff --git a/core/src/main/scala/spark/CustomParallelInMemoryShuffle.scala b/core/src/main/scala/spark/CustomParallelInMemoryShuffle.scala deleted file mode 100644 index 48f3685a1a..0000000000 --- a/core/src/main/scala/spark/CustomParallelInMemoryShuffle.scala +++ /dev/null @@ -1,535 +0,0 @@ -package spark - -import java.io._ -import java.net._ -import java.util.{BitSet, Random, Timer, TimerTask, UUID} -import java.util.concurrent.atomic.AtomicLong -import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory} - -import scala.collection.mutable.{ArrayBuffer, HashMap} - -/** - * TODO: THIS IS AN ABSOLUTELY EXPERIMENTAL IMPLEMENTATON (FOR NOW). - * - * An implementation of shuffle using local memory served through custom server - * where receivers create simultaneous connections to multiple servers by - * setting the 'spark.shuffle.maxRxConnections' config option. - * - * TODO: Add support for compression when spark.compress is set to true. - */ -@serializable -class CustomParallelInMemoryShuffle[K, V, C] -extends Shuffle[K, V, C] with Logging { - @transient var totalSplits = 0 - @transient var hasSplits = 0 - @transient var hasSplitsBitVector: BitSet = null - @transient var splitsInRequestBitVector: BitSet = null - - @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null - @transient var combiners: HashMap[K,C] = null - - override def compute(input: RDD[(K, V)], - numOutputSplits: Int, - createCombiner: V => C, - mergeValue: (C, V) => C, - mergeCombiners: (C, C) => C) - : RDD[(K, C)] = - { - val sc = input.sparkContext - val shuffleId = CustomParallelInMemoryShuffle.newShuffleId() - logInfo("Shuffle ID: " + shuffleId) - - val splitRdd = new NumberedSplitRDD(input) - val numInputSplits = splitRdd.splits.size - - // Run a parallel map and collect to write the intermediate data files, - // returning a list of inputSplitId -> serverUri pairs - val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => { - val myIndex = pair._1 - val myIterator = pair._2 - val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C]) - for ((k, v) <- myIterator) { - var bucketId = k.hashCode % numOutputSplits - if (bucketId < 0) { // Fix bucket ID if hash code was negative - bucketId += numOutputSplits - } - val bucket = buckets(bucketId) - bucket(k) = bucket.get(k) match { - case Some(c) => mergeValue(c, v) - case None => createCombiner(v) - } - } - - for (i <- 0 until numOutputSplits) { - val splitName = - CustomParallelInMemoryShuffle.getSplitName(shuffleId, myIndex, i) - - val writeStartTime = System.currentTimeMillis - logInfo("BEGIN WRITE: " + splitName) - - // Write buckets(i) to a byte array & put in splitsCache instead of file - val baos = new ByteArrayOutputStream - val oos = new ObjectOutputStream(baos) - buckets(i).foreach(pair => oos.writeObject(pair)) - oos.close - baos.close - - CustomParallelInMemoryShuffle.splitsCache(splitName) = baos.toByteArray - val splitLen = - CustomParallelInMemoryShuffle.splitsCache(splitName).length - - logInfo("END WRITE: " + splitName) - val writeTime = System.currentTimeMillis - writeStartTime - logInfo("Writing " + splitName + " of size " + splitLen + " bytes took " + writeTime + " millis.") - } - - (myIndex, CustomParallelInMemoryShuffle.serverAddress, - CustomParallelInMemoryShuffle.serverPort) - }).collect() - - val splitsByUri = new ArrayBuffer[(String, Int, Int)] - for ((inputId, serverAddress, serverPort) <- outputLocs) { - splitsByUri += ((serverAddress, serverPort, inputId)) - } - - // TODO: Could broadcast splitsByUri - - // Return an RDD that does each of the merges for a given partition - val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits) - return indexes.flatMap((myId: Int) => { - totalSplits = splitsByUri.size - hasSplits = 0 - hasSplitsBitVector = new BitSet(totalSplits) - splitsInRequestBitVector = new BitSet(totalSplits) - - receivedData = new LinkedBlockingQueue[(Int, Array[Byte])] - combiners = new HashMap[K, C] - - var threadPool = Shuffle.newDaemonFixedThreadPool( - Shuffle.MaxRxConnections) - - while (hasSplits < totalSplits) { - var numThreadsToCreate = Math.min(totalSplits, - Shuffle.MaxRxConnections) - - threadPool.getActiveCount - - while (hasSplits < totalSplits && numThreadsToCreate > 0) { - // Select a random split to pull - val splitIndex = selectRandomSplit - - if (splitIndex != -1) { - val (serverAddress, serverPort, inputId) = splitsByUri(splitIndex) - val requestSplit = "%d/%d/%d".format(shuffleId, inputId, myId) - - threadPool.execute(new ShuffleClient(splitIndex, serverAddress, - serverPort, requestSplit)) - - // splitIndex is in transit. Will be unset in the ShuffleClient - splitsInRequestBitVector.synchronized { - splitsInRequestBitVector.set(splitIndex) - } - } - - numThreadsToCreate = numThreadsToCreate - 1 - } - - // Sleep for a while before creating new threads - Thread.sleep(Shuffle.MinKnockInterval) - } - - threadPool.shutdown() - - // Start consumer - // TODO: Consumption is delayed until everything has been received. - // Otherwise it interferes with network performance - var shuffleConsumer = new ShuffleConsumer(mergeCombiners) - shuffleConsumer.setDaemon(true) - shuffleConsumer.start() - logInfo("ShuffleConsumer started...") - - // Don't return until consumption is finished - // while (receivedData.size > 0) { - // Thread.sleep(Shuffle.MinKnockInterval) - // } - - // Wait till shuffleConsumer is done - shuffleConsumer.join - - combiners - }) - } - - def selectRandomSplit: Int = { - var requiredSplits = new ArrayBuffer[Int] - - synchronized { - for (i <- 0 until totalSplits) { - if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) { - requiredSplits += i - } - } - } - - if (requiredSplits.size > 0) { - requiredSplits(CustomParallelInMemoryShuffle.ranGen.nextInt( - requiredSplits.size)) - } else { - -1 - } - } - - class ShuffleConsumer(mergeCombiners: (C, C) => C) - extends Thread with Logging { - override def run: Unit = { - // Run until all splits are here - while (receivedData.size > 0) { - var splitIndex = -1 - var recvByteArray: Array[Byte] = null - - try { - var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])] - splitIndex = tempPair._1 - recvByteArray = tempPair._2 - } catch { - case e: Exception => { - logInfo("Exception during taking data from receivedData") - } - } - - val inputStream = - new ObjectInputStream(new ByteArrayInputStream(recvByteArray)) - - try{ - while (true) { - val (k, c) = inputStream.readObject.asInstanceOf[(K, C)] - combiners(k) = combiners.get(k) match { - case Some(oldC) => mergeCombiners(oldC, c) - case None => c - } - } - } catch { - case e: EOFException => { } - } - inputStream.close() - } - } - } - - class ShuffleClient(splitIndex: Int, hostAddress: String, listenPort: Int, - requestSplit: String) - extends Thread with Logging { - private var peerSocketToSource: Socket = null - private var oosSource: ObjectOutputStream = null - private var oisSource: ObjectInputStream = null - - private var receptionSucceeded = false - - override def run: Unit = { - // Setup the timeout mechanism - var timeOutTask = new TimerTask { - override def run: Unit = { - cleanUpConnections() - } - } - - var timeOutTimer = new Timer - timeOutTimer.schedule(timeOutTask, Shuffle.MaxKnockInterval) - - logInfo("ShuffleClient started... => %s:%d#%s".format(hostAddress, listenPort, requestSplit)) - - try { - // Connect to the source - peerSocketToSource = new Socket(hostAddress, listenPort) - oosSource = - new ObjectOutputStream(peerSocketToSource.getOutputStream) - oosSource.flush() - var isSource = peerSocketToSource.getInputStream - oisSource = new ObjectInputStream(isSource) - - // Send the request - oosSource.writeObject(requestSplit) - - // Receive the length of the requested file - var requestedFileLen = oisSource.readObject.asInstanceOf[Int] - logInfo("Received requestedFileLen = " + requestedFileLen) - - // Turn the timer OFF, if the sender responds before timeout - timeOutTimer.cancel() - - // Receive the file - if (requestedFileLen != -1) { - val readStartTime = System.currentTimeMillis - logInfo("BEGIN READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit)) - - // Receive data in an Array[Byte] - var recvByteArray = new Array[Byte](requestedFileLen) - var alreadyRead = 0 - var bytesRead = 0 - - while (alreadyRead != requestedFileLen) { - bytesRead = isSource.read(recvByteArray, alreadyRead, - requestedFileLen - alreadyRead) - if (bytesRead > 0) { - alreadyRead = alreadyRead + bytesRead - } - } - - // Make it available to the consumer - try { - receivedData.put((splitIndex, recvByteArray)) - } catch { - case e: Exception => { - logInfo("Exception during putting data into receivedData") - } - } - - // TODO: Updating stats before consumption is completed - hasSplitsBitVector.synchronized { - hasSplitsBitVector.set(splitIndex) - hasSplits += 1 - } - - // We have received splitIndex - splitsInRequestBitVector.synchronized { - splitsInRequestBitVector.set(splitIndex, false) - } - - receptionSucceeded = true - - logInfo("END READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit)) - val readTime = System.currentTimeMillis - readStartTime - logInfo("Reading http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit) + " took " + readTime + " millis.") - } else { - throw new SparkException("ShuffleServer " + hostAddress + " does not have " + requestSplit) - } - } catch { - // EOFException is expected to happen because sender can break - // connection due to timeout - case eofe: java.io.EOFException => { } - case e: Exception => { - logInfo("ShuffleClient had a " + e) - } - } finally { - // If reception failed, unset for future retry - if (!receptionSucceeded) { - splitsInRequestBitVector.synchronized { - splitsInRequestBitVector.set(splitIndex, false) - } - } - cleanUpConnections() - } - } - - private def cleanUpConnections(): Unit = { - if (oisSource != null) { - oisSource.close() - } - if (oosSource != null) { - oosSource.close() - } - if (peerSocketToSource != null) { - peerSocketToSource.close() - } - } - } -} - -object CustomParallelInMemoryShuffle extends Logging { - // Cache for keeping the splits around - val splitsCache = new HashMap[String, Array[Byte]] - - private var initialized = false - private var nextShuffleId = new AtomicLong(0) - - // Variables initialized by initializeIfNeeded() - private var shuffleDir: File = null - - private var shuffleServer: ShuffleServer = null - private var serverAddress = InetAddress.getLocalHost.getHostAddress - private var serverPort: Int = -1 - - // Random number generator - var ranGen = new Random - - private def initializeIfNeeded() = synchronized { - if (!initialized) { - // TODO: localDir should be created by some mechanism common to Spark - // so that it can be shared among shuffle, broadcast, etc - val localDirRoot = System.getProperty("spark.local.dir", "/tmp") - var tries = 0 - var foundLocalDir = false - var localDir: File = null - var localDirUuid: UUID = null - - while (!foundLocalDir && tries < 10) { - tries += 1 - try { - localDirUuid = UUID.randomUUID - localDir = new File(localDirRoot, "spark-local-" + localDirUuid) - if (!localDir.exists) { - localDir.mkdirs() - foundLocalDir = true - } - } catch { - case e: Exception => - logWarning("Attempt " + tries + " to create local dir failed", e) - } - } - if (!foundLocalDir) { - logError("Failed 10 attempts to create local dir in " + localDirRoot) - System.exit(1) - } - shuffleDir = new File(localDir, "shuffle") - shuffleDir.mkdirs() - logInfo("Shuffle dir: " + shuffleDir) - - // Create and start the shuffleServer - shuffleServer = new ShuffleServer - shuffleServer.setDaemon(true) - shuffleServer.start() - logInfo("ShuffleServer started...") - - initialized = true - } - } - - def getSplitName(shuffleId: Long, inputId: Int, outputId: Int): String = { - initializeIfNeeded() - // Adding shuffleDir is unnecessary. Added to keep the parsers working - return "%s/%d/%d/%d".format(shuffleDir, shuffleId, inputId, outputId) - } - - def newShuffleId(): Long = { - nextShuffleId.getAndIncrement() - } - - // Returns a standard ThreadFactory except all threads are daemons - private def newDaemonThreadFactory: ThreadFactory = { - new ThreadFactory { - def newThread(r: Runnable): Thread = { - var t = Executors.defaultThreadFactory.newThread(r) - t.setDaemon(true) - return t - } - } - } - - // Wrapper over newFixedThreadPool - def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = { - var threadPool = - Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor] - - threadPool.setThreadFactory(newDaemonThreadFactory) - - return threadPool - } - - class ShuffleServer - extends Thread with Logging { - var threadPool = newDaemonFixedThreadPool(Shuffle.MaxTxConnections) - - var serverSocket: ServerSocket = null - - override def run: Unit = { - serverSocket = new ServerSocket(0) - serverPort = serverSocket.getLocalPort - - logInfo("ShuffleServer started with " + serverSocket) - logInfo("Local URI: http://" + serverAddress + ":" + serverPort) - - try { - while (true) { - var clientSocket: Socket = null - try { - clientSocket = serverSocket.accept() - } catch { - case e: Exception => { } - } - if (clientSocket != null) { - logInfo("Serve: Accepted new client connection:" + clientSocket) - try { - threadPool.execute(new ShuffleServerThread(clientSocket)) - } catch { - // In failure, close socket here; else, the thread will close it - case ioe: IOException => { - clientSocket.close() - } - } - } - } - } finally { - if (serverSocket != null) { - logInfo("ShuffleServer now stopping...") - serverSocket.close() - } - } - // Shutdown the thread pool - threadPool.shutdown() - } - - class ShuffleServerThread(val clientSocket: Socket) - extends Thread with Logging { - private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream] - os.flush() - private val bos = new BufferedOutputStream(os) - bos.flush() - private val oos = new ObjectOutputStream(os) - oos.flush() - private val ois = new ObjectInputStream(clientSocket.getInputStream) - - logInfo("new ShuffleServerThread is running") - - override def run: Unit = { - try { - // Receive requestedSplit from the receiver - // Adding shuffleDir is unnecessary. Added to keep the parsers working - var requestedSplit = - shuffleDir + "/" + ois.readObject.asInstanceOf[String] - logInfo("requestedSplit: " + requestedSplit) - - // Send the length of the requestedSplit to let the receiver know that - // transfer is about to start - // In the case of receiver timeout and connection close, this will - // throw a java.net.SocketException: Broken pipe - var requestedSplitLen = -1 - - try { - requestedSplitLen = - CustomParallelInMemoryShuffle.splitsCache(requestedSplit).length - } catch { - case e: Exception => { } - } - - oos.writeObject(requestedSplitLen) - oos.flush() - - logInfo("requestedSplitLen = " + requestedSplitLen) - - // Read and send the requested split - if (requestedSplitLen != -1) { - // Send - bos.write(CustomParallelInMemoryShuffle.splitsCache(requestedSplit), - 0, requestedSplitLen) - bos.flush() - } else { - // Close the connection - } - } catch { - // If something went wrong, e.g., the worker at the other end died etc - // then close everything up - // Exception can happen if the receiver stops receiving - case e: Exception => { - logInfo("ShuffleServerThread had a " + e) - } - } finally { - logInfo("ShuffleServerThread is closing streams and sockets") - ois.close() - // TODO: Following can cause "java.net.SocketException: Socket closed" - oos.close() - bos.close() - clientSocket.close() - } - } - } - } -} diff --git a/core/src/main/scala/spark/HttpBlockedLocalFileShuffle.scala b/core/src/main/scala/spark/HttpBlockedLocalFileShuffle.scala deleted file mode 100644 index 8e89cadfdd..0000000000 --- a/core/src/main/scala/spark/HttpBlockedLocalFileShuffle.scala +++ /dev/null @@ -1,471 +0,0 @@ -package spark - -import java.io._ -import java.net._ -import java.util.{BitSet, Random, Timer, TimerTask, UUID} -import java.util.concurrent.atomic.AtomicLong -import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory} - -import scala.collection.mutable.{ArrayBuffer, HashMap} - -/** - * An implementation of shuffle using local files served through HTTP where - * receivers create simultaneous connections to multiple servers by setting the - * 'spark.shuffle.maxRxConnections' config option. - * - * By controlling the 'spark.shuffle.blockSize' config option one can also - * control the largest block size to retrieve by each reducers. An INDEX file - * keeps track of block boundaries instead of creating many smaller files. - * - * TODO: Add support for compression when spark.compress is set to true. - */ -@serializable -class HttpBlockedLocalFileShuffle[K, V, C] -extends Shuffle[K, V, C] with Logging { - @transient var totalSplits = 0 - @transient var hasSplits = 0 - - @transient var blocksInSplit: Array[ArrayBuffer[Long]] = null - @transient var totalBlocksInSplit: Array[Int] = null - @transient var hasBlocksInSplit: Array[Int] = null - - @transient var hasSplitsBitVector: BitSet = null - @transient var splitsInRequestBitVector: BitSet = null - - @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null - @transient var combiners: HashMap[K,C] = null - - override def compute(input: RDD[(K, V)], - numOutputSplits: Int, - createCombiner: V => C, - mergeValue: (C, V) => C, - mergeCombiners: (C, C) => C) - : RDD[(K, C)] = - { - val sc = input.sparkContext - val shuffleId = HttpBlockedLocalFileShuffle.newShuffleId() - logInfo("Shuffle ID: " + shuffleId) - - val splitRdd = new NumberedSplitRDD(input) - val numInputSplits = splitRdd.splits.size - - // Run a parallel map and collect to write the intermediate data files, - // returning a list of inputSplitId -> serverUri pairs - val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => { - val myIndex = pair._1 - val myIterator = pair._2 - val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C]) - for ((k, v) <- myIterator) { - var bucketId = k.hashCode % numOutputSplits - if (bucketId < 0) { // Fix bucket ID if hash code was negative - bucketId += numOutputSplits - } - val bucket = buckets(bucketId) - bucket(k) = bucket.get(k) match { - case Some(c) => mergeValue(c, v) - case None => createCombiner(v) - } - } - - for (i <- 0 until numOutputSplits) { - // Open the INDEX file - var indexFile: File = - HttpBlockedLocalFileShuffle.getBlockIndexOutputFile(shuffleId, - myIndex, i) - var indexOut = new ObjectOutputStream(new FileOutputStream(indexFile)) - var indexDirty: Boolean = true - var alreadyWritten: Long = 0 - - // Open the actual file - var file: File = - HttpBlockedLocalFileShuffle.getOutputFile(shuffleId, myIndex, i) - val out = new ObjectOutputStream(new FileOutputStream(file)) - - val writeStartTime = System.currentTimeMillis - logInfo("BEGIN WRITE: " + file) - - buckets(i).foreach(pair => { - out.writeObject(pair) - out.flush() - indexDirty = true - - // Update the INDEX file if more than blockSize limit has been written - if (file.length - alreadyWritten > Shuffle.BlockSize) { - indexOut.writeObject(file.length) - indexDirty = false - alreadyWritten = file.length - } - }) - - // Write down the last range if it was not written - if (indexDirty) { - indexOut.writeObject(file.length) - } - - out.close() - indexOut.close() - - logInfo("END WRITE: " + file) - val writeTime = (System.currentTimeMillis - writeStartTime) - logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.") - } - - (myIndex, HttpBlockedLocalFileShuffle.serverUri) - }).collect() - - // TODO: Could broadcast outputLocs - - // Return an RDD that does each of the merges for a given partition - val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits) - return indexes.flatMap((myId: Int) => { - totalSplits = outputLocs.size - hasSplits = 0 - - blocksInSplit = Array.tabulate(totalSplits)(_ => new ArrayBuffer[Long]) - totalBlocksInSplit = Array.tabulate(totalSplits)(_ => -1) - hasBlocksInSplit = Array.tabulate(totalSplits)(_ => 0) - - hasSplitsBitVector = new BitSet(totalSplits) - splitsInRequestBitVector = new BitSet(totalSplits) - - receivedData = new LinkedBlockingQueue[(Int, Array[Byte])] - combiners = new HashMap[K, C] - - var threadPool = Shuffle.newDaemonFixedThreadPool( - Shuffle.MaxRxConnections) - - while (hasSplits < totalSplits) { - var numThreadsToCreate = - Math.min(totalSplits, Shuffle.MaxRxConnections) - - threadPool.getActiveCount - - while (hasSplits < totalSplits && numThreadsToCreate > 0) { - // Select a random split to pull - val splitIndex = selectRandomSplit - - if (splitIndex != -1) { - val (inputId, serverUri) = outputLocs(splitIndex) - - threadPool.execute(new ShuffleClient(serverUri, shuffleId.toInt, - inputId, myId, splitIndex)) - - // splitIndex is in transit. Will be unset in the ShuffleClient - splitsInRequestBitVector.synchronized { - splitsInRequestBitVector.set(splitIndex) - } - } - - numThreadsToCreate = numThreadsToCreate - 1 - } - - // Sleep for a while before creating new threads - Thread.sleep(Shuffle.MinKnockInterval) - } - - threadPool.shutdown() - - // Start consumer - // TODO: Consumption is delayed until everything has been received. - // Otherwise it interferes with network performance - var shuffleConsumer = new ShuffleConsumer(mergeCombiners) - shuffleConsumer.setDaemon(true) - shuffleConsumer.start() - logInfo("ShuffleConsumer started...") - - // Don't return until consumption is finished - // while (receivedData.size > 0) { - // Thread.sleep(Shuffle.MinKnockInterval) - // } - - // Wait till shuffleConsumer is done - shuffleConsumer.join - - combiners - }) - } - - def selectRandomSplit: Int = { - var requiredSplits = new ArrayBuffer[Int] - - synchronized { - for (i <- 0 until totalSplits) { - if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) { - requiredSplits += i - } - } - } - - if (requiredSplits.size > 0) { - requiredSplits(HttpBlockedLocalFileShuffle.ranGen.nextInt( - requiredSplits.size)) - } else { - -1 - } - } - - class ShuffleConsumer(mergeCombiners: (C, C) => C) - extends Thread with Logging { - override def run: Unit = { - // Run until all splits are here - while (receivedData.size > 0) { - var splitIndex = -1 - var recvByteArray: Array[Byte] = null - - try { - var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])] - splitIndex = tempPair._1 - recvByteArray = tempPair._2 - } catch { - case e: Exception => { - logInfo("Exception during taking data from receivedData") - } - } - - val inputStream = - new ObjectInputStream(new ByteArrayInputStream(recvByteArray)) - - try{ - while (true) { - val (k, c) = inputStream.readObject.asInstanceOf[(K, C)] - combiners(k) = combiners.get(k) match { - case Some(oldC) => mergeCombiners(oldC, c) - case None => c - } - } - } catch { - case e: EOFException => { } - } - inputStream.close() - } - } - } - - class ShuffleClient(serverUri: String, shuffleId: Int, - inputId: Int, myId: Int, splitIndex: Int) - extends Thread with Logging { - private var receptionSucceeded = false - - override def run: Unit = { - try { - // First get the INDEX file if totalBlocksInSplit(splitIndex) is unknown - if (totalBlocksInSplit(splitIndex) == -1) { - val url = "%s/shuffle/%d/%d/INDEX-%d".format(serverUri, shuffleId, - inputId, myId) - val inputStream = new ObjectInputStream(new URL(url).openStream()) - - try { - while (true) { - blocksInSplit(splitIndex) += - inputStream.readObject().asInstanceOf[Long] - } - } catch { - case e: EOFException => {} - } - - totalBlocksInSplit(splitIndex) = blocksInSplit(splitIndex).size - inputStream.close() - } - - // Open connection - val urlString = - "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, inputId, myId) - val url = new URL(urlString) - val httpConnection = - url.openConnection().asInstanceOf[HttpURLConnection] - - // Set the range to download - val blockStartsAt = hasBlocksInSplit(splitIndex) match { - case 0 => 0 - case _ => blocksInSplit(splitIndex)(hasBlocksInSplit(splitIndex) - 1) + 1 - } - val blockEndsAt = blocksInSplit(splitIndex)(hasBlocksInSplit(splitIndex)) - httpConnection.setRequestProperty("Range", - "bytes=" + blockStartsAt + "-" + blockEndsAt) - - // Connect to the server - httpConnection.connect() - - val urStringWithRange = - urlString + "[%d:%d]".format(blockStartsAt, blockEndsAt) - val readStartTime = System.currentTimeMillis - logInfo("BEGIN READ: " + urStringWithRange) - - // Receive data in an Array[Byte] - val requestedFileLen: Int = (blockEndsAt - blockStartsAt).toInt + 1 - var recvByteArray = new Array[Byte](requestedFileLen) - var alreadyRead = 0 - var bytesRead = 0 - - val isSource = httpConnection.getInputStream() - while (alreadyRead != requestedFileLen) { - bytesRead = isSource.read(recvByteArray, alreadyRead, - requestedFileLen - alreadyRead) - if (bytesRead > 0) { - alreadyRead = alreadyRead + bytesRead - } - } - - // Disconnect - httpConnection.disconnect() - - // Make it available to the consumer - try { - receivedData.put((splitIndex, recvByteArray)) - } catch { - case e: Exception => { - logInfo("Exception during putting data into receivedData") - } - } - - // TODO: Updating stats before consumption is completed - hasBlocksInSplit(splitIndex) = hasBlocksInSplit(splitIndex) + 1 - - // Split has been received only if all the blocks have been received - if (hasBlocksInSplit(splitIndex) == totalBlocksInSplit(splitIndex)) { - hasSplitsBitVector.synchronized { - hasSplitsBitVector.set(splitIndex) - hasSplits += 1 - } - } - - // We have received splitIndex - splitsInRequestBitVector.synchronized { - splitsInRequestBitVector.set(splitIndex, false) - } - - receptionSucceeded = true - - logInfo("END READ: " + urStringWithRange) - val readTime = System.currentTimeMillis - readStartTime - logInfo("Reading " + urStringWithRange + " took " + readTime + " millis.") - } catch { - // EOFException is expected to happen because sender can break - // connection due to timeout - case eofe: java.io.EOFException => { } - case e: Exception => { - logInfo("ShuffleClient had a " + e) - } - } finally { - // If reception failed, unset for future retry - if (!receptionSucceeded) { - splitsInRequestBitVector.synchronized { - splitsInRequestBitVector.set(splitIndex, false) - } - } - } - } - } -} - -object HttpBlockedLocalFileShuffle extends Logging { - private var initialized = false - private var nextShuffleId = new AtomicLong(0) - - // Variables initialized by initializeIfNeeded() - private var shuffleDir: File = null - private var server: HttpServer = null - private var serverUri: String = null - - // Random number generator - var ranGen = new Random - - private def initializeIfNeeded() = synchronized { - if (!initialized) { - // TODO: localDir should be created by some mechanism common to Spark - // so that it can be shared among shuffle, broadcast, etc - val localDirRoot = System.getProperty("spark.local.dir", "/tmp") - var tries = 0 - var foundLocalDir = false - var localDir: File = null - var localDirUuid: UUID = null - while (!foundLocalDir && tries < 10) { - tries += 1 - try { - localDirUuid = UUID.randomUUID - localDir = new File(localDirRoot, "spark-local-" + localDirUuid) - if (!localDir.exists) { - localDir.mkdirs() - foundLocalDir = true - } - } catch { - case e: Exception => - logWarning("Attempt " + tries + " to create local dir failed", e) - } - } - if (!foundLocalDir) { - logError("Failed 10 attempts to create local dir in " + localDirRoot) - System.exit(1) - } - shuffleDir = new File(localDir, "shuffle") - shuffleDir.mkdirs() - logInfo("Shuffle dir: " + shuffleDir) - - val extServerPort = System.getProperty( - "spark.localFileShuffle.external.server.port", "-1").toInt - if (extServerPort != -1) { - // We're using an external HTTP server; set URI relative to its root - var extServerPath = System.getProperty( - "spark.localFileShuffle.external.server.path", "") - if (extServerPath != "" && !extServerPath.endsWith("/")) { - extServerPath += "/" - } - serverUri = "http://%s:%d/%s/spark-local-%s".format( - Utils.localIpAddress, extServerPort, extServerPath, localDirUuid) - } else { - // Create our own server - server = new HttpServer(localDir) - server.start() - serverUri = server.uri - } - initialized = true - logInfo("Local URI: " + serverUri) - } - } - - def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int): File = { - initializeIfNeeded() - val dir = new File(shuffleDir, shuffleId + "/" + inputId) - dir.mkdirs() - val file = new File(dir, "" + outputId) - return file - } - - def getBlockIndexOutputFile(shuffleId: Long, inputId: Int, - outputId: Int): File = { - initializeIfNeeded() - val dir = new File(shuffleDir, shuffleId + "/" + inputId) - dir.mkdirs() - val file = new File(dir, "INDEX-" + outputId) - return file - } - - def getServerUri(): String = { - initializeIfNeeded() - serverUri - } - - def newShuffleId(): Long = { - nextShuffleId.getAndIncrement() - } - - // Returns a standard ThreadFactory except all threads are daemons - private def newDaemonThreadFactory: ThreadFactory = { - new ThreadFactory { - def newThread(r: Runnable): Thread = { - var t = Executors.defaultThreadFactory.newThread(r) - t.setDaemon(true) - return t - } - } - } - - // Wrapper over newFixedThreadPool - def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = { - var threadPool = - Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor] - - threadPool.setThreadFactory(newDaemonThreadFactory) - - return threadPool - } -} diff --git a/core/src/main/scala/spark/ManualBlockedLocalFileShuffle.scala b/core/src/main/scala/spark/ManualBlockedLocalFileShuffle.scala deleted file mode 100644 index cfd84fdb83..0000000000 --- a/core/src/main/scala/spark/ManualBlockedLocalFileShuffle.scala +++ /dev/null @@ -1,465 +0,0 @@ -package spark - -import java.io._ -import java.net._ -import java.util.{BitSet, Random, Timer, TimerTask, UUID} -import java.util.concurrent.atomic.AtomicLong -import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory} - -import scala.collection.mutable.{ArrayBuffer, HashMap} - -/** - * An implementation of shuffle using local files served through HTTP where - * receivers create simultaneous connections to multiple servers by setting the - * 'spark.shuffle.maxRxConnections' config option. - * - * By controlling the 'spark.shuffle.blockSize' config option one can also - * control the largest block size to divide each map output into. Essentially, - * instead of creating one large output file for each reducer, maps create - * multiple smaller files to enable finer level of engagement. - * - * TODO: Add support for compression when spark.compress is set to true. - */ -@serializable -class ManualBlockedLocalFileShuffle[K, V, C] -extends Shuffle[K, V, C] with Logging { - @transient var totalSplits = 0 - @transient var hasSplits = 0 - - @transient var totalBlocksInSplit: Array[Int] = null - @transient var hasBlocksInSplit: Array[Int] = null - - @transient var hasSplitsBitVector: BitSet = null - @transient var splitsInRequestBitVector: BitSet = null - - @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null - @transient var combiners: HashMap[K,C] = null - - override def compute(input: RDD[(K, V)], - numOutputSplits: Int, - createCombiner: V => C, - mergeValue: (C, V) => C, - mergeCombiners: (C, C) => C) - : RDD[(K, C)] = - { - val sc = input.sparkContext - val shuffleId = ManualBlockedLocalFileShuffle.newShuffleId() - logInfo("Shuffle ID: " + shuffleId) - - val splitRdd = new NumberedSplitRDD(input) - val numInputSplits = splitRdd.splits.size - - // Run a parallel map and collect to write the intermediate data files, - // returning a list of inputSplitId -> serverUri pairs - val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => { - val myIndex = pair._1 - val myIterator = pair._2 - val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C]) - for ((k, v) <- myIterator) { - var bucketId = k.hashCode % numOutputSplits - if (bucketId < 0) { // Fix bucket ID if hash code was negative - bucketId += numOutputSplits - } - val bucket = buckets(bucketId) - bucket(k) = bucket.get(k) match { - case Some(c) => mergeValue(c, v) - case None => createCombiner(v) - } - } - - for (i <- 0 until numOutputSplits) { - var blockNum = 0 - var isDirty = false - var file: File = null - var out: ObjectOutputStream = null - - var writeStartTime: Long = 0 - - buckets(i).foreach(pair => { - // Open a new file if necessary - if (!isDirty) { - file = ManualBlockedLocalFileShuffle.getOutputFile(shuffleId, - myIndex, i, blockNum) - writeStartTime = System.currentTimeMillis - logInfo("BEGIN WRITE: " + file) - - out = new ObjectOutputStream(new FileOutputStream(file)) - } - - out.writeObject(pair) - out.flush() - isDirty = true - - // Close the old file if has crossed the blockSize limit - if (file.length > Shuffle.BlockSize) { - out.close() - logInfo("END WRITE: " + file) - val writeTime = System.currentTimeMillis - writeStartTime - logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.") - - blockNum = blockNum + 1 - isDirty = false - } - }) - - if (isDirty) { - out.close() - logInfo("END WRITE: " + file) - val writeTime = System.currentTimeMillis - writeStartTime - logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.") - - blockNum = blockNum + 1 - } - - // Write the BLOCKNUM file - file = ManualBlockedLocalFileShuffle.getBlockNumOutputFile(shuffleId, - myIndex, i) - out = new ObjectOutputStream(new FileOutputStream(file)) - out.writeObject(blockNum) - out.close() - } - - (myIndex, ManualBlockedLocalFileShuffle.serverUri) - }).collect() - - // TODO: Could broadcast outputLocs - - // Return an RDD that does each of the merges for a given partition - val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits) - return indexes.flatMap((myId: Int) => { - totalSplits = outputLocs.size - hasSplits = 0 - - totalBlocksInSplit = Array.tabulate(totalSplits)(_ => -1) - hasBlocksInSplit = Array.tabulate(totalSplits)(_ => 0) - - hasSplitsBitVector = new BitSet(totalSplits) - splitsInRequestBitVector = new BitSet(totalSplits) - - receivedData = new LinkedBlockingQueue[(Int, Array[Byte])] - combiners = new HashMap[K, C] - - var threadPool = Shuffle.newDaemonFixedThreadPool( - Shuffle.MaxRxConnections) - - while (hasSplits < totalSplits) { - var numThreadsToCreate = - Math.min(totalSplits, Shuffle.MaxRxConnections) - - threadPool.getActiveCount - - while (hasSplits < totalSplits && numThreadsToCreate > 0) { - // Select a random split to pull - val splitIndex = selectRandomSplit - - if (splitIndex != -1) { - val (inputId, serverUri) = outputLocs(splitIndex) - - threadPool.execute(new ShuffleClient(serverUri, shuffleId.toInt, - inputId, myId, splitIndex, mergeCombiners)) - - // splitIndex is in transit. Will be unset in the ShuffleClient - splitsInRequestBitVector.synchronized { - splitsInRequestBitVector.set(splitIndex) - } - } - - numThreadsToCreate = numThreadsToCreate - 1 - } - - // Sleep for a while before creating new threads - Thread.sleep(Shuffle.MinKnockInterval) - } - - threadPool.shutdown() - - // Start consumer - // TODO: Consumption is delayed until everything has been received. - // Otherwise it interferes with network performance - var shuffleConsumer = new ShuffleConsumer(mergeCombiners) - shuffleConsumer.setDaemon(true) - shuffleConsumer.start() - logInfo("ShuffleConsumer started...") - - // Don't return until consumption is finished - // while (receivedData.size > 0) { - // Thread.sleep(Shuffle.MinKnockInterval) - // } - - // Wait till shuffleConsumer is done - shuffleConsumer.join - - combiners - }) - } - - def selectRandomSplit: Int = { - var requiredSplits = new ArrayBuffer[Int] - - synchronized { - for (i <- 0 until totalSplits) { - if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) { - requiredSplits += i - } - } - } - - if (requiredSplits.size > 0) { - requiredSplits(ManualBlockedLocalFileShuffle.ranGen.nextInt( - requiredSplits.size)) - } else { - -1 - } - } - - class ShuffleConsumer(mergeCombiners: (C, C) => C) - extends Thread with Logging { - override def run: Unit = { - // Run until all splits are here - while (receivedData.size > 0) { - var splitIndex = -1 - var recvByteArray: Array[Byte] = null - - try { - var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])] - splitIndex = tempPair._1 - recvByteArray = tempPair._2 - } catch { - case e: Exception => { - logInfo("Exception during taking data from receivedData") - } - } - - val inputStream = - new ObjectInputStream(new ByteArrayInputStream(recvByteArray)) - - try{ - while (true) { - val (k, c) = inputStream.readObject.asInstanceOf[(K, C)] - combiners(k) = combiners.get(k) match { - case Some(oldC) => mergeCombiners(oldC, c) - case None => c - } - } - } catch { - case e: EOFException => { } - } - inputStream.close() - } - } - } - - class ShuffleClient(serverUri: String, shuffleId: Int, - inputId: Int, myId: Int, splitIndex: Int, - mergeCombiners: (C, C) => C) - extends Thread with Logging { - private var receptionSucceeded = false - - override def run: Unit = { - try { - // Everything will break if BLOCKNUM is not correctly received - // First get BLOCKNUM file if totalBlocksInSplit(splitIndex) is unknown - if (totalBlocksInSplit(splitIndex) == -1) { - val url = "%s/shuffle/%d/%d/BLOCKNUM-%d".format(serverUri, shuffleId, - inputId, myId) - val inputStream = new ObjectInputStream(new URL(url).openStream()) - totalBlocksInSplit(splitIndex) = - inputStream.readObject().asInstanceOf[Int] - inputStream.close() - } - - // Open connection - val urlString = - "%s/shuffle/%d/%d/%d-%d".format(serverUri, shuffleId, inputId, - myId, hasBlocksInSplit(splitIndex)) - val url = new URL(urlString) - val httpConnection = - url.openConnection().asInstanceOf[HttpURLConnection] - - // Connect to the server - httpConnection.connect() - - // Receive file length - var requestedFileLen = httpConnection.getContentLength - - val readStartTime = System.currentTimeMillis - logInfo("BEGIN READ: " + url) - - // Receive data in an Array[Byte] - var recvByteArray = new Array[Byte](requestedFileLen) - var alreadyRead = 0 - var bytesRead = 0 - - val isSource = httpConnection.getInputStream() - while (alreadyRead != requestedFileLen) { - bytesRead = isSource.read(recvByteArray, alreadyRead, - requestedFileLen - alreadyRead) - if (bytesRead > 0) { - alreadyRead = alreadyRead + bytesRead - } - } - - // Disconnect - httpConnection.disconnect() - - // Make it available to the consumer - try { - receivedData.put((splitIndex, recvByteArray)) - } catch { - case e: Exception => { - logInfo("Exception during putting data into receivedData") - } - } - - // TODO: Updating stats before consumption is completed - hasBlocksInSplit(splitIndex) = hasBlocksInSplit(splitIndex) + 1 - - // Split has been received only if all the blocks have been received - if (hasBlocksInSplit(splitIndex) == totalBlocksInSplit(splitIndex)) { - hasSplitsBitVector.synchronized { - hasSplitsBitVector.set(splitIndex) - hasSplits += 1 - } - } - - // We have received splitIndex - splitsInRequestBitVector.synchronized { - splitsInRequestBitVector.set(splitIndex, false) - } - - receptionSucceeded = true - - logInfo("END READ: " + url) - val readTime = System.currentTimeMillis - readStartTime - logInfo("Reading " + url + " took " + readTime + " millis.") - } catch { - // EOFException is expected to happen because sender can break - // connection due to timeout - case eofe: java.io.EOFException => { } - case e: Exception => { - logInfo("ShuffleClient had a " + e) - } - } finally { - // If reception failed, unset for future retry - if (!receptionSucceeded) { - splitsInRequestBitVector.synchronized { - splitsInRequestBitVector.set(splitIndex, false) - } - } - } - } - } -} - -object ManualBlockedLocalFileShuffle extends Logging { - private var initialized = false - private var nextShuffleId = new AtomicLong(0) - - // Variables initialized by initializeIfNeeded() - private var shuffleDir: File = null - private var server: HttpServer = null - private var serverUri: String = null - - // Random number generator - var ranGen = new Random - - private def initializeIfNeeded() = synchronized { - if (!initialized) { - // TODO: localDir should be created by some mechanism common to Spark - // so that it can be shared among shuffle, broadcast, etc - val localDirRoot = System.getProperty("spark.local.dir", "/tmp") - var tries = 0 - var foundLocalDir = false - var localDir: File = null - var localDirUuid: UUID = null - while (!foundLocalDir && tries < 10) { - tries += 1 - try { - localDirUuid = UUID.randomUUID - localDir = new File(localDirRoot, "spark-local-" + localDirUuid) - if (!localDir.exists) { - localDir.mkdirs() - foundLocalDir = true - } - } catch { - case e: Exception => - logWarning("Attempt " + tries + " to create local dir failed", e) - } - } - if (!foundLocalDir) { - logError("Failed 10 attempts to create local dir in " + localDirRoot) - System.exit(1) - } - shuffleDir = new File(localDir, "shuffle") - shuffleDir.mkdirs() - logInfo("Shuffle dir: " + shuffleDir) - - val extServerPort = System.getProperty( - "spark.localFileShuffle.external.server.port", "-1").toInt - if (extServerPort != -1) { - // We're using an external HTTP server; set URI relative to its root - var extServerPath = System.getProperty( - "spark.localFileShuffle.external.server.path", "") - if (extServerPath != "" && !extServerPath.endsWith("/")) { - extServerPath += "/" - } - serverUri = "http://%s:%d/%s/spark-local-%s".format( - Utils.localIpAddress, extServerPort, extServerPath, localDirUuid) - } else { - // Create our own server - server = new HttpServer(localDir) - server.start() - serverUri = server.uri - } - initialized = true - logInfo("Local URI: " + serverUri) - } - } - - def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int, - blockId: Int): File = { - initializeIfNeeded() - val dir = new File(shuffleDir, shuffleId + "/" + inputId) - dir.mkdirs() - val file = new File(dir, "%d-%d".format(outputId, blockId)) - return file - } - - def getBlockNumOutputFile(shuffleId: Long, inputId: Int, - outputId: Int): File = { - initializeIfNeeded() - val dir = new File(shuffleDir, shuffleId + "/" + inputId) - dir.mkdirs() - val file = new File(dir, "BLOCKNUM-" + outputId) - return file - } - - def getServerUri(): String = { - initializeIfNeeded() - serverUri - } - - def newShuffleId(): Long = { - nextShuffleId.getAndIncrement() - } - - // Returns a standard ThreadFactory except all threads are daemons - private def newDaemonThreadFactory: ThreadFactory = { - new ThreadFactory { - def newThread(r: Runnable): Thread = { - var t = Executors.defaultThreadFactory.newThread(r) - t.setDaemon(true) - return t - } - } - } - - // Wrapper over newFixedThreadPool - def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = { - var threadPool = - Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor] - - threadPool.setThreadFactory(newDaemonThreadFactory) - - return threadPool - } -} diff --git a/core/src/main/scala/spark/ShuffleTrackerStrategy.scala b/core/src/main/scala/spark/ShuffleTrackerStrategy.scala deleted file mode 100644 index fc2f4aa5f7..0000000000 --- a/core/src/main/scala/spark/ShuffleTrackerStrategy.scala +++ /dev/null @@ -1,470 +0,0 @@ -package spark - -import java.util.{BitSet, Random} - -import scala.collection.mutable.ArrayBuffer -import scala.util.Sorting._ - -/** - * A trait for implementing tracker strategies for the shuffle system. - */ -trait ShuffleTrackerStrategy { - // Initialize - def initialize(outputLocs_ : Array[SplitInfo]): Unit - - // Select a set of splits and send back - def selectSplit(reducerSplitInfo: SplitInfo): ArrayBuffer[Int] - - // Update internal stats if things could be sent back successfully - def AddReducerToSplit(reducerSplitInfo: SplitInfo, splitIndices: ArrayBuffer[Int]): Unit - - // A reducer is done. Update internal stats - def deleteReducerFrom(reducerSplitInfo: SplitInfo, - receptionStat: ReceptionStats): Unit -} - -/** - * Helper class to send back reception stats from the reducer - */ -case class ReceptionStats(val bytesReceived: Int, val timeSpent: Int, - serverSplitIndex: Int) { } - -/** - * A simple ShuffleTrackerStrategy that tries to balance the total number of - * connections created for each mapper. - */ -class BalanceConnectionsShuffleTrackerStrategy -extends ShuffleTrackerStrategy with Logging { - private var numSources = -1 - private var outputLocs: Array[SplitInfo] = null - private var curConnectionsPerLoc: Array[Int] = null - private var totalConnectionsPerLoc: Array[Int] = null - - // The order of elements in the outputLocs (splitIndex) is used to pass - // information back and forth between the tracker, mappers, and reducers - def initialize(outputLocs_ : Array[SplitInfo]): Unit = { - outputLocs = outputLocs_ - numSources = outputLocs.size - - // Now initialize other data structures - curConnectionsPerLoc = Array.tabulate(numSources)(_ => 0) - totalConnectionsPerLoc = Array.tabulate(numSources)(_ => 0) - } - - def selectSplit(reducerSplitInfo: SplitInfo): ArrayBuffer[Int] = synchronized { - var minConnections = Int.MaxValue - var minIndex = -1 - - var splitIndices = ArrayBuffer[Int]() - - for (i <- 0 until numSources) { - // TODO: Use of MaxRxConnections instead of MaxTxConnections is - // intentional here. MaxTxConnections is per machine whereas - // MaxRxConnections is per mapper/reducer. Will have to find a better way. - if (curConnectionsPerLoc(i) < Shuffle.MaxRxConnections && - totalConnectionsPerLoc(i) < minConnections && - !reducerSplitInfo.hasSplitsBitVector.get(i)) { - minConnections = totalConnectionsPerLoc(i) - minIndex = i - } - } - - if (minIndex != -1) { - splitIndices += minIndex - } - - return splitIndices - } - - def AddReducerToSplit(reducerSplitInfo: SplitInfo, splitIndices: ArrayBuffer[Int]): Unit = synchronized { - splitIndices.foreach { splitIndex => - curConnectionsPerLoc(splitIndex) = curConnectionsPerLoc(splitIndex) + 1 - totalConnectionsPerLoc(splitIndex) = - totalConnectionsPerLoc(splitIndex) + 1 - } - } - - def deleteReducerFrom(reducerSplitInfo: SplitInfo, - receptionStat: ReceptionStats): Unit = synchronized { - // Decrease number of active connections - curConnectionsPerLoc(receptionStat.serverSplitIndex) = - curConnectionsPerLoc(receptionStat.serverSplitIndex) - 1 - - // TODO: This assertion can legally fail when ShuffleClient times out while - // waiting for tracker response and decides to go to a random server - // assert(curConnectionsPerLoc(receptionStat.serverSplitIndex) >= 0) - - // Just in case - if (curConnectionsPerLoc(receptionStat.serverSplitIndex) < 0) { - curConnectionsPerLoc(receptionStat.serverSplitIndex) = 0 - } - } -} - -/** - * A simple ShuffleTrackerStrategy that randomly selects mapper for each reducer - */ -class SelectRandomShuffleTrackerStrategy -extends ShuffleTrackerStrategy with Logging { - private var numMappers = -1 - private var outputLocs: Array[SplitInfo] = null - - private var ranGen = new Random - - // The order of elements in the outputLocs (splitIndex) is used to pass - // information back and forth between the tracker, mappers, and reducers - def initialize(outputLocs_ : Array[SplitInfo]): Unit = { - outputLocs = outputLocs_ - numMappers = outputLocs.size - } - - def selectSplit(reducerSplitInfo: SplitInfo): ArrayBuffer[Int] = synchronized { - var splitIndex = -1 - - do { - splitIndex = ranGen.nextInt(numMappers) - } while (reducerSplitInfo.hasSplitsBitVector.get(splitIndex)) - - return ArrayBuffer(splitIndex) - } - - def AddReducerToSplit(reducerSplitInfo: SplitInfo, splitIndices: ArrayBuffer[Int]): Unit = synchronized { - } - - def deleteReducerFrom(reducerSplitInfo: SplitInfo, - receptionStat: ReceptionStats): Unit = synchronized { - // TODO: This assertion can legally fail when ShuffleClient times out while - // waiting for tracker response and decides to go to a random server - // assert(curConnectionsPerLoc(receptionStat.serverSplitIndex) >= 0) - } -} - -/** - * Shuffle tracker strategy that tries to balance the percentage of blocks - * remaining for each reducer - */ -class BalanceRemainingShuffleTrackerStrategy -extends ShuffleTrackerStrategy with Logging { - // Number of mappers - private var numMappers = -1 - // Number of reducers - private var numReducers = -1 - private var outputLocs: Array[SplitInfo] = null - - // Data structures from reducers' perspectives - private var totalBlocksPerInputSplit: Array[Array[Int]] = null - private var hasBlocksPerInputSplit: Array[Array[Int]] = null - - // Stored in bytes per millisecond - private var speedPerInputSplit: Array[Array[Double]] = null - - private var curConnectionsPerLoc: Array[Int] = null - private var totalConnectionsPerLoc: Array[Int] = null - - // The order of elements in the outputLocs (splitIndex) is used to pass - // information back and forth between the tracker, mappers, and reducers - def initialize(outputLocs_ : Array[SplitInfo]): Unit = { - outputLocs = outputLocs_ - - numMappers = outputLocs.size - - // All the outputLocs have totalBlocksPerOutputSplit of same size - numReducers = outputLocs(0).totalBlocksPerOutputSplit.size - - // Now initialize the data structures - totalBlocksPerInputSplit = Array.tabulate(numReducers, numMappers)((i,j) => - outputLocs(j).totalBlocksPerOutputSplit(i)) - hasBlocksPerInputSplit = Array.tabulate(numReducers, numMappers)((_,_) => 0) - - // Initialize to -1 - speedPerInputSplit = Array.tabulate(numReducers, numMappers)((_,_) => -1.0) - - curConnectionsPerLoc = Array.tabulate(numMappers)(_ => 0) - totalConnectionsPerLoc = Array.tabulate(numMappers)(_ => 0) - } - - def selectSplit(reducerSplitInfo: SplitInfo): ArrayBuffer[Int] = synchronized { - var splitIndex = -1 - - // Estimate time remaining to finish receiving for all reducer/mapper pairs - // If speed is unknown or zero then make it 1 to give a large estimate - var individualEstimates = Array.tabulate(numReducers, numMappers)((_,_) => 0.0) - for (i <- 0 until numReducers; j <- 0 until numMappers) { - var blocksRemaining = totalBlocksPerInputSplit(i)(j) - - hasBlocksPerInputSplit(i)(j) - assert(blocksRemaining >= 0) - - individualEstimates(i)(j) = 1.0 * blocksRemaining * Shuffle.BlockSize / - { if (speedPerInputSplit(i)(j) <= 0.0) 1.0 else speedPerInputSplit(i)(j) } - } - - // Check if all speedPerInputSplit entries have non-zero values - var estimationComplete = true - for (i <- 0 until numReducers; j <- 0 until numMappers) { - if (speedPerInputSplit(i)(j) < 0.0) { - estimationComplete = false - } - } - - // Mark mappers where this reducer is too fast - var throttleFromMapper = Array.tabulate(numMappers)(_ => false) - - for (i <- 0 until numMappers) { - var estimatesFromAMapper = - Array.tabulate(numReducers)(j => individualEstimates(j)(i)) - - val estimateOfThisReducer = estimatesFromAMapper(reducerSplitInfo.splitId) - - // Only care if this reducer yet has something to receive from this mapper - if (estimateOfThisReducer > 0) { - // Sort the estimated times - quickSort(estimatesFromAMapper) - - // Find a Shuffle.ThrottleFraction amount of gap - var gapIndex = -1 - for (i <- 0 until numReducers - 1) { - if (gapIndex == -1 && estimatesFromAMapper(i) > 0 && - (Shuffle.ThrottleFraction * estimatesFromAMapper(i) < - estimatesFromAMapper(i + 1))) { - gapIndex = i - } - - assert (estimatesFromAMapper(i) <= estimatesFromAMapper(i + 1)) - } - - // Keep track of how many have completed - var numComplete = estimatesFromAMapper.findIndexOf(i => (i > 0)) - if (numComplete == -1) { - numComplete = numReducers - } - - // TODO: Pick a configurable parameter - if (gapIndex != -1 && (1.0 * (gapIndex - numComplete + 1) < 0.1 * Shuffle.ThrottleFraction * (numReducers - numComplete)) && - estimateOfThisReducer <= estimatesFromAMapper(gapIndex)) { - throttleFromMapper(i) = true - logInfo("Throttling R-%d at M-%d with %f and cut-off %f at %d".format(reducerSplitInfo.splitId, i, estimateOfThisReducer, estimatesFromAMapper(gapIndex + 1), gapIndex)) -// for (i <- 0 until numReducers) { -// print(estimatesFromAMapper(i) + " ") -// } -// println("") - } - } else { - throttleFromMapper(i) = true - } - } - - var minConnections = Int.MaxValue - for (i <- 0 until numMappers) { - // TODO: Use of MaxRxConnections instead of MaxTxConnections is - // intentional here. MaxTxConnections is per machine whereas - // MaxRxConnections is per mapper/reducer. Will have to find a better way. - if (curConnectionsPerLoc(i) < Shuffle.MaxRxConnections && - totalConnectionsPerLoc(i) < minConnections && - !reducerSplitInfo.hasSplitsBitVector.get(i) && - !throttleFromMapper(i)) { - minConnections = totalConnectionsPerLoc(i) - splitIndex = i - } - } - - return ArrayBuffer(splitIndex) - } - - def AddReducerToSplit(reducerSplitInfo: SplitInfo, splitIndices: ArrayBuffer[Int]): Unit = synchronized { - splitIndices.foreach { splitIndex => - curConnectionsPerLoc(splitIndex) += 1 - totalConnectionsPerLoc(splitIndex) += 1 - } - } - - def deleteReducerFrom(reducerSplitInfo: SplitInfo, - receptionStat: ReceptionStats): Unit = synchronized { - // Update hasBlocksPerInputSplit for reducerSplitInfo - hasBlocksPerInputSplit(reducerSplitInfo.splitId) = - reducerSplitInfo.hasBlocksPerInputSplit - - // Store the last known speed. Add 1 to avoid divide-by-zero. Ignore 0 bytes - // TODO: We are forgetting the old speed. Can use averaging at some point. - if (receptionStat.bytesReceived > 0) { - speedPerInputSplit(reducerSplitInfo.splitId)(receptionStat.serverSplitIndex) = - 1.0 * receptionStat.bytesReceived / (receptionStat.timeSpent + 1.0) - } - - // logInfo("%d received %d bytes in %d millis".format(reducerSplitInfo.splitId, receptionStat.bytesReceived, receptionStat.timeSpent)) - - // Update current connections to the mapper - curConnectionsPerLoc(receptionStat.serverSplitIndex) -= 1 - - // TODO: This assertion can legally fail when ShuffleClient times out while - // waiting for tracker response and decides to go to a random server - // assert(curConnectionsPerLoc(receptionStat.serverSplitIndex) >= 0) - - // Just in case - if (curConnectionsPerLoc(receptionStat.serverSplitIndex) < 0) { - curConnectionsPerLoc(receptionStat.serverSplitIndex) = 0 - } - } -} - -/** - * Shuffle tracker strategy that allows reducers to create receiving threads - * depending on their estimated time remaining - */ -class LimitConnectionsShuffleTrackerStrategy -extends ShuffleTrackerStrategy with Logging { - // Number of mappers - private var numMappers = -1 - // Number of reducers - private var numReducers = -1 - private var outputLocs: Array[SplitInfo] = null - - private var ranGen = new Random - - // Data structures from reducers' perspectives - private var totalBlocksPerInputSplit: Array[Array[Int]] = null - private var hasBlocksPerInputSplit: Array[Array[Int]] = null - - // Stored in bytes per millisecond - private var speedPerInputSplit: Array[Array[Double]] = null - - private var curConnectionsPerReducer: Array[Int] = null - private var maxConnectionsPerReducer: Array[Int] = null - - // The order of elements in the outputLocs (splitIndex) is used to pass - // information back and forth between the tracker, mappers, and reducers - def initialize(outputLocs_ : Array[SplitInfo]): Unit = { - outputLocs = outputLocs_ - - numMappers = outputLocs.size - - // All the outputLocs have totalBlocksPerOutputSplit of same size - numReducers = outputLocs(0).totalBlocksPerOutputSplit.size - - // Now initialize the data structures - totalBlocksPerInputSplit = Array.tabulate(numReducers, numMappers)((i,j) => - outputLocs(j).totalBlocksPerOutputSplit(i)) - hasBlocksPerInputSplit = Array.tabulate(numReducers, numMappers)((_,_) => 0) - - // Initialize to -1 - speedPerInputSplit = Array.tabulate(numReducers, numMappers)((_,_) => -1.0) - - curConnectionsPerReducer = Array.tabulate(numReducers)(_ => 0) - maxConnectionsPerReducer = Array.tabulate(numReducers)(_ => Shuffle.MaxRxConnections) - } - - def selectSplit(reducerSplitInfo: SplitInfo): ArrayBuffer[Int] = synchronized { - var splitIndices = ArrayBuffer[Int]() - - // Estimate time remaining to finish receiving for all reducer/mapper pairs - // If speed is unknown or zero then make it 1 to give a large estimate - var individualEstimates = Array.tabulate(numReducers, numMappers)((_,_) => 0.0) - for (i <- 0 until numReducers; j <- 0 until numMappers) { - var blocksRemaining = totalBlocksPerInputSplit(i)(j) - - hasBlocksPerInputSplit(i)(j) - assert(blocksRemaining >= 0) - - individualEstimates(i)(j) = 1.0 * blocksRemaining * Shuffle.BlockSize / - { if (speedPerInputSplit(i)(j) <= 0.0) 1.0 else speedPerInputSplit(i)(j) } - } - - // Check if all speedPerInputSplit entries have non-zero values - var estimationComplete = true - for (i <- 0 until numReducers; j <- 0 until numMappers) { - if (speedPerInputSplit(i)(j) < 0.0) { - estimationComplete = false - } - } - - // Estimate time remaining to finish receiving for each reducer - var completionEstimates = Array.tabulate(numReducers)( - individualEstimates(_).foldLeft(Double.MinValue)(Math.max(_,_))) - - var numFinished = 0 - for (i <- 0 until numReducers) { - if (completionEstimates(i).toInt == 0) { - numFinished += 1 - } - } - - // If a certain number of reducers have finished already, then don't bother - if (numFinished >= ((1.0 - 0.1 * Shuffle.ThrottleFraction) * numReducers)) { - for (i <- 0 until numReducers) { - maxConnectionsPerReducer(i) = Shuffle.MaxRxConnections - } - // Otherwise, if estimation is complete give reducers proportional threads - } else if (estimationComplete) { - val fastestEstimate = - completionEstimates.foldLeft(Double.MaxValue)(Math.min(_,_)) - val slowestEstimate = - completionEstimates.foldLeft(Double.MinValue)(Math.max(_,_)) - - // Set maxConnectionsPerReducer for all reducers proportional to their - // estimated time remaining with slowestEstimate reducer having the max - for (i <- 0 until numReducers) { - maxConnectionsPerReducer(i) = - ((completionEstimates(i) / slowestEstimate) * Shuffle.MaxRxConnections).toInt - } - } - - // Send back a splitIndex if this reducer is within its limit - if (curConnectionsPerReducer(reducerSplitInfo.splitId) < - maxConnectionsPerReducer(reducerSplitInfo.splitId)) { - - var i = maxConnectionsPerReducer(reducerSplitInfo.splitId) - - curConnectionsPerReducer(reducerSplitInfo.splitId) - - var temp = reducerSplitInfo.hasSplitsBitVector.clone.asInstanceOf[BitSet] - temp.flip(0, numMappers) - - i = Math.min(i, temp.cardinality) - - while (i > 0) { - var splitIndex = -1 - - do { - splitIndex = ranGen.nextInt(numMappers) - } while (reducerSplitInfo.hasSplitsBitVector.get(splitIndex)) - - reducerSplitInfo.hasSplitsBitVector.set(splitIndex) - splitIndices += splitIndex - i -= 1 - } - } - - return splitIndices - } - - def AddReducerToSplit(reducerSplitInfo: SplitInfo, splitIndices: ArrayBuffer[Int]): Unit = synchronized { - splitIndices.foreach { splitIndex => - curConnectionsPerReducer(reducerSplitInfo.splitId) += 1 - } - } - - def deleteReducerFrom(reducerSplitInfo: SplitInfo, - receptionStat: ReceptionStats): Unit = synchronized { - // Update hasBlocksPerInputSplit for reducerSplitInfo - hasBlocksPerInputSplit(reducerSplitInfo.splitId) = - reducerSplitInfo.hasBlocksPerInputSplit - - // Store the last known speed. Add 1 to avoid divide-by-zero. Ignore 0 bytes - // TODO: We are forgetting the old speed. Can use averaging at some point. - if (receptionStat.bytesReceived > 0) { - speedPerInputSplit(reducerSplitInfo.splitId)(receptionStat.serverSplitIndex) = - 1.0 * receptionStat.bytesReceived / (receptionStat.timeSpent + 1.0) - } - - // logInfo("%d received %d bytes in %d millis".format(reducerSplitInfo.splitId, receptionStat.bytesReceived, receptionStat.timeSpent)) - - // Update current threads by this reducer - curConnectionsPerReducer(reducerSplitInfo.splitId) -= 1 - - // TODO: This assertion can legally fail when ShuffleClient times out while - // waiting for tracker response and decides to go to a random server - // assert(curConnectionsPerLoc(receptionStat.serverSplitIndex) >= 0) - - // Just in case - if (curConnectionsPerReducer(reducerSplitInfo.splitId) < 0) { - curConnectionsPerReducer(reducerSplitInfo.splitId) = 0 - } - } -} diff --git a/core/src/main/scala/spark/TrackedCustomBlockedInMemoryShuffle.scala b/core/src/main/scala/spark/TrackedCustomBlockedInMemoryShuffle.scala deleted file mode 100644 index 0d21df9338..0000000000 --- a/core/src/main/scala/spark/TrackedCustomBlockedInMemoryShuffle.scala +++ /dev/null @@ -1,926 +0,0 @@ -package spark - -import java.io._ -import java.net._ -import java.util.{BitSet, Random, Timer, TimerTask, UUID} -import java.util.concurrent.atomic.AtomicLong -import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory} - -import scala.collection.mutable.{ArrayBuffer, HashMap} - -/** - * An implementation of shuffle using memory served through custom server - * where receivers create simultaneous connections to multiple servers by - * setting the 'spark.shuffle.maxRxConnections' config option. - * - * By controlling the 'spark.shuffle.blockSize' config option one can also - * control the largest block size to divide each map output into. Essentially, - * instead of creating one large output file for each reducer, maps create - * multiple smaller files to enable finer level of engagement. - * - * 'spark.shuffle.maxTxConnections' enforces server-side cap. Ideally, - * maxTxConnections >= maxRxConnections * numReducersPerMachine - * - * 'spark.shuffle.TrackerStrategy' decides which strategy to use in the tracker - * - * TODO: Add support for compression when spark.compress is set to true. - */ -@serializable -class TrackedCustomBlockedInMemoryShuffle[K, V, C] -extends Shuffle[K, V, C] with Logging { - @transient var totalSplits = 0 - @transient var hasSplits = 0 - - @transient var totalBlocksInSplit: Array[Int] = null - @transient var hasBlocksInSplit: Array[Int] = null - - @transient var hasSplitsBitVector: BitSet = null - @transient var splitsInRequestBitVector: BitSet = null - - @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null - @transient var combiners: HashMap[K,C] = null - - override def compute(input: RDD[(K, V)], - numOutputSplits: Int, - createCombiner: V => C, - mergeValue: (C, V) => C, - mergeCombiners: (C, C) => C) - : RDD[(K, C)] = - { - val sc = input.sparkContext - val shuffleId = TrackedCustomBlockedInMemoryShuffle.newShuffleId() - logInfo("Shuffle ID: " + shuffleId) - - val splitRdd = new NumberedSplitRDD(input) - val numInputSplits = splitRdd.splits.size - - // Run a parallel map and collect to write the intermediate data files, - // returning a list of inputSplitId -> serverUri pairs - val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => { - val myIndex = pair._1 - val myIterator = pair._2 - val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C]) - for ((k, v) <- myIterator) { - var bucketId = k.hashCode % numOutputSplits - if (bucketId < 0) { // Fix bucket ID if hash code was negative - bucketId += numOutputSplits - } - val bucket = buckets(bucketId) - bucket(k) = bucket.get(k) match { - case Some(c) => mergeValue(c, v) - case None => createCombiner(v) - } - } - - // Keep track of number of blocks for each output split - var numBlocksPerOutputSplit = Array.tabulate(numOutputSplits)(_ => 0) - - for (i <- 0 until numOutputSplits) { - var blockNum = 0 - var isDirty = false - - var splitName = "" - var baos: ByteArrayOutputStream = null - var oos: ObjectOutputStream = null - - var writeStartTime: Long = 0 - - buckets(i).foreach(pair => { - // Open a new stream if necessary - if (!isDirty) { - splitName = TrackedCustomBlockedInMemoryShuffle.getSplitName(shuffleId, - myIndex, i, blockNum) - - baos = new ByteArrayOutputStream - oos = new ObjectOutputStream(baos) - - writeStartTime = System.currentTimeMillis - logInfo("BEGIN WRITE: " + splitName) - } - - oos.writeObject(pair) - isDirty = true - - // Close the old stream if has crossed the blockSize limit - if (baos.size > Shuffle.BlockSize) { - TrackedCustomBlockedInMemoryShuffle.splitsCache(splitName) = - baos.toByteArray - - logInfo("END WRITE: " + splitName) - val writeTime = System.currentTimeMillis - writeStartTime - logInfo("Writing " + splitName + " of size " + baos.size + " bytes took " + writeTime + " millis.") - - blockNum = blockNum + 1 - isDirty = false - oos.close() - } - }) - - if (isDirty) { - TrackedCustomBlockedInMemoryShuffle.splitsCache(splitName) = baos.toByteArray - - logInfo("END WRITE: " + splitName) - val writeTime = System.currentTimeMillis - writeStartTime - logInfo("Writing " + splitName + " of size " + baos.size + " bytes took " + writeTime + " millis.") - - blockNum = blockNum + 1 - oos.close() - } - - // Store BLOCKNUM info - splitName = TrackedCustomBlockedInMemoryShuffle.getBlockNumOutputName( - shuffleId, myIndex, i) - baos = new ByteArrayOutputStream - oos = new ObjectOutputStream(baos) - oos.writeObject(blockNum) - TrackedCustomBlockedInMemoryShuffle.splitsCache(splitName) = baos.toByteArray - - // Close streams - oos.close() - - // Store number of blocks for this outputSplit - numBlocksPerOutputSplit(i) = blockNum - } - - var retVal = SplitInfo(TrackedCustomBlockedInMemoryShuffle.serverAddress, - TrackedCustomBlockedInMemoryShuffle.serverPort, myIndex) - retVal.totalBlocksPerOutputSplit = numBlocksPerOutputSplit - - (retVal) - }).collect() - - // Start tracker - var shuffleTracker = new ShuffleTracker(outputLocs) - shuffleTracker.setDaemon(true) - shuffleTracker.start() - logInfo("ShuffleTracker started...") - - // Return an RDD that does each of the merges for a given partition - val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits) - return indexes.flatMap((myId: Int) => { - totalSplits = outputLocs.size - hasSplits = 0 - - totalBlocksInSplit = Array.tabulate(totalSplits)(_ => -1) - hasBlocksInSplit = Array.tabulate(totalSplits)(_ => 0) - - hasSplitsBitVector = new BitSet(totalSplits) - splitsInRequestBitVector = new BitSet(totalSplits) - - receivedData = new LinkedBlockingQueue[(Int, Array[Byte])] - combiners = new HashMap[K, C] - - var threadPool = Shuffle.newDaemonFixedThreadPool( - Shuffle.MaxRxConnections) - - while (hasSplits < totalSplits) { - // Local status of hasSplitsBitVector and splitsInRequestBitVector - val localSplitInfo = getLocalSplitInfo(myId) - - // DO NOT talk to the tracker if all the required splits are already busy - val hasOrWillHaveSplits = localSplitInfo.hasSplitsBitVector.cardinality - - var numThreadsToCreate = - Math.min(totalSplits - hasOrWillHaveSplits, Shuffle.MaxRxConnections) - - threadPool.getActiveCount - - while (hasSplits < totalSplits && numThreadsToCreate > 0) { - // Receive which split to pull from the tracker - logInfo("Talking to tracker...") - val startTime = System.currentTimeMillis - val splitIndices = getTrackerSelectedSplit(myId) - logInfo("Got %s from tracker in %d millis".format(splitIndices, System.currentTimeMillis - startTime)) - - if (splitIndices.size > 0) { - splitIndices.foreach { splitIndex => - val selectedSplitInfo = outputLocs(splitIndex) - val requestSplit = - "%d/%d/%d".format(shuffleId, selectedSplitInfo.splitId, myId) - - threadPool.execute(new ShuffleClient(splitIndex, selectedSplitInfo, - requestSplit, myId)) - - // splitIndex is in transit. Will be unset in the ShuffleClient - splitsInRequestBitVector.synchronized { - splitsInRequestBitVector.set(splitIndex) - } - } - } else { - // Tracker replied back with a NO. Sleep for a while. - Thread.sleep(Shuffle.MinKnockInterval) - numThreadsToCreate = 0 - } - - numThreadsToCreate = numThreadsToCreate - splitIndices.size - } - - // Sleep for a while before creating new threads - Thread.sleep(Shuffle.MinKnockInterval) - } - - threadPool.shutdown() - - // Start consumer - // TODO: Consumption is delayed until everything has been received. - // Otherwise it interferes with network performance - var shuffleConsumer = new ShuffleConsumer(mergeCombiners) - shuffleConsumer.setDaemon(true) - shuffleConsumer.start() - logInfo("ShuffleConsumer started...") - - // Don't return until consumption is finished - // while (receivedData.size > 0) { - // Thread.sleep(Shuffle.MinKnockInterval) - // } - - // Wait till shuffleConsumer is done - shuffleConsumer.join - - combiners - }) - } - - private def getLocalSplitInfo(myId: Int): SplitInfo = { - var localSplitInfo = SplitInfo(InetAddress.getLocalHost.getHostAddress, - SplitInfo.UnusedParam, myId) - - // Store hasSplits - localSplitInfo.hasSplits = hasSplits - - // Store hasSplitsBitVector - hasSplitsBitVector.synchronized { - localSplitInfo.hasSplitsBitVector = - hasSplitsBitVector.clone.asInstanceOf[BitSet] - } - - // Store hasBlocksInSplit to hasBlocksPerInputSplit - hasBlocksInSplit.synchronized { - localSplitInfo.hasBlocksPerInputSplit = - hasBlocksInSplit.clone.asInstanceOf[Array[Int]] - } - - // Include the splitsInRequest as well - splitsInRequestBitVector.synchronized { - localSplitInfo.hasSplitsBitVector.or(splitsInRequestBitVector) - } - - return localSplitInfo - } - - def selectRandomSplit: Int = { - var requiredSplits = new ArrayBuffer[Int] - - synchronized { - for (i <- 0 until totalSplits) { - if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) { - requiredSplits += i - } - } - } - - if (requiredSplits.size > 0) { - requiredSplits(TrackedCustomBlockedInMemoryShuffle.ranGen.nextInt( - requiredSplits.size)) - } else { - -1 - } - } - - // Talks to the tracker and receives instruction - private def getTrackerSelectedSplit(myId: Int): ArrayBuffer[Int] = { - // Local status of hasSplitsBitVector and splitsInRequestBitVector - val localSplitInfo = getLocalSplitInfo(myId) - - // DO NOT talk to the tracker if all the required splits are already busy - if (localSplitInfo.hasSplitsBitVector.cardinality == totalSplits) { - return ArrayBuffer[Int]() - } - - val clientSocketToTracker = new Socket(Shuffle.MasterHostAddress, - Shuffle.MasterTrackerPort) - val oosTracker = - new ObjectOutputStream(clientSocketToTracker.getOutputStream) - oosTracker.flush() - val oisTracker = - new ObjectInputStream(clientSocketToTracker.getInputStream) - - var selectedSplitIndices = ArrayBuffer[Int]() - - // Setup the timeout mechanism - var timeOutTask = new TimerTask { - override def run: Unit = { - logInfo("Waited enough for tracker response... Take random response...") - - // sockets will be closed in finally - // TODO: Sometimes timer wont go off - - // TODO: Selecting randomly here. Tracker won't know about it and get an - // asssertion failure when this thread leaves - - selectedSplitIndices = ArrayBuffer(selectRandomSplit) - } - } - - var timeOutTimer = new Timer - // TODO: Which timeout to use? - // timeOutTimer.schedule(timeOutTask, Shuffle.MinKnockInterval) - - try { - // Send intention - oosTracker.writeObject(Shuffle.ReducerEntering) - oosTracker.flush() - - // Send what this reducer has - oosTracker.writeObject(localSplitInfo) - oosTracker.flush() - - // Receive reply from the tracker - selectedSplitIndices = oisTracker.readObject.asInstanceOf[ArrayBuffer[Int]] - - // Turn the timer OFF - timeOutTimer.cancel() - } catch { - case e: Exception => { - logInfo("getTrackerSelectedSplit had a " + e) - } - } finally { - oisTracker.close() - oosTracker.close() - clientSocketToTracker.close() - } - - return selectedSplitIndices - } - - class ShuffleTracker(outputLocs: Array[SplitInfo]) - extends Thread with Logging { - var threadPool = Shuffle.newDaemonCachedThreadPool - var serverSocket: ServerSocket = null - - // Create trackerStrategy object - val trackerStrategyClass = System.getProperty( - "spark.shuffle.trackerStrategy", - "spark.BalanceConnectionsShuffleTrackerStrategy") - - val trackerStrategy = - Class.forName(trackerStrategyClass).newInstance().asInstanceOf[ShuffleTrackerStrategy] - - // Must initialize here by supplying the outputLocs param - // TODO: This could be avoided by directly passing it to the constructor - trackerStrategy.initialize(outputLocs) - - override def run: Unit = { - serverSocket = new ServerSocket(Shuffle.MasterTrackerPort) - logInfo("ShuffleTracker" + serverSocket) - - try { - while (true) { - var clientSocket: Socket = null - try { - clientSocket = serverSocket.accept() - } catch { - case e: Exception => { - logInfo("ShuffleTracker had a " + e) - } - } - - if (clientSocket != null) { - try { - threadPool.execute(new Thread { - override def run: Unit = { - val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - val ois = new ObjectInputStream(clientSocket.getInputStream) - - try { - // Receive intention - val reducerIntention = ois.readObject.asInstanceOf[Int] - - if (reducerIntention == Shuffle.ReducerEntering) { - // Receive what the reducer has - val reducerSplitInfo = - ois.readObject.asInstanceOf[SplitInfo] - - // Select splits and update stats if necessary - var selectedSplitIndices = ArrayBuffer[Int]() - trackerStrategy.synchronized { - selectedSplitIndices = trackerStrategy.selectSplit( - reducerSplitInfo) - } - - // Send reply back - oos.writeObject(selectedSplitIndices) - oos.flush() - - // Update internal stats, only if receiver got the reply - trackerStrategy.synchronized { - trackerStrategy.AddReducerToSplit(reducerSplitInfo, - selectedSplitIndices) - } - } - else if (reducerIntention == Shuffle.ReducerLeaving) { - val reducerSplitInfo = - ois.readObject.asInstanceOf[SplitInfo] - - // Receive reception stats: how many blocks the reducer - // read in how much time and from where - val receptionStat = - ois.readObject.asInstanceOf[ReceptionStats] - - // Update stats - trackerStrategy.synchronized { - trackerStrategy.deleteReducerFrom(reducerSplitInfo, - receptionStat) - } - - // Send ACK - oos.writeObject(receptionStat.serverSplitIndex) - oos.flush() - } - else { - throw new SparkException("Undefined reducerIntention") - } - } catch { - // EOFException is expected to happen because receiver can - // break connection due to timeout and pick random instead - case eofe: java.io.EOFException => { } - case e: Exception => { - logInfo("ShuffleTracker had a " + e) - } - } finally { - ois.close() - oos.close() - clientSocket.close() - } - } - }) - } catch { - // In failure, close socket here; else, client thread will close - case ioe: IOException => { - clientSocket.close() - } - } - } - } - } finally { - serverSocket.close() - } - // Shutdown the thread pool - threadPool.shutdown() - } - } - - class ShuffleConsumer(mergeCombiners: (C, C) => C) - extends Thread with Logging { - override def run: Unit = { - // Run until all splits are here - while (receivedData.size > 0) { - var splitIndex = -1 - var recvByteArray: Array[Byte] = null - - try { - var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])] - splitIndex = tempPair._1 - recvByteArray = tempPair._2 - } catch { - case e: Exception => { - logInfo("Exception during taking data from receivedData") - } - } - - val inputStream = - new ObjectInputStream(new ByteArrayInputStream(recvByteArray)) - - try{ - while (true) { - val (k, c) = inputStream.readObject.asInstanceOf[(K, C)] - combiners(k) = combiners.get(k) match { - case Some(oldC) => mergeCombiners(oldC, c) - case None => c - } - } - } catch { - case e: EOFException => { } - } - inputStream.close() - } - } - } - - class ShuffleClient(splitIndex: Int, serversplitInfo: SplitInfo, - requestSplit: String, myId: Int) - extends Thread with Logging { - private var peerSocketToSource: Socket = null - private var oosSource: ObjectOutputStream = null - private var oisSource: ObjectInputStream = null - - private var receptionSucceeded = false - - // Make sure that multiple messages don't go to the tracker - private var alreadySentLeavingNotification = false - - // Keep track of bytes received and time spent - private var numBytesReceived = 0 - private var totalTimeSpent = 0 - - override def run: Unit = { - // Setup the timeout mechanism - var timeOutTask = new TimerTask { - override def run: Unit = { - cleanUp() - } - } - - var timeOutTimer = new Timer - timeOutTimer.schedule(timeOutTask, Shuffle.MaxKnockInterval) - - try { - // Everything will break if BLOCKNUM is not correctly received - // First get BLOCKNUM file if totalBlocksInSplit(splitIndex) is unknown - peerSocketToSource = new Socket( - serversplitInfo.hostAddress, serversplitInfo.listenPort) - oosSource = - new ObjectOutputStream(peerSocketToSource.getOutputStream) - oosSource.flush() - var isSource = peerSocketToSource.getInputStream - oisSource = new ObjectInputStream(isSource) - - // Send path information - oosSource.writeObject(requestSplit) - - // TODO: Can be optimized. No need to do it everytime. - // Receive BLOCKNUM - totalBlocksInSplit(splitIndex) = oisSource.readObject.asInstanceOf[Int] - - // Turn the timer OFF, if the sender responds before timeout - timeOutTimer.cancel() - - while (hasBlocksInSplit(splitIndex) < totalBlocksInSplit(splitIndex)) { - // Set receptionSucceeded to false before trying for each block - receptionSucceeded = false - - // Request specific block - oosSource.writeObject(hasBlocksInSplit(splitIndex)) - - // Good to go. First, receive the length of the requested file - var requestedFileLen = oisSource.readObject.asInstanceOf[Int] - logInfo("Received requestedFileLen = " + requestedFileLen) - - // Create a temp variable to be used in different places - val requestPath = "http://%s:%d/shuffle/%s-%d".format( - serversplitInfo.hostAddress, serversplitInfo.listenPort, requestSplit, - hasBlocksInSplit(splitIndex)) - - // Receive the file - if (requestedFileLen != -1) { - val readStartTime = System.currentTimeMillis - logInfo("BEGIN READ: " + requestPath) - - // Receive data in an Array[Byte] - var recvByteArray = new Array[Byte](requestedFileLen) - var alreadyRead = 0 - var bytesRead = 0 - - while (alreadyRead != requestedFileLen) { - bytesRead = isSource.read(recvByteArray, alreadyRead, - requestedFileLen - alreadyRead) - if (bytesRead > 0) { - alreadyRead = alreadyRead + bytesRead - } - } - - // Make it available to the consumer - try { - receivedData.put((splitIndex, recvByteArray)) - } catch { - case e: Exception => { - logInfo("Exception during putting data into receivedData") - } - } - - // TODO: Updating stats before consumption is completed - hasBlocksInSplit(splitIndex) = hasBlocksInSplit(splitIndex) + 1 - - // Split has been received only if all the blocks have been received - if (hasBlocksInSplit(splitIndex) == totalBlocksInSplit(splitIndex)) { - hasSplitsBitVector.synchronized { - hasSplitsBitVector.set(splitIndex) - hasSplits += 1 - } - } - - receptionSucceeded = true - - logInfo("END READ: " + requestPath) - val readTime = System.currentTimeMillis - readStartTime - logInfo("Reading " + requestPath + " took " + readTime + " millis.") - - // Update stats - numBytesReceived = numBytesReceived + requestedFileLen - totalTimeSpent = totalTimeSpent + readTime.toInt - } else { - throw new SparkException("ShuffleServer " + serversplitInfo.hostAddress + " does not have " + requestSplit) - } - } - } catch { - // EOFException is expected to happen because sender can break - // connection due to timeout - case eofe: java.io.EOFException => { } - case e: Exception => { - logInfo("ShuffleClient had a " + e) - } - } finally { - splitsInRequestBitVector.synchronized { - splitsInRequestBitVector.set(splitIndex, false) - } - cleanUp() - } - } - - // Connect to the tracker and update its stats - private def sendLeavingNotification(): Unit = synchronized { - if (!alreadySentLeavingNotification) { - val clientSocketToTracker = new Socket(Shuffle.MasterHostAddress, - Shuffle.MasterTrackerPort) - val oosTracker = - new ObjectOutputStream(clientSocketToTracker.getOutputStream) - oosTracker.flush() - val oisTracker = - new ObjectInputStream(clientSocketToTracker.getInputStream) - - try { - // Send intention - oosTracker.writeObject(Shuffle.ReducerLeaving) - oosTracker.flush() - - // Send reducerSplitInfo - oosTracker.writeObject(getLocalSplitInfo(myId)) - oosTracker.flush() - - // Send reception stats - oosTracker.writeObject(ReceptionStats( - numBytesReceived, totalTimeSpent, splitIndex)) - oosTracker.flush() - - // Receive ACK. No need to do anything with that - oisTracker.readObject.asInstanceOf[Int] - - // Now update sentLeavingNotifacation - alreadySentLeavingNotification = true - } catch { - case e: Exception => { - logInfo("sendLeavingNotification had a " + e) - } - } finally { - oisTracker.close() - oosTracker.close() - clientSocketToTracker.close() - } - } - } - - private def cleanUp(): Unit = { - // Update tracker stats first - sendLeavingNotification() - - // Clean up the connections to the mapper - if (oisSource != null) { - oisSource.close() - } - if (oosSource != null) { - oosSource.close() - } - if (peerSocketToSource != null) { - peerSocketToSource.close() - } - - logInfo("Leaving client") - } - } -} - -object TrackedCustomBlockedInMemoryShuffle extends Logging { - // Cache for keeping the splits around - val splitsCache = new HashMap[String, Array[Byte]] - - private var initialized = false - private var nextShuffleId = new AtomicLong(0) - - // Variables initialized by initializeIfNeeded() - private var shuffleDir: File = null - - private var shuffleServer: ShuffleServer = null - private var serverAddress = InetAddress.getLocalHost.getHostAddress - private var serverPort: Int = -1 - - // Random number generator - var ranGen = new Random - - private def initializeIfNeeded() = synchronized { - if (!initialized) { - // TODO: localDir should be created by some mechanism common to Spark - // so that it can be shared among shuffle, broadcast, etc - val localDirRoot = System.getProperty("spark.local.dir", "/tmp") - var tries = 0 - var foundLocalDir = false - var localDir: File = null - var localDirUuid: UUID = null - while (!foundLocalDir && tries < 10) { - tries += 1 - try { - localDirUuid = UUID.randomUUID - localDir = new File(localDirRoot, "spark-local-" + localDirUuid) - if (!localDir.exists) { - localDir.mkdirs() - foundLocalDir = true - } - } catch { - case e: Exception => - logWarning("Attempt " + tries + " to create local dir failed", e) - } - } - if (!foundLocalDir) { - logError("Failed 10 attempts to create local dir in " + localDirRoot) - System.exit(1) - } - shuffleDir = new File(localDir, "shuffle") - shuffleDir.mkdirs() - logInfo("Shuffle dir: " + shuffleDir) - - // Create and start the shuffleServer - shuffleServer = new ShuffleServer - shuffleServer.setDaemon(true) - shuffleServer.start() - logInfo("ShuffleServer started...") - - initialized = true - } - } - - def getSplitName(shuffleId: Long, inputId: Int, outputId: Int, - blockId: Int): String = { - initializeIfNeeded() - // Adding shuffleDir is unnecessary. Added to keep the parsers working - return "%s/%d/%d/%d-%d".format(shuffleDir, shuffleId, inputId, outputId, - blockId) - } - - def getBlockNumOutputName(shuffleId: Long, inputId: Int, - outputId: Int): String = { - initializeIfNeeded() - return "%s/%d/%d/%d-BLOCKNUM".format(shuffleDir, shuffleId, inputId, - outputId) - } - - def newShuffleId(): Long = { - nextShuffleId.getAndIncrement() - } - - class ShuffleServer - extends Thread with Logging { - var threadPool = Shuffle.newDaemonFixedThreadPool(Shuffle.MaxTxConnections) - - var serverSocket: ServerSocket = null - - override def run: Unit = { - serverSocket = new ServerSocket(0) - serverPort = serverSocket.getLocalPort - - logInfo("ShuffleServer started with " + serverSocket) - logInfo("Local URI: http://" + serverAddress + ":" + serverPort) - - try { - while (true) { - var clientSocket: Socket = null - try { - clientSocket = serverSocket.accept() - } catch { - case e: Exception => { } - } - if (clientSocket != null) { - logInfo("Serve: Accepted new client connection:" + clientSocket) - try { - threadPool.execute(new ShuffleServerThread(clientSocket)) - } catch { - // In failure, close socket here; else, the thread will close it - case ioe: IOException => { - clientSocket.close() - } - } - } - } - } finally { - if (serverSocket != null) { - logInfo("ShuffleServer now stopping...") - serverSocket.close() - } - } - // Shutdown the thread pool - threadPool.shutdown() - } - - class ShuffleServerThread(val clientSocket: Socket) - extends Thread with Logging { - private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream] - os.flush() - private val bos = new BufferedOutputStream(os) - bos.flush() - private val oos = new ObjectOutputStream(os) - oos.flush() - private val ois = new ObjectInputStream(clientSocket.getInputStream) - - logInfo("new ShuffleServerThread is running") - - override def run: Unit = { - try { - // Receive basic path information - var requestedSplitBase = ois.readObject.asInstanceOf[String] - - logInfo("requestedSplitBase: " + requestedSplitBase) - - // Read BLOCKNUM and send back the total number of blocks - val blockNumName = "%s/%s-BLOCKNUM".format(shuffleDir, - requestedSplitBase) - - val blockNumIn = new ObjectInputStream(new ByteArrayInputStream( - TrackedCustomBlockedInMemoryShuffle.splitsCache(blockNumName))) - val BLOCKNUM = blockNumIn.readObject.asInstanceOf[Int] - blockNumIn.close() - - oos.writeObject(BLOCKNUM) - oos.flush() - - val startTime = System.currentTimeMillis - var curTime = startTime - var keepSending = true - var numBlocksToSend = Shuffle.MaxChatBlocks - - while (keepSending && numBlocksToSend > 0) { - // Receive specific block request - val blockId = ois.readObject.asInstanceOf[Int] - - // Ready to send - var requestedSplit = shuffleDir + "/" + requestedSplitBase + "-" + blockId - - // Send the length of the requestedSplit to let the receiver know that - // transfer is about to start - // In the case of receiver timeout and connection close, this will - // throw a java.net.SocketException: Broken pipe - var requestedSplitLen = -1 - - try { - requestedSplitLen = - TrackedCustomBlockedInMemoryShuffle.splitsCache(requestedSplit).length - } catch { - case e: Exception => { } - } - - oos.writeObject(requestedSplitLen) - oos.flush() - - logInfo("requestedSplitLen = " + requestedSplitLen) - - // Read and send the requested file - if (requestedSplitLen != -1) { - // Send - bos.write(TrackedCustomBlockedInMemoryShuffle.splitsCache(requestedSplit), - 0, requestedSplitLen) - bos.flush() - - // Update loop variables - numBlocksToSend = numBlocksToSend - 1 - - curTime = System.currentTimeMillis - // Revoke sending only if there is anyone waiting in the queue - // TODO: Turning OFF the optimization so that reducers go back to - // tracker get advice - if (curTime - startTime >= Shuffle.MaxChatTime /* && - threadPool.getQueue.size > 0 */) { - keepSending = false - } - } else { - // Close the connection - } - } - } catch { - // If something went wrong, e.g., the worker at the other end died etc - // then close everything up - // Exception can happen if the receiver stops receiving - // EOFException is expected to happen because receiver can break - // connection as soon as it has all the blocks - case eofe: java.io.EOFException => { } - case e: Exception => { - logInfo("ShuffleServerThread had a " + e) - } - } finally { - logInfo("ShuffleServerThread is closing streams and sockets") - ois.close() - // TODO: Following can cause "java.net.SocketException: Socket closed" - oos.close() - bos.close() - clientSocket.close() - } - } - } - } -} diff --git a/core/src/main/scala/spark/TrackedCustomBlockedLocalFileShuffle.scala b/core/src/main/scala/spark/TrackedCustomBlockedLocalFileShuffle.scala deleted file mode 100644 index a27fa628c6..0000000000 --- a/core/src/main/scala/spark/TrackedCustomBlockedLocalFileShuffle.scala +++ /dev/null @@ -1,930 +0,0 @@ -package spark - -import java.io._ -import java.net._ -import java.util.{BitSet, Random, Timer, TimerTask, UUID} -import java.util.concurrent.atomic.AtomicLong -import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory} - -import scala.collection.mutable.{ArrayBuffer, HashMap} - -/** - * An implementation of shuffle using local files served through custom server - * where receivers create simultaneous connections to multiple servers by - * setting the 'spark.shuffle.maxRxConnections' config option. - * - * By controlling the 'spark.shuffle.blockSize' config option one can also - * control the largest block size to divide each map output into. Essentially, - * instead of creating one large output file for each reducer, maps create - * multiple smaller files to enable finer level of engagement. - * - * 'spark.shuffle.maxTxConnections' enforces server-side cap. Ideally, - * maxTxConnections >= maxRxConnections * numReducersPerMachine - * - * 'spark.shuffle.TrackerStrategy' decides which strategy to use in the tracker - * - * TODO: Add support for compression when spark.compress is set to true. - */ -@serializable -class TrackedCustomBlockedLocalFileShuffle[K, V, C] -extends Shuffle[K, V, C] with Logging { - @transient var totalSplits = 0 - @transient var hasSplits = 0 - - @transient var totalBlocksInSplit: Array[Int] = null - @transient var hasBlocksInSplit: Array[Int] = null - - @transient var hasSplitsBitVector: BitSet = null - @transient var splitsInRequestBitVector: BitSet = null - - @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null - @transient var combiners: HashMap[K,C] = null - - override def compute(input: RDD[(K, V)], - numOutputSplits: Int, - createCombiner: V => C, - mergeValue: (C, V) => C, - mergeCombiners: (C, C) => C) - : RDD[(K, C)] = - { - val sc = input.sparkContext - val shuffleId = TrackedCustomBlockedLocalFileShuffle.newShuffleId() - logInfo("Shuffle ID: " + shuffleId) - - val splitRdd = new NumberedSplitRDD(input) - val numInputSplits = splitRdd.splits.size - - // Run a parallel map and collect to write the intermediate data files, - // returning a list of inputSplitId -> serverUri pairs - val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => { - val myIndex = pair._1 - val myIterator = pair._2 - val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C]) - for ((k, v) <- myIterator) { - var bucketId = k.hashCode % numOutputSplits - if (bucketId < 0) { // Fix bucket ID if hash code was negative - bucketId += numOutputSplits - } - val bucket = buckets(bucketId) - bucket(k) = bucket.get(k) match { - case Some(c) => mergeValue(c, v) - case None => createCombiner(v) - } - } - - // Keep track of number of blocks for each output split - var numBlocksPerOutputSplit = Array.tabulate(numOutputSplits)(_ => 0) - - for (i <- 0 until numOutputSplits) { - var blockNum = 0 - var isDirty = false - var file: File = null - var out: ObjectOutputStream = null - - var writeStartTime: Long = 0 - - buckets(i).foreach(pair => { - // Open a new file if necessary - if (!isDirty) { - file = TrackedCustomBlockedLocalFileShuffle.getOutputFile(shuffleId, - myIndex, i, blockNum) - writeStartTime = System.currentTimeMillis - logInfo("BEGIN WRITE: " + file) - - out = new ObjectOutputStream(new FileOutputStream(file)) - } - - out.writeObject(pair) - out.flush() - isDirty = true - - // Close the old file if has crossed the blockSize limit - if (file.length > Shuffle.BlockSize) { - out.close() - logInfo("END WRITE: " + file) - val writeTime = System.currentTimeMillis - writeStartTime - logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.") - - blockNum = blockNum + 1 - isDirty = false - } - }) - - if (isDirty) { - out.close() - logInfo("END WRITE: " + file) - val writeTime = System.currentTimeMillis - writeStartTime - logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.") - - blockNum = blockNum + 1 - } - - // Write the BLOCKNUM file - file = TrackedCustomBlockedLocalFileShuffle.getBlockNumOutputFile(shuffleId, - myIndex, i) - out = new ObjectOutputStream(new FileOutputStream(file)) - out.writeObject(blockNum) - out.close() - - // Store number of blocks for this outputSplit - numBlocksPerOutputSplit(i) = blockNum - } - - var retVal = SplitInfo(TrackedCustomBlockedLocalFileShuffle.serverAddress, - TrackedCustomBlockedLocalFileShuffle.serverPort, myIndex) - retVal.totalBlocksPerOutputSplit = numBlocksPerOutputSplit - - (retVal) - }).collect() - - // Start tracker - var shuffleTracker = new ShuffleTracker(outputLocs) - shuffleTracker.setDaemon(true) - shuffleTracker.start() - logInfo("ShuffleTracker started...") - - // Return an RDD that does each of the merges for a given partition - val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits) - return indexes.flatMap((myId: Int) => { - totalSplits = outputLocs.size - hasSplits = 0 - - totalBlocksInSplit = Array.tabulate(totalSplits)(_ => -1) - hasBlocksInSplit = Array.tabulate(totalSplits)(_ => 0) - - hasSplitsBitVector = new BitSet(totalSplits) - splitsInRequestBitVector = new BitSet(totalSplits) - - receivedData = new LinkedBlockingQueue[(Int, Array[Byte])] - combiners = new HashMap[K, C] - - var threadPool = Shuffle.newDaemonFixedThreadPool( - Shuffle.MaxRxConnections) - - while (hasSplits < totalSplits) { - // Local status of hasSplitsBitVector and splitsInRequestBitVector - val localSplitInfo = getLocalSplitInfo(myId) - - // DO NOT talk to the tracker if all the required splits are already busy - val hasOrWillHaveSplits = localSplitInfo.hasSplitsBitVector.cardinality - - var numThreadsToCreate = - Math.min(totalSplits - hasOrWillHaveSplits, Shuffle.MaxRxConnections) - - threadPool.getActiveCount - - while (hasSplits < totalSplits && numThreadsToCreate > 0) { - // Receive which split to pull from the tracker - logInfo("Talking to tracker...") - val startTime = System.currentTimeMillis - val splitIndices = getTrackerSelectedSplit(myId) - logInfo("Got %s from tracker in %d millis".format(splitIndices, System.currentTimeMillis - startTime)) - - if (splitIndices.size > 0) { - splitIndices.foreach { splitIndex => - val selectedSplitInfo = outputLocs(splitIndex) - val requestSplit = - "%d/%d/%d".format(shuffleId, selectedSplitInfo.splitId, myId) - - threadPool.execute(new ShuffleClient(splitIndex, selectedSplitInfo, - requestSplit, myId)) - - // splitIndex is in transit. Will be unset in the ShuffleClient - splitsInRequestBitVector.synchronized { - splitsInRequestBitVector.set(splitIndex) - } - } - } else { - // Tracker replied back with a NO. Sleep for a while. - Thread.sleep(Shuffle.MinKnockInterval) - numThreadsToCreate = 0 - } - - numThreadsToCreate = numThreadsToCreate - splitIndices.size - } - - // Sleep for a while before creating new threads - Thread.sleep(Shuffle.MinKnockInterval) - } - - threadPool.shutdown() - - // Start consumer - // TODO: Consumption is delayed until everything has been received. - // Otherwise it interferes with network performance - var shuffleConsumer = new ShuffleConsumer(mergeCombiners) - shuffleConsumer.setDaemon(true) - shuffleConsumer.start() - logInfo("ShuffleConsumer started...") - - // Don't return until consumption is finished - // while (receivedData.size > 0) { - // Thread.sleep(Shuffle.MinKnockInterval) - // } - - // Wait till shuffleConsumer is done - shuffleConsumer.join - - combiners - }) - } - - private def getLocalSplitInfo(myId: Int): SplitInfo = { - var localSplitInfo = SplitInfo(InetAddress.getLocalHost.getHostAddress, - SplitInfo.UnusedParam, myId) - - // Store hasSplits - localSplitInfo.hasSplits = hasSplits - - // Store hasSplitsBitVector - hasSplitsBitVector.synchronized { - localSplitInfo.hasSplitsBitVector = - hasSplitsBitVector.clone.asInstanceOf[BitSet] - } - - // Store hasBlocksInSplit to hasBlocksPerInputSplit - hasBlocksInSplit.synchronized { - localSplitInfo.hasBlocksPerInputSplit = - hasBlocksInSplit.clone.asInstanceOf[Array[Int]] - } - - // Include the splitsInRequest as well - splitsInRequestBitVector.synchronized { - localSplitInfo.hasSplitsBitVector.or(splitsInRequestBitVector) - } - - return localSplitInfo - } - - def selectRandomSplit: Int = { - var requiredSplits = new ArrayBuffer[Int] - - synchronized { - for (i <- 0 until totalSplits) { - if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) { - requiredSplits += i - } - } - } - - if (requiredSplits.size > 0) { - requiredSplits(TrackedCustomBlockedLocalFileShuffle.ranGen.nextInt( - requiredSplits.size)) - } else { - -1 - } - } - - // Talks to the tracker and receives instruction - private def getTrackerSelectedSplit(myId: Int): ArrayBuffer[Int] = { - // Local status of hasSplitsBitVector and splitsInRequestBitVector - val localSplitInfo = getLocalSplitInfo(myId) - - // DO NOT talk to the tracker if all the required splits are already busy - if (localSplitInfo.hasSplitsBitVector.cardinality == totalSplits) { - return ArrayBuffer[Int]() - } - - val clientSocketToTracker = new Socket(Shuffle.MasterHostAddress, - Shuffle.MasterTrackerPort) - val oosTracker = - new ObjectOutputStream(clientSocketToTracker.getOutputStream) - oosTracker.flush() - val oisTracker = - new ObjectInputStream(clientSocketToTracker.getInputStream) - - var selectedSplitIndices = ArrayBuffer[Int]() - - // Setup the timeout mechanism - var timeOutTask = new TimerTask { - override def run: Unit = { - logInfo("Waited enough for tracker response... Take random response...") - - // sockets will be closed in finally - // TODO: Sometimes timer wont go off - - // TODO: Selecting randomly here. Tracker won't know about it and get an - // asssertion failure when this thread leaves - - selectedSplitIndices = ArrayBuffer(selectRandomSplit) - } - } - - var timeOutTimer = new Timer - // TODO: Which timeout to use? - // timeOutTimer.schedule(timeOutTask, Shuffle.MinKnockInterval) - - try { - // Send intention - oosTracker.writeObject(Shuffle.ReducerEntering) - oosTracker.flush() - - // Send what this reducer has - oosTracker.writeObject(localSplitInfo) - oosTracker.flush() - - // Receive reply from the tracker - selectedSplitIndices = oisTracker.readObject.asInstanceOf[ArrayBuffer[Int]] - - // Turn the timer OFF - timeOutTimer.cancel() - } catch { - case e: Exception => { - logInfo("getTrackerSelectedSplit had a " + e) - } - } finally { - oisTracker.close() - oosTracker.close() - clientSocketToTracker.close() - } - - return selectedSplitIndices - } - - class ShuffleTracker(outputLocs: Array[SplitInfo]) - extends Thread with Logging { - var threadPool = Shuffle.newDaemonCachedThreadPool - var serverSocket: ServerSocket = null - - // Create trackerStrategy object - val trackerStrategyClass = System.getProperty( - "spark.shuffle.trackerStrategy", - "spark.BalanceConnectionsShuffleTrackerStrategy") - - val trackerStrategy = - Class.forName(trackerStrategyClass).newInstance().asInstanceOf[ShuffleTrackerStrategy] - - // Must initialize here by supplying the outputLocs param - // TODO: This could be avoided by directly passing it to the constructor - trackerStrategy.initialize(outputLocs) - - override def run: Unit = { - serverSocket = new ServerSocket(Shuffle.MasterTrackerPort) - logInfo("ShuffleTracker" + serverSocket) - - try { - while (true) { - var clientSocket: Socket = null - try { - clientSocket = serverSocket.accept() - } catch { - case e: Exception => { - logInfo("ShuffleTracker had a " + e) - } - } - - if (clientSocket != null) { - try { - threadPool.execute(new Thread { - override def run: Unit = { - val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - val ois = new ObjectInputStream(clientSocket.getInputStream) - - try { - // Receive intention - val reducerIntention = ois.readObject.asInstanceOf[Int] - - if (reducerIntention == Shuffle.ReducerEntering) { - // Receive what the reducer has - val reducerSplitInfo = - ois.readObject.asInstanceOf[SplitInfo] - - // Select splits and update stats if necessary - var selectedSplitIndices = ArrayBuffer[Int]() - trackerStrategy.synchronized { - selectedSplitIndices = trackerStrategy.selectSplit( - reducerSplitInfo) - } - - // Send reply back - oos.writeObject(selectedSplitIndices) - oos.flush() - - // Update internal stats, only if receiver got the reply - trackerStrategy.synchronized { - trackerStrategy.AddReducerToSplit(reducerSplitInfo, - selectedSplitIndices) - } - } - else if (reducerIntention == Shuffle.ReducerLeaving) { - val reducerSplitInfo = - ois.readObject.asInstanceOf[SplitInfo] - - // Receive reception stats: how many blocks the reducer - // read in how much time and from where - val receptionStat = - ois.readObject.asInstanceOf[ReceptionStats] - - // Update stats - trackerStrategy.synchronized { - trackerStrategy.deleteReducerFrom(reducerSplitInfo, - receptionStat) - } - - // Send ACK - oos.writeObject(receptionStat.serverSplitIndex) - oos.flush() - } - else { - throw new SparkException("Undefined reducerIntention") - } - } catch { - // EOFException is expected to happen because receiver can - // break connection due to timeout and pick random instead - case eofe: java.io.EOFException => { } - case e: Exception => { - logInfo("ShuffleTracker had a " + e) - } - } finally { - ois.close() - oos.close() - clientSocket.close() - } - } - }) - } catch { - // In failure, close socket here; else, client thread will close - case ioe: IOException => { - clientSocket.close() - } - } - } - } - } finally { - serverSocket.close() - } - // Shutdown the thread pool - threadPool.shutdown() - } - } - - class ShuffleConsumer(mergeCombiners: (C, C) => C) - extends Thread with Logging { - override def run: Unit = { - // Run until all splits are here - while (receivedData.size > 0) { - var splitIndex = -1 - var recvByteArray: Array[Byte] = null - - try { - var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])] - splitIndex = tempPair._1 - recvByteArray = tempPair._2 - } catch { - case e: Exception => { - logInfo("Exception during taking data from receivedData") - } - } - - val inputStream = - new ObjectInputStream(new ByteArrayInputStream(recvByteArray)) - - try{ - while (true) { - val (k, c) = inputStream.readObject.asInstanceOf[(K, C)] - combiners(k) = combiners.get(k) match { - case Some(oldC) => mergeCombiners(oldC, c) - case None => c - } - } - } catch { - case e: EOFException => { } - } - inputStream.close() - } - } - } - - class ShuffleClient(splitIndex: Int, serversplitInfo: SplitInfo, - requestSplit: String, myId: Int) - extends Thread with Logging { - private var peerSocketToSource: Socket = null - private var oosSource: ObjectOutputStream = null - private var oisSource: ObjectInputStream = null - - private var receptionSucceeded = false - - // Make sure that multiple messages don't go to the tracker - private var alreadySentLeavingNotification = false - - // Keep track of bytes received and time spent - private var numBytesReceived = 0 - private var totalTimeSpent = 0 - - override def run: Unit = { - // Setup the timeout mechanism - var timeOutTask = new TimerTask { - override def run: Unit = { - cleanUp() - } - } - - var timeOutTimer = new Timer - timeOutTimer.schedule(timeOutTask, Shuffle.MaxKnockInterval) - - try { - // Everything will break if BLOCKNUM is not correctly received - // First get BLOCKNUM file if totalBlocksInSplit(splitIndex) is unknown - peerSocketToSource = new Socket( - serversplitInfo.hostAddress, serversplitInfo.listenPort) - oosSource = - new ObjectOutputStream(peerSocketToSource.getOutputStream) - oosSource.flush() - var isSource = peerSocketToSource.getInputStream - oisSource = new ObjectInputStream(isSource) - - // Send path information - oosSource.writeObject(requestSplit) - - // TODO: Can be optimized. No need to do it everytime. - // Receive BLOCKNUM - totalBlocksInSplit(splitIndex) = oisSource.readObject.asInstanceOf[Int] - - // Turn the timer OFF, if the sender responds before timeout - timeOutTimer.cancel() - - while (hasBlocksInSplit(splitIndex) < totalBlocksInSplit(splitIndex)) { - // Set receptionSucceeded to false before trying for each block - receptionSucceeded = false - - // Request specific block - oosSource.writeObject(hasBlocksInSplit(splitIndex)) - - // Good to go. First, receive the length of the requested file - var requestedFileLen = oisSource.readObject.asInstanceOf[Int] - logInfo("Received requestedFileLen = " + requestedFileLen) - - // Create a temp variable to be used in different places - val requestPath = "http://%s:%d/shuffle/%s-%d".format( - serversplitInfo.hostAddress, serversplitInfo.listenPort, requestSplit, - hasBlocksInSplit(splitIndex)) - - // Receive the file - if (requestedFileLen != -1) { - val readStartTime = System.currentTimeMillis - logInfo("BEGIN READ: " + requestPath) - - // Receive data in an Array[Byte] - var recvByteArray = new Array[Byte](requestedFileLen) - var alreadyRead = 0 - var bytesRead = 0 - - while (alreadyRead != requestedFileLen) { - bytesRead = isSource.read(recvByteArray, alreadyRead, - requestedFileLen - alreadyRead) - if (bytesRead > 0) { - alreadyRead = alreadyRead + bytesRead - } - } - - // Make it available to the consumer - try { - receivedData.put((splitIndex, recvByteArray)) - } catch { - case e: Exception => { - logInfo("Exception during putting data into receivedData") - } - } - - // TODO: Updating stats before consumption is completed - hasBlocksInSplit(splitIndex) = hasBlocksInSplit(splitIndex) + 1 - - // Split has been received only if all the blocks have been received - if (hasBlocksInSplit(splitIndex) == totalBlocksInSplit(splitIndex)) { - hasSplitsBitVector.synchronized { - hasSplitsBitVector.set(splitIndex) - hasSplits += 1 - } - } - - receptionSucceeded = true - - logInfo("END READ: " + requestPath) - val readTime = System.currentTimeMillis - readStartTime - logInfo("Reading " + requestPath + " took " + readTime + " millis.") - - // Update stats - numBytesReceived = numBytesReceived + requestedFileLen - totalTimeSpent = totalTimeSpent + readTime.toInt - } else { - throw new SparkException("ShuffleServer " + serversplitInfo.hostAddress + " does not have " + requestSplit) - } - } - } catch { - // EOFException is expected to happen because sender can break - // connection due to timeout - case eofe: java.io.EOFException => { } - case e: Exception => { - logInfo("ShuffleClient had a " + e) - } - } finally { - splitsInRequestBitVector.synchronized { - splitsInRequestBitVector.set(splitIndex, false) - } - cleanUp() - } - } - - // Connect to the tracker and update its stats - private def sendLeavingNotification(): Unit = synchronized { - if (!alreadySentLeavingNotification) { - val clientSocketToTracker = new Socket(Shuffle.MasterHostAddress, - Shuffle.MasterTrackerPort) - val oosTracker = - new ObjectOutputStream(clientSocketToTracker.getOutputStream) - oosTracker.flush() - val oisTracker = - new ObjectInputStream(clientSocketToTracker.getInputStream) - - try { - // Send intention - oosTracker.writeObject(Shuffle.ReducerLeaving) - oosTracker.flush() - - // Send reducerSplitInfo - oosTracker.writeObject(getLocalSplitInfo(myId)) - oosTracker.flush() - - // Send reception stats - oosTracker.writeObject(ReceptionStats( - numBytesReceived, totalTimeSpent, splitIndex)) - oosTracker.flush() - - // Receive ACK. No need to do anything with that - oisTracker.readObject.asInstanceOf[Int] - - // Now update sentLeavingNotifacation - alreadySentLeavingNotification = true - } catch { - case e: Exception => { - logInfo("sendLeavingNotification had a " + e) - } - } finally { - oisTracker.close() - oosTracker.close() - clientSocketToTracker.close() - } - } - } - - private def cleanUp(): Unit = { - // Update tracker stats first - sendLeavingNotification() - - // Clean up the connections to the mapper - if (oisSource != null) { - oisSource.close() - } - if (oosSource != null) { - oosSource.close() - } - if (peerSocketToSource != null) { - peerSocketToSource.close() - } - - logInfo("Leaving client") - } - } -} - -object TrackedCustomBlockedLocalFileShuffle extends Logging { - private var initialized = false - private var nextShuffleId = new AtomicLong(0) - - // Variables initialized by initializeIfNeeded() - private var shuffleDir: File = null - - private var shuffleServer: ShuffleServer = null - private var serverAddress = InetAddress.getLocalHost.getHostAddress - private var serverPort: Int = -1 - - // Random number generator - var ranGen = new Random - - private def initializeIfNeeded() = synchronized { - if (!initialized) { - // TODO: localDir should be created by some mechanism common to Spark - // so that it can be shared among shuffle, broadcast, etc - val localDirRoot = System.getProperty("spark.local.dir", "/tmp") - var tries = 0 - var foundLocalDir = false - var localDir: File = null - var localDirUuid: UUID = null - while (!foundLocalDir && tries < 10) { - tries += 1 - try { - localDirUuid = UUID.randomUUID - localDir = new File(localDirRoot, "spark-local-" + localDirUuid) - if (!localDir.exists) { - localDir.mkdirs() - foundLocalDir = true - } - } catch { - case e: Exception => - logWarning("Attempt " + tries + " to create local dir failed", e) - } - } - if (!foundLocalDir) { - logError("Failed 10 attempts to create local dir in " + localDirRoot) - System.exit(1) - } - shuffleDir = new File(localDir, "shuffle") - shuffleDir.mkdirs() - logInfo("Shuffle dir: " + shuffleDir) - - // Create and start the shuffleServer - shuffleServer = new ShuffleServer - shuffleServer.setDaemon(true) - shuffleServer.start() - logInfo("ShuffleServer started...") - - initialized = true - } - } - - def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int, - blockId: Int): File = { - initializeIfNeeded() - val dir = new File(shuffleDir, shuffleId + "/" + inputId) - dir.mkdirs() - val file = new File(dir, "%d-%d".format(outputId, blockId)) - return file - } - - def getBlockNumOutputFile(shuffleId: Long, inputId: Int, - outputId: Int): File = { - initializeIfNeeded() - val dir = new File(shuffleDir, shuffleId + "/" + inputId) - dir.mkdirs() - val file = new File(dir, outputId + "-BLOCKNUM") - return file - } - - def newShuffleId(): Long = { - nextShuffleId.getAndIncrement() - } - - class ShuffleServer - extends Thread with Logging { - var threadPool = Shuffle.newDaemonFixedThreadPool(Shuffle.MaxTxConnections) - - var serverSocket: ServerSocket = null - - override def run: Unit = { - serverSocket = new ServerSocket(0) - serverPort = serverSocket.getLocalPort - - logInfo("ShuffleServer started with " + serverSocket) - logInfo("Local URI: http://" + serverAddress + ":" + serverPort) - - try { - while (true) { - var clientSocket: Socket = null - try { - clientSocket = serverSocket.accept() - } catch { - case e: Exception => { } - } - if (clientSocket != null) { - logInfo("Serve: Accepted new client connection:" + clientSocket) - try { - threadPool.execute(new ShuffleServerThread(clientSocket)) - } catch { - // In failure, close socket here; else, the thread will close it - case ioe: IOException => { - clientSocket.close() - } - } - } - } - } finally { - if (serverSocket != null) { - logInfo("ShuffleServer now stopping...") - serverSocket.close() - } - } - // Shutdown the thread pool - threadPool.shutdown() - } - - class ShuffleServerThread(val clientSocket: Socket) - extends Thread with Logging { - private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream] - os.flush() - private val bos = new BufferedOutputStream(os) - bos.flush() - private val oos = new ObjectOutputStream(os) - oos.flush() - private val ois = new ObjectInputStream(clientSocket.getInputStream) - - logInfo("new ShuffleServerThread is running") - - override def run: Unit = { - try { - // Receive basic path information - var requestPathBase = ois.readObject.asInstanceOf[String] - - logInfo("requestPathBase: " + requestPathBase) - - // Read BLOCKNUM file and send back the total number of blocks - val blockNumFilePath = "%s/%s-BLOCKNUM".format(shuffleDir, - requestPathBase) - val blockNumIn = - new ObjectInputStream(new FileInputStream(blockNumFilePath)) - val BLOCKNUM = blockNumIn.readObject.asInstanceOf[Int] - blockNumIn.close() - - oos.writeObject(BLOCKNUM) - oos.flush() - - val startTime = System.currentTimeMillis - var curTime = startTime - var keepSending = true - var numBlocksToSend = Shuffle.MaxChatBlocks - - while (keepSending && numBlocksToSend > 0) { - // Receive specific block request - val blockId = ois.readObject.asInstanceOf[Int] - - // Ready to send - var requestPath = requestPathBase + "-" + blockId - - // Open the file - var requestedFile: File = null - var requestedFileLen = -1 - try { - requestedFile = new File(shuffleDir + "/" + requestPath) - requestedFileLen = requestedFile.length.toInt - } catch { - case e: Exception => { } - } - - // Send the length of the requestPath to let the receiver know that - // transfer is about to start - // In the case of receiver timeout and connection close, this will - // throw a java.net.SocketException: Broken pipe - oos.writeObject(requestedFileLen) - oos.flush() - - logInfo("requestedFileLen = " + requestedFileLen) - - // Read and send the requested file - if (requestedFileLen != -1) { - // Read - var byteArray = new Array[Byte](requestedFileLen) - val bis = - new BufferedInputStream(new FileInputStream(requestedFile)) - - var bytesRead = bis.read(byteArray, 0, byteArray.length) - var alreadyRead = bytesRead - - while (alreadyRead < requestedFileLen) { - bytesRead = bis.read(byteArray, alreadyRead, - (byteArray.length - alreadyRead)) - if(bytesRead > 0) { - alreadyRead = alreadyRead + bytesRead - } - } - bis.close() - - // Send - bos.write(byteArray, 0, byteArray.length) - bos.flush() - - // Update loop variables - numBlocksToSend = numBlocksToSend - 1 - - curTime = System.currentTimeMillis - // Revoke sending only if there is anyone waiting in the queue - // TODO: Turning OFF the optimization so that reducers go back to - // tracker get advice - if (curTime - startTime >= Shuffle.MaxChatTime /* && - threadPool.getQueue.size > 0 */) { - keepSending = false - } - } else { - // Close the connection - } - } - } catch { - // If something went wrong, e.g., the worker at the other end died etc - // then close everything up - // Exception can happen if the receiver stops receiving - // EOFException is expected to happen because receiver can break - // connection as soon as it has all the blocks - case eofe: java.io.EOFException => { } - case e: Exception => { - logInfo("ShuffleServerThread had a " + e) - } - } finally { - logInfo("ShuffleServerThread is closing streams and sockets") - ois.close() - // TODO: Following can cause "java.net.SocketException: Socket closed" - oos.close() - bos.close() - clientSocket.close() - } - } - } - } -} diff --git a/core/src/main/scala/spark/TrackedCustomParallelLocalFileShuffle.scala b/core/src/main/scala/spark/TrackedCustomParallelLocalFileShuffle.scala deleted file mode 100644 index 4a38a8d7ff..0000000000 --- a/core/src/main/scala/spark/TrackedCustomParallelLocalFileShuffle.scala +++ /dev/null @@ -1,802 +0,0 @@ -package spark - -import java.io._ -import java.net._ -import java.util.{BitSet, Random, Timer, TimerTask, UUID} -import java.util.concurrent.atomic.AtomicLong -import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory} - -import scala.collection.mutable.{ArrayBuffer, HashMap} - -/** - * An implementation of shuffle using local files served through custom server - * where receivers create simultaneous connections to multiple servers by - * setting the 'spark.shuffle.maxRxConnections' config option. - * - * 'spark.shuffle.maxTxConnections' enforces server-side cap. Ideally, - * maxTxConnections >= maxRxConnections * numReducersPerMachine - * - * 'spark.shuffle.TrackerStrategy' decides which strategy to use in the tracker - * - * TODO: Add support for compression when spark.compress is set to true. - */ -@serializable -class TrackedCustomParallelLocalFileShuffle[K, V, C] -extends Shuffle[K, V, C] with Logging { - @transient var totalSplits = 0 - @transient var hasSplits = 0 - @transient var hasSplitsBitVector: BitSet = null - @transient var splitsInRequestBitVector: BitSet = null - - @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null - @transient var combiners: HashMap[K,C] = null - - override def compute(input: RDD[(K, V)], - numOutputSplits: Int, - createCombiner: V => C, - mergeValue: (C, V) => C, - mergeCombiners: (C, C) => C) - : RDD[(K, C)] = - { - val sc = input.sparkContext - val shuffleId = TrackedCustomParallelLocalFileShuffle.newShuffleId() - logInfo("Shuffle ID: " + shuffleId) - - val splitRdd = new NumberedSplitRDD(input) - val numInputSplits = splitRdd.splits.size - - // Run a parallel map and collect to write the intermediate data files - val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => { - val myIndex = pair._1 - val myIterator = pair._2 - val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C]) - for ((k, v) <- myIterator) { - var bucketId = k.hashCode % numOutputSplits - if (bucketId < 0) { // Fix bucket ID if hash code was negative - bucketId += numOutputSplits - } - val bucket = buckets(bucketId) - bucket(k) = bucket.get(k) match { - case Some(c) => mergeValue(c, v) - case None => createCombiner(v) - } - } - - for (i <- 0 until numOutputSplits) { - val file = TrackedCustomParallelLocalFileShuffle.getOutputFile(shuffleId, - myIndex, i) - val writeStartTime = System.currentTimeMillis - logInfo("BEGIN WRITE: " + file) - val out = new ObjectOutputStream(new FileOutputStream(file)) - buckets(i).foreach(pair => out.writeObject(pair)) - out.close() - logInfo("END WRITE: " + file) - val writeTime = System.currentTimeMillis - writeStartTime - logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.") - } - - (SplitInfo (TrackedCustomParallelLocalFileShuffle.serverAddress, - TrackedCustomParallelLocalFileShuffle.serverPort, myIndex)) - }).collect() - - // Start tracker - var shuffleTracker = new ShuffleTracker(outputLocs) - shuffleTracker.setDaemon(true) - shuffleTracker.start() - logInfo("ShuffleTracker started...") - - // Return an RDD that does each of the merges for a given partition - val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits) - return indexes.flatMap((myId: Int) => { - totalSplits = outputLocs.size - hasSplits = 0 - hasSplitsBitVector = new BitSet(totalSplits) - splitsInRequestBitVector = new BitSet(totalSplits) - - receivedData = new LinkedBlockingQueue[(Int, Array[Byte])] - combiners = new HashMap[K, C] - - var threadPool = Shuffle.newDaemonFixedThreadPool( - Shuffle.MaxRxConnections) - - while (hasSplits < totalSplits) { - // Local status of hasSplitsBitVector and splitsInRequestBitVector - val localSplitInfo = getLocalSplitInfo(myId) - - // DO NOT talk to the tracker if all the required splits are already busy - val hasOrWillHaveSplits = localSplitInfo.hasSplitsBitVector.cardinality - - var numThreadsToCreate = - Math.min(totalSplits - hasOrWillHaveSplits, Shuffle.MaxRxConnections) - - threadPool.getActiveCount - - while (hasSplits < totalSplits && numThreadsToCreate > 0) { - // Receive which split to pull from the tracker - logInfo("Talking to tracker...") - val startTime = System.currentTimeMillis - val splitIndices = getTrackerSelectedSplit(myId) - logInfo("Got %s from tracker in %d millis".format(splitIndices, System.currentTimeMillis - startTime)) - - if (splitIndices.size > 0) { - splitIndices.foreach { splitIndex => - val selectedSplitInfo = outputLocs(splitIndex) - val requestSplit = - "%d/%d/%d".format(shuffleId, selectedSplitInfo.splitId, myId) - - threadPool.execute(new ShuffleClient(splitIndex, selectedSplitInfo, - requestSplit, myId)) - - // splitIndex is in transit. Will be unset in the ShuffleClient - splitsInRequestBitVector.synchronized { - splitsInRequestBitVector.set(splitIndex) - } - } - } else { - // Tracker replied back with a NO. Sleep for a while. - Thread.sleep(Shuffle.MinKnockInterval) - numThreadsToCreate = 0 - } - - numThreadsToCreate = numThreadsToCreate - splitIndices.size - } - - // Sleep for a while before creating new threads - Thread.sleep(Shuffle.MinKnockInterval) - } - - threadPool.shutdown() - - // Start consumer - // TODO: Consumption is delayed until everything has been received. - // Otherwise it interferes with network performance - var shuffleConsumer = new ShuffleConsumer(mergeCombiners) - shuffleConsumer.setDaemon(true) - shuffleConsumer.start() - logInfo("ShuffleConsumer started...") - - // Don't return until consumption is finished - // while (receivedData.size > 0) { - // Thread.sleep(Shuffle.MinKnockInterval) - // } - - // Wait till shuffleConsumer is done - shuffleConsumer.join - - combiners - }) - } - - private def getLocalSplitInfo(myId: Int): SplitInfo = { - var localSplitInfo = SplitInfo(InetAddress.getLocalHost.getHostAddress, - SplitInfo.UnusedParam, myId) - - localSplitInfo.hasSplits = hasSplits - - hasSplitsBitVector.synchronized { - localSplitInfo.hasSplitsBitVector = - hasSplitsBitVector.clone.asInstanceOf[BitSet] - } - - // Include the splitsInRequest as well - splitsInRequestBitVector.synchronized { - localSplitInfo.hasSplitsBitVector.or(splitsInRequestBitVector) - } - - return localSplitInfo - } - - // Selects a random split using local information - private def selectRandomSplit: Int = { - var requiredSplits = new ArrayBuffer[Int] - - synchronized { - for (i <- 0 until totalSplits) { - if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) { - requiredSplits += i - } - } - } - - if (requiredSplits.size > 0) { - requiredSplits(TrackedCustomParallelLocalFileShuffle.ranGen.nextInt( - requiredSplits.size)) - } else { - -1 - } - } - - // Talks to the tracker and receives instruction - private def getTrackerSelectedSplit(myId: Int): ArrayBuffer[Int] = { - // Local status of hasSplitsBitVector and splitsInRequestBitVector - val localSplitInfo = getLocalSplitInfo(myId) - - // DO NOT talk to the tracker if all the required splits are already busy - if (localSplitInfo.hasSplitsBitVector.cardinality == totalSplits) { - return ArrayBuffer[Int]() - } - - val clientSocketToTracker = new Socket(Shuffle.MasterHostAddress, - Shuffle.MasterTrackerPort) - val oosTracker = - new ObjectOutputStream(clientSocketToTracker.getOutputStream) - oosTracker.flush() - val oisTracker = - new ObjectInputStream(clientSocketToTracker.getInputStream) - - var selectedSplitIndices = ArrayBuffer[Int]() - - // Setup the timeout mechanism - var timeOutTask = new TimerTask { - override def run: Unit = { - logInfo("Waited enough for tracker response... Take random response...") - - // sockets will be closed in finally - // TODO: Sometimes timer wont go off - - // TODO: Selecting randomly here. Tracker won't know about it and get an - // asssertion failure when this thread leaves - - selectedSplitIndices = ArrayBuffer(selectRandomSplit) - } - } - - var timeOutTimer = new Timer - // TODO: Which timeout to use? - // timeOutTimer.schedule(timeOutTask, Shuffle.MinKnockInterval) - - try { - // Send intention - oosTracker.writeObject(Shuffle.ReducerEntering) - oosTracker.flush() - - // Send what this reducer has - oosTracker.writeObject(localSplitInfo) - oosTracker.flush() - - // Receive reply from the tracker - selectedSplitIndices = oisTracker.readObject.asInstanceOf[ArrayBuffer[Int]] - - // Turn the timer OFF - timeOutTimer.cancel() - } catch { - case e: Exception => { - logInfo("getTrackerSelectedSplit had a " + e) - } - } finally { - oisTracker.close() - oosTracker.close() - clientSocketToTracker.close() - } - - return selectedSplitIndices - } - - class ShuffleTracker(outputLocs: Array[SplitInfo]) - extends Thread with Logging { - var threadPool = Shuffle.newDaemonCachedThreadPool - var serverSocket: ServerSocket = null - - // Create trackerStrategy object - val trackerStrategyClass = System.getProperty( - "spark.shuffle.trackerStrategy", - "spark.BalanceConnectionsShuffleTrackerStrategy") - - val trackerStrategy = - Class.forName(trackerStrategyClass).newInstance().asInstanceOf[ShuffleTrackerStrategy] - - // Must initialize here by supplying the outputLocs param - // TODO: This could be avoided by directly passing it to the constructor - trackerStrategy.initialize(outputLocs) - - override def run: Unit = { - serverSocket = new ServerSocket(Shuffle.MasterTrackerPort) - logInfo("ShuffleTracker" + serverSocket) - - try { - while (true) { - var clientSocket: Socket = null - try { - clientSocket = serverSocket.accept() - } catch { - case e: Exception => { - logInfo("ShuffleTracker had a " + e) - } - } - - if (clientSocket != null) { - try { - threadPool.execute(new Thread { - override def run: Unit = { - val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - val ois = new ObjectInputStream(clientSocket.getInputStream) - - try { - // Receive intention - val reducerIntention = ois.readObject.asInstanceOf[Int] - - if (reducerIntention == Shuffle.ReducerEntering) { - // Receive what the reducer has - val reducerSplitInfo = - ois.readObject.asInstanceOf[SplitInfo] - - // Select split and update stats if necessary - var selectedSplitIndices = ArrayBuffer[Int]() - trackerStrategy.synchronized { - selectedSplitIndices = trackerStrategy.selectSplit( - reducerSplitInfo) - } - - // Send reply back - oos.writeObject(selectedSplitIndices) - oos.flush() - - // Update internal stats, only if receiver got the reply - trackerStrategy.synchronized { - trackerStrategy.AddReducerToSplit(reducerSplitInfo, - selectedSplitIndices) - } - } - else if (reducerIntention == Shuffle.ReducerLeaving) { - val reducerSplitInfo = - ois.readObject.asInstanceOf[SplitInfo] - - // Receive reception stats: how many blocks the reducer - // read in how much time and from where - val receptionStat = - ois.readObject.asInstanceOf[ReceptionStats] - - // Update stats - trackerStrategy.synchronized { - trackerStrategy.deleteReducerFrom(reducerSplitInfo, - receptionStat) - } - - // Send ACK - oos.writeObject(receptionStat.serverSplitIndex) - oos.flush() - } - else { - throw new SparkException("Undefined reducerIntention") - } - } catch { - // EOFException is expected to happen because receiver can - // break connection due to timeout and pick random instead - case eofe: java.io.EOFException => { } - case e: Exception => { - logInfo("ShuffleTracker had a " + e) - } - } finally { - ois.close() - oos.close() - clientSocket.close() - } - } - }) - } catch { - // In failure, close socket here; else, client thread will close - case ioe: IOException => { - clientSocket.close() - } - } - } - } - } finally { - serverSocket.close() - } - // Shutdown the thread pool - threadPool.shutdown() - } - } - - class ShuffleConsumer(mergeCombiners: (C, C) => C) - extends Thread with Logging { - override def run: Unit = { - // Run until all splits are here - while (receivedData.size > 0) { - var splitIndex = -1 - var recvByteArray: Array[Byte] = null - - try { - var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])] - splitIndex = tempPair._1 - recvByteArray = tempPair._2 - } catch { - case e: Exception => { - logInfo("Exception during taking data from receivedData") - } - } - - val inputStream = - new ObjectInputStream(new ByteArrayInputStream(recvByteArray)) - - try{ - while (true) { - val (k, c) = inputStream.readObject.asInstanceOf[(K, C)] - combiners(k) = combiners.get(k) match { - case Some(oldC) => mergeCombiners(oldC, c) - case None => c - } - } - } catch { - case e: EOFException => { } - } - inputStream.close() - } - } - } - - class ShuffleClient(splitIndex: Int, serversplitInfo: SplitInfo, - requestSplit: String, myId: Int) - extends Thread with Logging { - private var peerSocketToSource: Socket = null - private var oosSource: ObjectOutputStream = null - private var oisSource: ObjectInputStream = null - - private var receptionSucceeded = false - - // Make sure that multiple messages don't go to the tracker - private var alreadySentLeavingNotification = false - - // Keep track of bytes received and time spent - private var numBytesReceived = 0 - private var totalTimeSpent = 0 - - override def run: Unit = { - // Setup the timeout mechanism - var timeOutTask = new TimerTask { - override def run: Unit = { - cleanUp() - } - } - - var timeOutTimer = new Timer - timeOutTimer.schedule(timeOutTask, Shuffle.MaxKnockInterval) - - // Create a temp variable to be used in different places - val requestPath = "http://%s:%d/shuffle/%s".format( - serversplitInfo.hostAddress, serversplitInfo.listenPort, requestSplit) - - logInfo("ShuffleClient started... => " + requestPath) - - try { - // Connect to the source - peerSocketToSource = new Socket( - serversplitInfo.hostAddress, serversplitInfo.listenPort) - oosSource = - new ObjectOutputStream(peerSocketToSource.getOutputStream) - oosSource.flush() - var isSource = peerSocketToSource.getInputStream - oisSource = new ObjectInputStream(isSource) - - // Send the request - oosSource.writeObject(requestSplit) - oosSource.flush() - - // Receive the length of the requested file - var requestedFileLen = oisSource.readObject.asInstanceOf[Int] - logInfo("Received requestedFileLen = " + requestedFileLen) - - // Turn the timer OFF, if the sender responds before timeout - timeOutTimer.cancel() - - // Receive the file - if (requestedFileLen != -1) { - val readStartTime = System.currentTimeMillis - logInfo("BEGIN READ: " + requestPath) - - // Receive data in an Array[Byte] - var recvByteArray = new Array[Byte](requestedFileLen) - var alreadyRead = 0 - var bytesRead = 0 - - while (alreadyRead < requestedFileLen) { - bytesRead = isSource.read(recvByteArray, alreadyRead, - requestedFileLen - alreadyRead) - if (bytesRead > 0) { - alreadyRead = alreadyRead + bytesRead - } - } - - // Make it available to the consumer - try { - receivedData.put((splitIndex, recvByteArray)) - } catch { - case e: Exception => { - logInfo("Exception during putting data into receivedData") - } - } - - // TODO: Updating stats before consumption is completed - hasSplitsBitVector.synchronized { - hasSplitsBitVector.set(splitIndex) - hasSplits += 1 - } - - // We have received splitIndex - splitsInRequestBitVector.synchronized { - splitsInRequestBitVector.set(splitIndex, false) - } - - receptionSucceeded = true - - logInfo("END READ: " + requestPath) - val readTime = System.currentTimeMillis - readStartTime - logInfo("Reading " + requestPath + " took " + readTime + " millis.") - - // Update stats - numBytesReceived = numBytesReceived + requestedFileLen - totalTimeSpent = totalTimeSpent + readTime.toInt - } else { - throw new SparkException("ShuffleServer " + serversplitInfo.hostAddress + " does not have " + requestSplit) - } - } catch { - // EOFException is expected to happen because sender can break - // connection due to timeout - case eofe: java.io.EOFException => { } - case e: Exception => { - logInfo("ShuffleClient had a " + e) - } - } finally { - // If reception failed, unset for future retry - if (!receptionSucceeded) { - splitsInRequestBitVector.synchronized { - splitsInRequestBitVector.set(splitIndex, false) - } - } - cleanUp() - } - } - - // Connect to the tracker and update its stats - private def sendLeavingNotification(): Unit = synchronized { - if (!alreadySentLeavingNotification) { - val clientSocketToTracker = new Socket(Shuffle.MasterHostAddress, - Shuffle.MasterTrackerPort) - val oosTracker = - new ObjectOutputStream(clientSocketToTracker.getOutputStream) - oosTracker.flush() - val oisTracker = - new ObjectInputStream(clientSocketToTracker.getInputStream) - - try { - // Send intention - oosTracker.writeObject(Shuffle.ReducerLeaving) - oosTracker.flush() - - // Send reducerSplitInfo - oosTracker.writeObject(getLocalSplitInfo(myId)) - oosTracker.flush() - - // Send reception stats - oosTracker.writeObject(ReceptionStats( - numBytesReceived, totalTimeSpent, splitIndex)) - oosTracker.flush() - - // Receive ACK. No need to do anything with that - oisTracker.readObject.asInstanceOf[Int] - - // Now update sentLeavingNotifacation - alreadySentLeavingNotification = true - } catch { - case e: Exception => { - logInfo("sendLeavingNotification had a " + e) - } - } finally { - oisTracker.close() - oosTracker.close() - clientSocketToTracker.close() - } - } - } - - private def cleanUp(): Unit = { - // Update tracker stats first - sendLeavingNotification() - - // Clean up the connections to the mapper - if (oisSource != null) { - oisSource.close() - } - if (oosSource != null) { - oosSource.close() - } - if (peerSocketToSource != null) { - peerSocketToSource.close() - } - } - } -} - -object TrackedCustomParallelLocalFileShuffle extends Logging { - private var initialized = false - private var nextShuffleId = new AtomicLong(0) - - // Variables initialized by initializeIfNeeded() - private var shuffleDir: File = null - - private var shuffleServer: ShuffleServer = null - private var serverAddress = InetAddress.getLocalHost.getHostAddress - private var serverPort: Int = -1 - - // Random number generator - var ranGen = new Random - - private def initializeIfNeeded() = synchronized { - if (!initialized) { - // TODO: localDir should be created by some mechanism common to Spark - // so that it can be shared among shuffle, broadcast, etc - val localDirRoot = System.getProperty("spark.local.dir", "/tmp") - var tries = 0 - var foundLocalDir = false - var localDir: File = null - var localDirUuid: UUID = null - - while (!foundLocalDir && tries < 10) { - tries += 1 - try { - localDirUuid = UUID.randomUUID - localDir = new File(localDirRoot, "spark-local-" + localDirUuid) - if (!localDir.exists) { - localDir.mkdirs() - foundLocalDir = true - } - } catch { - case e: Exception => - logWarning("Attempt " + tries + " to create local dir failed", e) - } - } - if (!foundLocalDir) { - logError("Failed 10 attempts to create local dir in " + localDirRoot) - System.exit(1) - } - shuffleDir = new File(localDir, "shuffle") - shuffleDir.mkdirs() - logInfo("Shuffle dir: " + shuffleDir) - - // Create and start the shuffleServer - shuffleServer = new ShuffleServer - shuffleServer.setDaemon(true) - shuffleServer.start() - logInfo("ShuffleServer started...") - - initialized = true - } - } - - def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int): File = { - initializeIfNeeded() - val dir = new File(shuffleDir, shuffleId + "/" + inputId) - dir.mkdirs() - val file = new File(dir, "" + outputId) - return file - } - - def newShuffleId(): Long = { - nextShuffleId.getAndIncrement() - } - - class ShuffleServer - extends Thread with Logging { - var threadPool = Shuffle.newDaemonFixedThreadPool(Shuffle.MaxTxConnections) - - var serverSocket: ServerSocket = null - - override def run: Unit = { - serverSocket = new ServerSocket(0) - serverPort = serverSocket.getLocalPort - - logInfo("ShuffleServer started with " + serverSocket) - logInfo("Local URI: http://" + serverAddress + ":" + serverPort) - - try { - while (true) { - var clientSocket: Socket = null - try { - clientSocket = serverSocket.accept() - } catch { - case e: Exception => { } - } - if (clientSocket != null) { - logInfo("Serve: Accepted new client connection:" + clientSocket) - try { - threadPool.execute(new ShuffleServerThread(clientSocket)) - } catch { - // In failure, close socket here; else, the thread will close it - case ioe: IOException => { - clientSocket.close() - } - } - } - } - } finally { - if (serverSocket != null) { - logInfo("ShuffleServer now stopping...") - serverSocket.close() - } - } - // Shutdown the thread pool - threadPool.shutdown() - } - - class ShuffleServerThread(val clientSocket: Socket) - extends Thread with Logging { - private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream] - os.flush() - private val bos = new BufferedOutputStream(os) - bos.flush() - private val oos = new ObjectOutputStream(os) - oos.flush() - private val ois = new ObjectInputStream(clientSocket.getInputStream) - - logInfo("new ShuffleServerThread is running") - - override def run: Unit = { - try { - // Receive requestPath from the receiver - var requestPath = ois.readObject.asInstanceOf[String] - logInfo("requestPath: " + shuffleDir + "/" + requestPath) - - // Open the file - var requestedFile: File = null - var requestedFileLen = -1 - try { - requestedFile = new File(shuffleDir + "/" + requestPath) - requestedFileLen = requestedFile.length.toInt - } catch { - case e: Exception => { } - } - - // Send the length of the requestPath to let the receiver know that - // transfer is about to start - // In the case of receiver timeout and connection close, this will - // throw a java.net.SocketException: Broken pipe - oos.writeObject(requestedFileLen) - oos.flush() - - logInfo("requestedFileLen = " + requestedFileLen) - - // Read and send the requested file - if (requestedFileLen != -1) { - // Read - var byteArray = new Array[Byte](requestedFileLen) - val bis = - new BufferedInputStream(new FileInputStream(requestedFile)) - - var bytesRead = bis.read(byteArray, 0, byteArray.length) - var alreadyRead = bytesRead - - while (alreadyRead < requestedFileLen) { - bytesRead = bis.read(byteArray, alreadyRead, - (byteArray.length - alreadyRead)) - if(bytesRead > 0) { - alreadyRead = alreadyRead + bytesRead - } - } - bis.close() - - // Send - bos.write(byteArray, 0, byteArray.length) - bos.flush() - } else { - // Close the connection - } - } catch { - // If something went wrong, e.g., the worker at the other end died etc - // then close everything up - // Exception can happen if the receiver stops receiving - case e: Exception => { - logInfo("ShuffleServerThread had a " + e) - } - } finally { - logInfo("ShuffleServerThread is closing streams and sockets") - ois.close() - // TODO: Following can cause "java.net.SocketException: Socket closed" - oos.close() - bos.close() - clientSocket.close() - } - } - } - } -} |