aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorMosharaf Chowdhury <mosharaf@cs.berkeley.edu>2011-04-27 20:53:43 -0700
committerMosharaf Chowdhury <mosharaf@cs.berkeley.edu>2011-04-27 20:53:43 -0700
commit2742de707a7abfd76e3de20e10a0e4a974f12fd5 (patch)
tree2befc1d283f3b239ab2b033491cd79fcf900010d /core
parent9d78779257b156bec335af4ab2a66bb3cac30ca6 (diff)
downloadspark-2742de707a7abfd76e3de20e10a0e4a974f12fd5.tar.gz
spark-2742de707a7abfd76e3de20e10a0e4a974f12fd5.tar.bz2
spark-2742de707a7abfd76e3de20e10a0e4a974f12fd5.zip
Removed some shuffle implementations. Remaining ones all use local files
to write map outputs.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/spark/CustomBlockedInMemoryShuffle.scala629
-rw-r--r--core/src/main/scala/spark/CustomParallelFakeShuffle.scala495
-rw-r--r--core/src/main/scala/spark/CustomParallelInMemoryShuffle.scala535
-rw-r--r--core/src/main/scala/spark/HttpBlockedLocalFileShuffle.scala471
-rw-r--r--core/src/main/scala/spark/ManualBlockedLocalFileShuffle.scala465
-rw-r--r--core/src/main/scala/spark/ShuffleTrackerStrategy.scala470
-rw-r--r--core/src/main/scala/spark/TrackedCustomBlockedInMemoryShuffle.scala926
-rw-r--r--core/src/main/scala/spark/TrackedCustomBlockedLocalFileShuffle.scala930
-rw-r--r--core/src/main/scala/spark/TrackedCustomParallelLocalFileShuffle.scala802
9 files changed, 0 insertions, 5723 deletions
diff --git a/core/src/main/scala/spark/CustomBlockedInMemoryShuffle.scala b/core/src/main/scala/spark/CustomBlockedInMemoryShuffle.scala
deleted file mode 100644
index 5aef43f302..0000000000
--- a/core/src/main/scala/spark/CustomBlockedInMemoryShuffle.scala
+++ /dev/null
@@ -1,629 +0,0 @@
-package spark
-
-import java.io._
-import java.net._
-import java.util.{BitSet, Random, Timer, TimerTask, UUID}
-import java.util.concurrent.atomic.AtomicLong
-import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory}
-
-import scala.collection.mutable.{ArrayBuffer, HashMap}
-
-/**
- * TODO: THIS IS AN ABSOLUTELY EXPERIMENTAL IMPLEMENTATON (FOR NOW).
- *
- * An implementation of shuffle using local memory served through custom server
- * where receivers create simultaneous connections to multiple servers by
- * setting the 'spark.shuffle.maxRxConnections' config option.
- *
- * By controlling the 'spark.shuffle.blockSize' config option one can also
- * control the largest block size to divide each map output into. Essentially,
- * instead of creating one large output file for each reducer, maps create
- * multiple smaller files to enable finer level of engagement.
- *
- * 'spark.shuffle.maxTxConnections' enforces server-side cap. Ideally,
- * maxTxConnections >= maxRxConnections * numReducersPerMachine
- *
- * TODO: Add support for compression when spark.compress is set to true.
- */
-@serializable
-class CustomBlockedInMemoryShuffle[K, V, C]
-extends Shuffle[K, V, C] with Logging {
- @transient var totalSplits = 0
- @transient var hasSplits = 0
-
- @transient var totalBlocksInSplit: Array[Int] = null
- @transient var hasBlocksInSplit: Array[Int] = null
-
- @transient var hasSplitsBitVector: BitSet = null
- @transient var splitsInRequestBitVector: BitSet = null
-
- @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null
- @transient var combiners: HashMap[K,C] = null
-
- override def compute(input: RDD[(K, V)],
- numOutputSplits: Int,
- createCombiner: V => C,
- mergeValue: (C, V) => C,
- mergeCombiners: (C, C) => C)
- : RDD[(K, C)] =
- {
- val sc = input.sparkContext
- val shuffleId = CustomBlockedInMemoryShuffle.newShuffleId()
- logInfo("Shuffle ID: " + shuffleId)
-
- val splitRdd = new NumberedSplitRDD(input)
- val numInputSplits = splitRdd.splits.size
-
- // Run a parallel map and collect to write the intermediate data files,
- // returning a list of inputSplitId -> serverUri pairs
- val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => {
- val myIndex = pair._1
- val myIterator = pair._2
- val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C])
- for ((k, v) <- myIterator) {
- var bucketId = k.hashCode % numOutputSplits
- if (bucketId < 0) { // Fix bucket ID if hash code was negative
- bucketId += numOutputSplits
- }
- val bucket = buckets(bucketId)
- bucket(k) = bucket.get(k) match {
- case Some(c) => mergeValue(c, v)
- case None => createCombiner(v)
- }
- }
-
- for (i <- 0 until numOutputSplits) {
- var blockNum = 0
- var isDirty = false
-
- var splitName = ""
- var baos: ByteArrayOutputStream = null
- var oos: ObjectOutputStream = null
-
- var writeStartTime: Long = 0
-
- buckets(i).foreach(pair => {
- // Open a new stream if necessary
- if (!isDirty) {
- splitName = CustomBlockedInMemoryShuffle.getSplitName(shuffleId,
- myIndex, i, blockNum)
-
- baos = new ByteArrayOutputStream
- oos = new ObjectOutputStream(baos)
-
- writeStartTime = System.currentTimeMillis
- logInfo("BEGIN WRITE: " + splitName)
- }
-
- oos.writeObject(pair)
- isDirty = true
-
- // Close the old stream if has crossed the blockSize limit
- if (baos.size > Shuffle.BlockSize) {
- CustomBlockedInMemoryShuffle.splitsCache(splitName) =
- baos.toByteArray
-
- logInfo("END WRITE: " + splitName)
- val writeTime = System.currentTimeMillis - writeStartTime
- logInfo("Writing " + splitName + " of size " + baos.size + " bytes took " + writeTime + " millis.")
-
- blockNum = blockNum + 1
- isDirty = false
- oos.close()
- }
- })
-
- if (isDirty) {
- CustomBlockedInMemoryShuffle.splitsCache(splitName) = baos.toByteArray
-
- logInfo("END WRITE: " + splitName)
- val writeTime = System.currentTimeMillis - writeStartTime
- logInfo("Writing " + splitName + " of size " + baos.size + " bytes took " + writeTime + " millis.")
-
- blockNum = blockNum + 1
- oos.close()
- }
-
- // Store BLOCKNUM info
- splitName = CustomBlockedInMemoryShuffle.getBlockNumOutputName(
- shuffleId, myIndex, i)
- baos = new ByteArrayOutputStream
- oos = new ObjectOutputStream(baos)
- oos.writeObject(blockNum)
- CustomBlockedInMemoryShuffle.splitsCache(splitName) = baos.toByteArray
-
- // Close streams
- oos.close()
- }
-
- (myIndex, CustomBlockedInMemoryShuffle.serverAddress,
- CustomBlockedInMemoryShuffle.serverPort)
- }).collect()
-
- val splitsByUri = new ArrayBuffer[(String, Int, Int)]
- for ((inputId, serverAddress, serverPort) <- outputLocs) {
- splitsByUri += ((serverAddress, serverPort, inputId))
- }
-
- // TODO: Could broadcast outputLocs
-
- // Return an RDD that does each of the merges for a given partition
- val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits)
- return indexes.flatMap((myId: Int) => {
- totalSplits = outputLocs.size
- hasSplits = 0
-
- totalBlocksInSplit = Array.tabulate(totalSplits)(_ => -1)
- hasBlocksInSplit = Array.tabulate(totalSplits)(_ => 0)
-
- hasSplitsBitVector = new BitSet(totalSplits)
- splitsInRequestBitVector = new BitSet(totalSplits)
-
- receivedData = new LinkedBlockingQueue[(Int, Array[Byte])]
- combiners = new HashMap[K, C]
-
- var threadPool = Shuffle.newDaemonFixedThreadPool(
- Shuffle.MaxRxConnections)
-
- while (hasSplits < totalSplits) {
- var numThreadsToCreate =
- Math.min(totalSplits, Shuffle.MaxRxConnections) -
- threadPool.getActiveCount
-
- while (hasSplits < totalSplits && numThreadsToCreate > 0) {
- // Select a random split to pull
- val splitIndex = selectRandomSplit
-
- if (splitIndex != -1) {
- val (serverAddress, serverPort, inputId) = splitsByUri(splitIndex)
-
- threadPool.execute(new ShuffleClient(serverAddress, serverPort,
- shuffleId.toInt, inputId, myId, splitIndex))
-
- // splitIndex is in transit. Will be unset in the ShuffleClient
- splitsInRequestBitVector.synchronized {
- splitsInRequestBitVector.set(splitIndex)
- }
- }
-
- numThreadsToCreate = numThreadsToCreate - 1
- }
-
- // Sleep for a while before creating new threads
- Thread.sleep(Shuffle.MinKnockInterval)
- }
-
- threadPool.shutdown()
-
- // Start consumer
- // TODO: Consumption is delayed until everything has been received.
- // Otherwise it interferes with network performance
- var shuffleConsumer = new ShuffleConsumer(mergeCombiners)
- shuffleConsumer.setDaemon(true)
- shuffleConsumer.start()
- logInfo("ShuffleConsumer started...")
-
- // Don't return until consumption is finished
- // while (receivedData.size > 0) {
- // Thread.sleep(Shuffle.MinKnockInterval)
- // }
-
- // Wait till shuffleConsumer is done
- shuffleConsumer.join
-
- combiners
- })
- }
-
- def selectRandomSplit: Int = {
- var requiredSplits = new ArrayBuffer[Int]
-
- synchronized {
- for (i <- 0 until totalSplits) {
- if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) {
- requiredSplits += i
- }
- }
- }
-
- if (requiredSplits.size > 0) {
- requiredSplits(CustomBlockedInMemoryShuffle.ranGen.nextInt(
- requiredSplits.size))
- } else {
- -1
- }
- }
-
- class ShuffleConsumer(mergeCombiners: (C, C) => C)
- extends Thread with Logging {
- override def run: Unit = {
- // Run until all splits are here
- while (receivedData.size > 0) {
- var splitIndex = -1
- var recvByteArray: Array[Byte] = null
-
- try {
- var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])]
- splitIndex = tempPair._1
- recvByteArray = tempPair._2
- } catch {
- case e: Exception => {
- logInfo("Exception during taking data from receivedData")
- }
- }
-
- val inputStream =
- new ObjectInputStream(new ByteArrayInputStream(recvByteArray))
-
- try{
- while (true) {
- val (k, c) = inputStream.readObject.asInstanceOf[(K, C)]
- combiners(k) = combiners.get(k) match {
- case Some(oldC) => mergeCombiners(oldC, c)
- case None => c
- }
- }
- } catch {
- case e: EOFException => { }
- }
- inputStream.close()
- }
- }
- }
-
- class ShuffleClient(hostAddress: String, listenPort: Int, shuffleId: Int,
- inputId: Int, myId: Int, splitIndex: Int)
- extends Thread with Logging {
- private var peerSocketToSource: Socket = null
- private var oosSource: ObjectOutputStream = null
- private var oisSource: ObjectInputStream = null
-
- private var receptionSucceeded = false
-
- override def run: Unit = {
- // Setup the timeout mechanism
- var timeOutTask = new TimerTask {
- override def run: Unit = {
- cleanUpConnections()
- }
- }
-
- var timeOutTimer = new Timer
- timeOutTimer.schedule(timeOutTask, Shuffle.MaxKnockInterval)
-
- try {
- // Everything will break if BLOCKNUM is not correctly received
- // First get BLOCKNUM file if totalBlocksInSplit(splitIndex) is unknown
- peerSocketToSource = new Socket(hostAddress, listenPort)
- oosSource =
- new ObjectOutputStream(peerSocketToSource.getOutputStream)
- oosSource.flush()
- var isSource = peerSocketToSource.getInputStream
- oisSource = new ObjectInputStream(isSource)
-
- // Send path information
- oosSource.writeObject((shuffleId, inputId, myId))
-
- // TODO: Can be optimized. No need to do it everytime.
- // Receive BLOCKNUM
- totalBlocksInSplit(splitIndex) = oisSource.readObject.asInstanceOf[Int]
-
- // Turn the timer OFF, if the sender responds before timeout
- timeOutTimer.cancel()
-
- while (hasBlocksInSplit(splitIndex) < totalBlocksInSplit(splitIndex)) {
- // Set receptionSucceeded to false before trying for each block
- receptionSucceeded = false
-
- // Request specific block
- oosSource.writeObject(hasBlocksInSplit(splitIndex))
-
- // Good to go. First, receive the length of the requested file
- var requestedFileLen = oisSource.readObject.asInstanceOf[Int]
- logInfo("Received requestedFileLen = " + requestedFileLen)
-
- val requestSplit = "%d/%d/%d-%d".format(shuffleId, inputId, myId,
- hasBlocksInSplit(splitIndex))
-
- // Receive the file
- if (requestedFileLen != -1) {
- val readStartTime = System.currentTimeMillis
- logInfo("BEGIN READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit))
-
- // Receive data in an Array[Byte]
- var recvByteArray = new Array[Byte](requestedFileLen)
- var alreadyRead = 0
- var bytesRead = 0
-
- while (alreadyRead != requestedFileLen) {
- bytesRead = isSource.read(recvByteArray, alreadyRead,
- requestedFileLen - alreadyRead)
- if (bytesRead > 0) {
- alreadyRead = alreadyRead + bytesRead
- }
- }
-
- // Make it available to the consumer
- try {
- receivedData.put((splitIndex, recvByteArray))
- } catch {
- case e: Exception => {
- logInfo("Exception during putting data into receivedData")
- }
- }
-
- // TODO: Updating stats before consumption is completed
- hasBlocksInSplit(splitIndex) = hasBlocksInSplit(splitIndex) + 1
-
- // Split has been received only if all the blocks have been received
- if (hasBlocksInSplit(splitIndex) == totalBlocksInSplit(splitIndex)) {
- hasSplitsBitVector.synchronized {
- hasSplitsBitVector.set(splitIndex)
- hasSplits += 1
- }
- }
-
- // Consistent state in accounting variables
- receptionSucceeded = true
-
- logInfo("END READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit))
- val readTime = System.currentTimeMillis - readStartTime
- logInfo("Reading http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit) + " took " + readTime + " millis.")
- } else {
- throw new SparkException("ShuffleServer " + hostAddress + " does not have " + requestSplit)
- }
- }
- } catch {
- // EOFException is expected to happen because sender can break
- // connection due to timeout
- case eofe: java.io.EOFException => { }
- case e: Exception => {
- logInfo("ShuffleClient had a " + e)
- }
- } finally {
- splitsInRequestBitVector.synchronized {
- splitsInRequestBitVector.set(splitIndex, false)
- }
- cleanUpConnections()
- }
- }
-
- private def cleanUpConnections(): Unit = {
- if (oisSource != null) {
- oisSource.close()
- }
- if (oosSource != null) {
- oosSource.close()
- }
- if (peerSocketToSource != null) {
- peerSocketToSource.close()
- }
- }
- }
-}
-
-object CustomBlockedInMemoryShuffle extends Logging {
- // Cache for keeping the splits around
- val splitsCache = new HashMap[String, Array[Byte]]
-
- private var initialized = false
- private var nextShuffleId = new AtomicLong(0)
-
- // Variables initialized by initializeIfNeeded()
- private var shuffleDir: File = null
-
- private var shuffleServer: ShuffleServer = null
- private var serverAddress = InetAddress.getLocalHost.getHostAddress
- private var serverPort: Int = -1
-
- // Random number generator
- var ranGen = new Random
-
- private def initializeIfNeeded() = synchronized {
- if (!initialized) {
- // TODO: localDir should be created by some mechanism common to Spark
- // so that it can be shared among shuffle, broadcast, etc
- val localDirRoot = System.getProperty("spark.local.dir", "/tmp")
- var tries = 0
- var foundLocalDir = false
- var localDir: File = null
- var localDirUuid: UUID = null
- while (!foundLocalDir && tries < 10) {
- tries += 1
- try {
- localDirUuid = UUID.randomUUID
- localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
- if (!localDir.exists) {
- localDir.mkdirs()
- foundLocalDir = true
- }
- } catch {
- case e: Exception =>
- logWarning("Attempt " + tries + " to create local dir failed", e)
- }
- }
- if (!foundLocalDir) {
- logError("Failed 10 attempts to create local dir in " + localDirRoot)
- System.exit(1)
- }
- shuffleDir = new File(localDir, "shuffle")
- shuffleDir.mkdirs()
- logInfo("Shuffle dir: " + shuffleDir)
-
- // Create and start the shuffleServer
- shuffleServer = new ShuffleServer
- shuffleServer.setDaemon(true)
- shuffleServer.start()
- logInfo("ShuffleServer started...")
-
- initialized = true
- }
- }
-
- def getSplitName(shuffleId: Long, inputId: Int, outputId: Int,
- blockId: Int): String = {
- initializeIfNeeded()
- // Adding shuffleDir is unnecessary. Added to keep the parsers working
- return "%s/%d/%d/%d-%d".format(shuffleDir, shuffleId, inputId, outputId,
- blockId)
- }
-
- def getBlockNumOutputName(shuffleId: Long, inputId: Int,
- outputId: Int): String = {
- initializeIfNeeded()
- return "%s/%d/%d/BLOCKNUM-%d".format(shuffleDir, shuffleId, inputId,
- outputId)
- }
-
- def newShuffleId(): Long = {
- nextShuffleId.getAndIncrement()
- }
-
- class ShuffleServer
- extends Thread with Logging {
- var threadPool = Shuffle.newDaemonFixedThreadPool(Shuffle.MaxTxConnections)
-
- var serverSocket: ServerSocket = null
-
- override def run: Unit = {
- serverSocket = new ServerSocket(0)
- serverPort = serverSocket.getLocalPort
-
- logInfo("ShuffleServer started with " + serverSocket)
- logInfo("Local URI: http://" + serverAddress + ":" + serverPort)
-
- try {
- while (true) {
- var clientSocket: Socket = null
- try {
- clientSocket = serverSocket.accept()
- } catch {
- case e: Exception => { }
- }
- if (clientSocket != null) {
- logInfo("Serve: Accepted new client connection:" + clientSocket)
- try {
- threadPool.execute(new ShuffleServerThread(clientSocket))
- } catch {
- // In failure, close socket here; else, the thread will close it
- case ioe: IOException => {
- clientSocket.close()
- }
- }
- }
- }
- } finally {
- if (serverSocket != null) {
- logInfo("ShuffleServer now stopping...")
- serverSocket.close()
- }
- }
- // Shutdown the thread pool
- threadPool.shutdown()
- }
-
- class ShuffleServerThread(val clientSocket: Socket)
- extends Thread with Logging {
- private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream]
- os.flush()
- private val bos = new BufferedOutputStream(os)
- bos.flush()
- private val oos = new ObjectOutputStream(os)
- oos.flush()
- private val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- logInfo("new ShuffleServerThread is running")
-
- override def run: Unit = {
- try {
- // Receive basic path information
- val (shuffleId, myIndex, outputId) =
- ois.readObject.asInstanceOf[(Int, Int, Int)]
-
- var requestedSplitBase = "%s/%d/%d/%d".format(
- shuffleDir, shuffleId, myIndex, outputId)
- logInfo("requestedSplitBase: " + requestedSplitBase)
-
- // Read BLOCKNUM and send back the total number of blocks
- val blockNumName = "%s/%d/%d/BLOCKNUM-%d".format(shuffleDir,
- shuffleId, myIndex, outputId)
-
- val blockNumIn = new ObjectInputStream(new ByteArrayInputStream(
- CustomBlockedInMemoryShuffle.splitsCache(blockNumName)))
- val BLOCKNUM = blockNumIn.readObject.asInstanceOf[Int]
- blockNumIn.close()
-
- oos.writeObject(BLOCKNUM)
-
- val startTime = System.currentTimeMillis
- var curTime = startTime
- var keepSending = true
- var numBlocksToSend = Shuffle.MaxChatBlocks
-
- while (keepSending && numBlocksToSend > 0) {
- // Receive specific block request
- val blockId = ois.readObject.asInstanceOf[Int]
-
- // Ready to send
- var requestedSplit = requestedSplitBase + "-" + blockId
-
- // Send the length of the requestedSplit to let the receiver know that
- // transfer is about to start
- // In the case of receiver timeout and connection close, this will
- // throw a java.net.SocketException: Broken pipe
- var requestedSplitLen = -1
-
- try {
- requestedSplitLen =
- CustomBlockedInMemoryShuffle.splitsCache(requestedSplit).length
- } catch {
- case e: Exception => { }
- }
-
- oos.writeObject(requestedSplitLen)
- oos.flush()
-
- logInfo("requestedSplitLen = " + requestedSplitLen)
-
- // Read and send the requested file
- if (requestedSplitLen != -1) {
- // Send
- bos.write(CustomBlockedInMemoryShuffle.splitsCache(requestedSplit),
- 0, requestedSplitLen)
- bos.flush()
-
- // Update loop variables
- numBlocksToSend = numBlocksToSend - 1
-
- curTime = System.currentTimeMillis
- // Revoke sending only if there is anyone waiting in the queue
- if (curTime - startTime >= Shuffle.MaxChatTime &&
- threadPool.getQueue.size > 0) {
- keepSending = false
- }
- } else {
- // Close the connection
- }
- }
- } catch {
- // If something went wrong, e.g., the worker at the other end died etc
- // then close everything up
- // Exception can happen if the receiver stops receiving
- // EOFException is expected to happen because receiver can break
- // connection as soon as it has all the blocks
- case eofe: java.io.EOFException => { }
- case e: Exception => {
- logInfo("ShuffleServerThread had a " + e)
- }
- } finally {
- logInfo("ShuffleServerThread is closing streams and sockets")
- ois.close()
- // TODO: Following can cause "java.net.SocketException: Socket closed"
- oos.close()
- bos.close()
- clientSocket.close()
- }
- }
- }
- }
-}
diff --git a/core/src/main/scala/spark/CustomParallelFakeShuffle.scala b/core/src/main/scala/spark/CustomParallelFakeShuffle.scala
deleted file mode 100644
index 889c5111b6..0000000000
--- a/core/src/main/scala/spark/CustomParallelFakeShuffle.scala
+++ /dev/null
@@ -1,495 +0,0 @@
-package spark
-
-import java.io._
-import java.net._
-import java.util.{BitSet, Random, Timer, TimerTask, UUID}
-import java.util.concurrent.atomic.AtomicLong
-import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory}
-
-import scala.collection.mutable.{ArrayBuffer, HashMap}
-
-/**
- * TODO: THIS IS AN ABSOLUTELY EXPERIMENTAL IMPLEMENTATON (FOR NOW).
- *
- * An implementation of shuffle using fake data served through custom server
- * where receivers create simultaneous connections to multiple servers by
- * setting the 'spark.shuffle.maxRxConnections' config option.
- *
- * TODO: Add support for compression when spark.compress is set to true.
- */
-@serializable
-class CustomParallelFakeShuffle[K, V, C]
-extends Shuffle[K, V, C] with Logging {
- @transient var totalSplits = 0
- @transient var hasSplits = 0
- @transient var hasSplitsBitVector: BitSet = null
- @transient var splitsInRequestBitVector: BitSet = null
-
- @transient var combiners: HashMap[K,C] = null
-
- override def compute(input: RDD[(K, V)],
- numOutputSplits: Int,
- createCombiner: V => C,
- mergeValue: (C, V) => C,
- mergeCombiners: (C, C) => C)
- : RDD[(K, C)] =
- {
- val sc = input.sparkContext
- val shuffleId = CustomParallelFakeShuffle.newShuffleId()
- logInfo("Shuffle ID: " + shuffleId)
-
- val splitRdd = new NumberedSplitRDD(input)
- val numInputSplits = splitRdd.splits.size
-
- // Run a parallel map and collect to write the intermediate data files,
- // returning a list of inputSplitId -> serverUri pairs
- val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => {
- val myIndex = pair._1
- val myIterator = pair._2
- val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C])
- for ((k, v) <- myIterator) {
- var bucketId = k.hashCode % numOutputSplits
- if (bucketId < 0) { // Fix bucket ID if hash code was negative
- bucketId += numOutputSplits
- }
- val bucket = buckets(bucketId)
- bucket(k) = bucket.get(k) match {
- case Some(c) => mergeValue(c, v)
- case None => createCombiner(v)
- }
- }
-
- for (i <- 0 until numOutputSplits) {
- val splitName =
- CustomParallelFakeShuffle.getSplitName(shuffleId, myIndex, i)
-
- val writeStartTime = System.currentTimeMillis
- logInfo("BEGIN WRITE: " + splitName)
-
- // Write buckets(i) to a byte array & put in splitsCache instead of file
- val baos = new ByteArrayOutputStream
- val oos = new ObjectOutputStream(baos)
- buckets(i).foreach(pair => oos.writeObject(pair))
- oos.close
- baos.close
-
- // Store the length only
- CustomParallelFakeShuffle.splitsCache(splitName) = baos.toByteArray.length
- val splitLen = baos.toByteArray.length
-
- logInfo("END WRITE: " + splitName)
- val writeTime = System.currentTimeMillis - writeStartTime
- logInfo("Writing " + splitName + " of size " + splitLen + " bytes took " + writeTime + " millis.")
- }
-
- (myIndex, CustomParallelFakeShuffle.serverAddress,
- CustomParallelFakeShuffle.serverPort)
- }).collect()
-
- val splitsByUri = new ArrayBuffer[(String, Int, Int)]
- for ((inputId, serverAddress, serverPort) <- outputLocs) {
- splitsByUri += ((serverAddress, serverPort, inputId))
- }
-
- // TODO: Could broadcast splitsByUri
-
- // Return an RDD that does each of the merges for a given partition
- val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits)
- return indexes.flatMap((myId: Int) => {
- totalSplits = splitsByUri.size
- hasSplits = 0
- hasSplitsBitVector = new BitSet(totalSplits)
- splitsInRequestBitVector = new BitSet(totalSplits)
-
- combiners = new HashMap[K, C]
-
- var threadPool = Shuffle.newDaemonFixedThreadPool(
- Shuffle.MaxRxConnections)
-
- var totalThreadsCreated = 0
-
- while (hasSplits < totalSplits) {
- var numThreadsToCreate = Math.min(totalSplits,
- Shuffle.MaxRxConnections) -
- threadPool.getActiveCount
-
- while (hasSplits < totalSplits && numThreadsToCreate > 0) {
- // Select a random split to pull
- val splitIndex = selectRandomSplit
-
- if (splitIndex != -1) {
- val (serverAddress, serverPort, inputId) = splitsByUri(splitIndex)
- val requestSplit = "%d/%d/%d".format(shuffleId, inputId, myId)
-
- threadPool.execute(new ShuffleClient(splitIndex, serverAddress,
- serverPort, requestSplit))
- totalThreadsCreated += 1
- logInfo("totalThreadsCreated = %d / threadPool.getActiveCount = %d".format(totalThreadsCreated, threadPool.getActiveCount.toInt))
-
- // splitIndex is in transit. Will be unset in the ShuffleClient
- splitsInRequestBitVector.synchronized {
- splitsInRequestBitVector.set(splitIndex)
- }
- }
-
- numThreadsToCreate = numThreadsToCreate - 1
- }
-
- // Sleep for a while before creating new threads
- Thread.sleep(Shuffle.MinKnockInterval)
- }
-
- threadPool.shutdown()
-
- combiners
- })
- }
-
- def selectRandomSplit: Int = {
- var requiredSplits = new ArrayBuffer[Int]
-
- synchronized {
- for (i <- 0 until totalSplits) {
- if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) {
- requiredSplits += i
- }
- }
- }
-
- if (requiredSplits.size > 0) {
- requiredSplits(CustomParallelFakeShuffle.ranGen.nextInt(
- requiredSplits.size))
- } else {
- -1
- }
- }
-
- class ShuffleClient(splitIndex: Int, hostAddress: String, listenPort: Int,
- requestSplit: String)
- extends Thread with Logging {
- private var peerSocketToSource: Socket = null
- private var oosSource: ObjectOutputStream = null
- private var oisSource: ObjectInputStream = null
-
- private var receptionSucceeded = false
-
- override def run: Unit = {
- // Setup the timeout mechanism
- var timeOutTask = new TimerTask {
- override def run: Unit = {
- cleanUpConnections()
- }
- }
-
- var timeOutTimer = new Timer
- timeOutTimer.schedule(timeOutTask, Shuffle.MaxKnockInterval)
-
- logInfo("ShuffleClient started... => %s:%d#%s".format(hostAddress, listenPort, requestSplit))
-
- try {
- // Connect to the source
- peerSocketToSource = new Socket(hostAddress, listenPort)
- oosSource =
- new ObjectOutputStream(peerSocketToSource.getOutputStream)
- oosSource.flush()
- var isSource = peerSocketToSource.getInputStream
- oisSource = new ObjectInputStream(isSource)
-
- // Send the request
- oosSource.writeObject(requestSplit)
-
- // Receive the length of the requested file
- var requestedFileLen = oisSource.readObject.asInstanceOf[Int]
- logInfo("Received requestedFileLen = " + requestedFileLen)
-
- // Turn the timer OFF, if the sender responds before timeout
- timeOutTimer.cancel()
-
- // Receive the file
- if (requestedFileLen != -1) {
- val readStartTime = System.currentTimeMillis
- logInfo("BEGIN READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit))
-
-// var recvByteArray = new Array[Byte](requestedFileLen)
-// var alreadyRead = 0
-// var bytesRead = 0
-//
-// while (alreadyRead != requestedFileLen) {
-// bytesRead = isSource.read(recvByteArray, alreadyRead,
-// requestedFileLen - alreadyRead)
-// if (bytesRead > 0) {
-// alreadyRead = alreadyRead + bytesRead
-// }
-// }
-
- // Receive data in an Array[Byte]
- var recvByteArray = new Array[Byte](65536)
- var alreadyRead = 0
- var bytesRead = 0
-
- while (alreadyRead != requestedFileLen) {
- bytesRead = isSource.read(recvByteArray)
- if (bytesRead > 0) {
- alreadyRead = alreadyRead + bytesRead
- }
- }
-
- // TODO: Updating stats before consumption is completed
- hasSplitsBitVector.synchronized {
- hasSplitsBitVector.set(splitIndex)
- hasSplits += 1
- }
-
- // We have received splitIndex
- splitsInRequestBitVector.synchronized {
- splitsInRequestBitVector.set(splitIndex, false)
- }
-
- receptionSucceeded = true
-
- logInfo("END READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit))
- val readTime = System.currentTimeMillis - readStartTime
- logInfo("Reading http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit) + " took " + readTime + " millis.")
- } else {
- throw new SparkException("ShuffleServer " + hostAddress + " does not have " + requestSplit)
- }
- } catch {
- // EOFException is expected to happen because sender can break
- // connection due to timeout
- case eofe: java.io.EOFException => { }
- case e: Exception => {
- logInfo("ShuffleClient had a " + e)
- }
- } finally {
- // If reception failed, unset for future retry
- if (!receptionSucceeded) {
- splitsInRequestBitVector.synchronized {
- splitsInRequestBitVector.set(splitIndex, false)
- }
- }
- cleanUpConnections()
- }
- }
-
- private def cleanUpConnections(): Unit = {
- if (oisSource != null) {
- oisSource.close()
- }
- if (oosSource != null) {
- oosSource.close()
- }
- if (peerSocketToSource != null) {
- peerSocketToSource.close()
- }
- }
- }
-}
-
-object CustomParallelFakeShuffle extends Logging {
- // Cache for keeping the splits around
- val splitsCache = new HashMap[String, Int]
-
- private var initialized = false
- private var nextShuffleId = new AtomicLong(0)
-
- // Variables initialized by initializeIfNeeded()
- private var shuffleDir: File = null
-
- private var shuffleServer: ShuffleServer = null
- private var serverAddress = InetAddress.getLocalHost.getHostAddress
- private var serverPort: Int = -1
-
- // Random number generator
- var ranGen = new Random
-
- private def initializeIfNeeded() = synchronized {
- if (!initialized) {
- // TODO: localDir should be created by some mechanism common to Spark
- // so that it can be shared among shuffle, broadcast, etc
- val localDirRoot = System.getProperty("spark.local.dir", "/tmp")
- var tries = 0
- var foundLocalDir = false
- var localDir: File = null
- var localDirUuid: UUID = null
-
- while (!foundLocalDir && tries < 10) {
- tries += 1
- try {
- localDirUuid = UUID.randomUUID
- localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
- if (!localDir.exists) {
- localDir.mkdirs()
- foundLocalDir = true
- }
- } catch {
- case e: Exception =>
- logWarning("Attempt " + tries + " to create local dir failed", e)
- }
- }
- if (!foundLocalDir) {
- logError("Failed 10 attempts to create local dir in " + localDirRoot)
- System.exit(1)
- }
- shuffleDir = new File(localDir, "shuffle")
- shuffleDir.mkdirs()
- logInfo("Shuffle dir: " + shuffleDir)
-
- // Create and start the shuffleServer
- shuffleServer = new ShuffleServer
- shuffleServer.setDaemon(true)
- shuffleServer.start()
- logInfo("ShuffleServer started...")
-
- initialized = true
- }
- }
-
- def getSplitName(shuffleId: Long, inputId: Int, outputId: Int): String = {
- initializeIfNeeded()
- // Adding shuffleDir is unnecessary. Added to keep the parsers working
- return "%s/%d/%d/%d".format(shuffleDir, shuffleId, inputId, outputId)
- }
-
- def newShuffleId(): Long = {
- nextShuffleId.getAndIncrement()
- }
-
- // Returns a standard ThreadFactory except all threads are daemons
- private def newDaemonThreadFactory: ThreadFactory = {
- new ThreadFactory {
- def newThread(r: Runnable): Thread = {
- var t = Executors.defaultThreadFactory.newThread(r)
- t.setDaemon(true)
- return t
- }
- }
- }
-
- // Wrapper over newFixedThreadPool
- def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = {
- var threadPool =
- Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor]
-
- threadPool.setThreadFactory(newDaemonThreadFactory)
-
- return threadPool
- }
-
- class ShuffleServer
- extends Thread with Logging {
- var threadPool = newDaemonFixedThreadPool(Shuffle.MaxTxConnections)
-
- var serverSocket: ServerSocket = null
-
- override def run: Unit = {
- serverSocket = new ServerSocket(0)
- serverPort = serverSocket.getLocalPort
-
- logInfo("ShuffleServer started with " + serverSocket)
- logInfo("Local URI: http://" + serverAddress + ":" + serverPort)
-
- try {
- while (true) {
- var clientSocket: Socket = null
- try {
- clientSocket = serverSocket.accept()
- } catch {
- case e: Exception => { }
- }
- if (clientSocket != null) {
- logInfo("Serve: Accepted new client connection:" + clientSocket)
- try {
- threadPool.execute(new ShuffleServerThread(clientSocket))
- } catch {
- // In failure, close socket here; else, the thread will close it
- case ioe: IOException => {
- clientSocket.close()
- }
- }
- }
- }
- } finally {
- if (serverSocket != null) {
- logInfo("ShuffleServer now stopping...")
- serverSocket.close()
- }
- }
- // Shutdown the thread pool
- threadPool.shutdown()
- }
-
- class ShuffleServerThread(val clientSocket: Socket)
- extends Thread with Logging {
- private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream]
- os.flush()
- private val bos = new BufferedOutputStream(os)
- bos.flush()
- private val oos = new ObjectOutputStream(os)
- oos.flush()
- private val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- logInfo("new ShuffleServerThread is running")
-
- override def run: Unit = {
- try {
- // Receive requestedSplit from the receiver
- // Adding shuffleDir is unnecessary. Added to keep the parsers working
- var requestedSplit =
- shuffleDir + "/" + ois.readObject.asInstanceOf[String]
- logInfo("requestedSplit: " + requestedSplit)
-
- // Send the length of the requestedSplit to let the receiver know that
- // transfer is about to start
- // In the case of receiver timeout and connection close, this will
- // throw a java.net.SocketException: Broken pipe
- var requestedSplitLen = -1
-
- try {
- requestedSplitLen =
- CustomParallelFakeShuffle.splitsCache(requestedSplit)
- } catch {
- case e: Exception => { }
- }
-
- oos.writeObject(requestedSplitLen)
- oos.flush()
-
- logInfo("requestedSplitLen = " + requestedSplitLen)
-
- // Read and send the requested split
- if (requestedSplitLen != -1) {
- // Send fake data
-// var byteArray = new Array[Byte](requestedSplitLen)
-// bos.write(byteArray, 0, byteArray.length)
-// bos.flush()
-
- val buf = new Array[Byte](65536)
- var bytesSent = 0
- while (bytesSent < requestedSplitLen) {
- val bytesToSend = Math.min(requestedSplitLen - bytesSent, buf.length)
- bos.write(buf, 0, bytesToSend)
- bos.flush()
- bytesSent += bytesToSend
- }
- } else {
- // Close the connection
- }
- } catch {
- // If something went wrong, e.g., the worker at the other end died etc
- // then close everything up
- // Exception can happen if the receiver stops receiving
- case e: Exception => {
- logInfo("ShuffleServerThread had a " + e)
- }
- } finally {
- logInfo("ShuffleServerThread is closing streams and sockets")
- ois.close()
- // TODO: Following can cause "java.net.SocketException: Socket closed"
- oos.close()
- bos.close()
- clientSocket.close()
- }
- }
- }
- }
-}
diff --git a/core/src/main/scala/spark/CustomParallelInMemoryShuffle.scala b/core/src/main/scala/spark/CustomParallelInMemoryShuffle.scala
deleted file mode 100644
index 48f3685a1a..0000000000
--- a/core/src/main/scala/spark/CustomParallelInMemoryShuffle.scala
+++ /dev/null
@@ -1,535 +0,0 @@
-package spark
-
-import java.io._
-import java.net._
-import java.util.{BitSet, Random, Timer, TimerTask, UUID}
-import java.util.concurrent.atomic.AtomicLong
-import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory}
-
-import scala.collection.mutable.{ArrayBuffer, HashMap}
-
-/**
- * TODO: THIS IS AN ABSOLUTELY EXPERIMENTAL IMPLEMENTATON (FOR NOW).
- *
- * An implementation of shuffle using local memory served through custom server
- * where receivers create simultaneous connections to multiple servers by
- * setting the 'spark.shuffle.maxRxConnections' config option.
- *
- * TODO: Add support for compression when spark.compress is set to true.
- */
-@serializable
-class CustomParallelInMemoryShuffle[K, V, C]
-extends Shuffle[K, V, C] with Logging {
- @transient var totalSplits = 0
- @transient var hasSplits = 0
- @transient var hasSplitsBitVector: BitSet = null
- @transient var splitsInRequestBitVector: BitSet = null
-
- @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null
- @transient var combiners: HashMap[K,C] = null
-
- override def compute(input: RDD[(K, V)],
- numOutputSplits: Int,
- createCombiner: V => C,
- mergeValue: (C, V) => C,
- mergeCombiners: (C, C) => C)
- : RDD[(K, C)] =
- {
- val sc = input.sparkContext
- val shuffleId = CustomParallelInMemoryShuffle.newShuffleId()
- logInfo("Shuffle ID: " + shuffleId)
-
- val splitRdd = new NumberedSplitRDD(input)
- val numInputSplits = splitRdd.splits.size
-
- // Run a parallel map and collect to write the intermediate data files,
- // returning a list of inputSplitId -> serverUri pairs
- val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => {
- val myIndex = pair._1
- val myIterator = pair._2
- val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C])
- for ((k, v) <- myIterator) {
- var bucketId = k.hashCode % numOutputSplits
- if (bucketId < 0) { // Fix bucket ID if hash code was negative
- bucketId += numOutputSplits
- }
- val bucket = buckets(bucketId)
- bucket(k) = bucket.get(k) match {
- case Some(c) => mergeValue(c, v)
- case None => createCombiner(v)
- }
- }
-
- for (i <- 0 until numOutputSplits) {
- val splitName =
- CustomParallelInMemoryShuffle.getSplitName(shuffleId, myIndex, i)
-
- val writeStartTime = System.currentTimeMillis
- logInfo("BEGIN WRITE: " + splitName)
-
- // Write buckets(i) to a byte array & put in splitsCache instead of file
- val baos = new ByteArrayOutputStream
- val oos = new ObjectOutputStream(baos)
- buckets(i).foreach(pair => oos.writeObject(pair))
- oos.close
- baos.close
-
- CustomParallelInMemoryShuffle.splitsCache(splitName) = baos.toByteArray
- val splitLen =
- CustomParallelInMemoryShuffle.splitsCache(splitName).length
-
- logInfo("END WRITE: " + splitName)
- val writeTime = System.currentTimeMillis - writeStartTime
- logInfo("Writing " + splitName + " of size " + splitLen + " bytes took " + writeTime + " millis.")
- }
-
- (myIndex, CustomParallelInMemoryShuffle.serverAddress,
- CustomParallelInMemoryShuffle.serverPort)
- }).collect()
-
- val splitsByUri = new ArrayBuffer[(String, Int, Int)]
- for ((inputId, serverAddress, serverPort) <- outputLocs) {
- splitsByUri += ((serverAddress, serverPort, inputId))
- }
-
- // TODO: Could broadcast splitsByUri
-
- // Return an RDD that does each of the merges for a given partition
- val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits)
- return indexes.flatMap((myId: Int) => {
- totalSplits = splitsByUri.size
- hasSplits = 0
- hasSplitsBitVector = new BitSet(totalSplits)
- splitsInRequestBitVector = new BitSet(totalSplits)
-
- receivedData = new LinkedBlockingQueue[(Int, Array[Byte])]
- combiners = new HashMap[K, C]
-
- var threadPool = Shuffle.newDaemonFixedThreadPool(
- Shuffle.MaxRxConnections)
-
- while (hasSplits < totalSplits) {
- var numThreadsToCreate = Math.min(totalSplits,
- Shuffle.MaxRxConnections) -
- threadPool.getActiveCount
-
- while (hasSplits < totalSplits && numThreadsToCreate > 0) {
- // Select a random split to pull
- val splitIndex = selectRandomSplit
-
- if (splitIndex != -1) {
- val (serverAddress, serverPort, inputId) = splitsByUri(splitIndex)
- val requestSplit = "%d/%d/%d".format(shuffleId, inputId, myId)
-
- threadPool.execute(new ShuffleClient(splitIndex, serverAddress,
- serverPort, requestSplit))
-
- // splitIndex is in transit. Will be unset in the ShuffleClient
- splitsInRequestBitVector.synchronized {
- splitsInRequestBitVector.set(splitIndex)
- }
- }
-
- numThreadsToCreate = numThreadsToCreate - 1
- }
-
- // Sleep for a while before creating new threads
- Thread.sleep(Shuffle.MinKnockInterval)
- }
-
- threadPool.shutdown()
-
- // Start consumer
- // TODO: Consumption is delayed until everything has been received.
- // Otherwise it interferes with network performance
- var shuffleConsumer = new ShuffleConsumer(mergeCombiners)
- shuffleConsumer.setDaemon(true)
- shuffleConsumer.start()
- logInfo("ShuffleConsumer started...")
-
- // Don't return until consumption is finished
- // while (receivedData.size > 0) {
- // Thread.sleep(Shuffle.MinKnockInterval)
- // }
-
- // Wait till shuffleConsumer is done
- shuffleConsumer.join
-
- combiners
- })
- }
-
- def selectRandomSplit: Int = {
- var requiredSplits = new ArrayBuffer[Int]
-
- synchronized {
- for (i <- 0 until totalSplits) {
- if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) {
- requiredSplits += i
- }
- }
- }
-
- if (requiredSplits.size > 0) {
- requiredSplits(CustomParallelInMemoryShuffle.ranGen.nextInt(
- requiredSplits.size))
- } else {
- -1
- }
- }
-
- class ShuffleConsumer(mergeCombiners: (C, C) => C)
- extends Thread with Logging {
- override def run: Unit = {
- // Run until all splits are here
- while (receivedData.size > 0) {
- var splitIndex = -1
- var recvByteArray: Array[Byte] = null
-
- try {
- var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])]
- splitIndex = tempPair._1
- recvByteArray = tempPair._2
- } catch {
- case e: Exception => {
- logInfo("Exception during taking data from receivedData")
- }
- }
-
- val inputStream =
- new ObjectInputStream(new ByteArrayInputStream(recvByteArray))
-
- try{
- while (true) {
- val (k, c) = inputStream.readObject.asInstanceOf[(K, C)]
- combiners(k) = combiners.get(k) match {
- case Some(oldC) => mergeCombiners(oldC, c)
- case None => c
- }
- }
- } catch {
- case e: EOFException => { }
- }
- inputStream.close()
- }
- }
- }
-
- class ShuffleClient(splitIndex: Int, hostAddress: String, listenPort: Int,
- requestSplit: String)
- extends Thread with Logging {
- private var peerSocketToSource: Socket = null
- private var oosSource: ObjectOutputStream = null
- private var oisSource: ObjectInputStream = null
-
- private var receptionSucceeded = false
-
- override def run: Unit = {
- // Setup the timeout mechanism
- var timeOutTask = new TimerTask {
- override def run: Unit = {
- cleanUpConnections()
- }
- }
-
- var timeOutTimer = new Timer
- timeOutTimer.schedule(timeOutTask, Shuffle.MaxKnockInterval)
-
- logInfo("ShuffleClient started... => %s:%d#%s".format(hostAddress, listenPort, requestSplit))
-
- try {
- // Connect to the source
- peerSocketToSource = new Socket(hostAddress, listenPort)
- oosSource =
- new ObjectOutputStream(peerSocketToSource.getOutputStream)
- oosSource.flush()
- var isSource = peerSocketToSource.getInputStream
- oisSource = new ObjectInputStream(isSource)
-
- // Send the request
- oosSource.writeObject(requestSplit)
-
- // Receive the length of the requested file
- var requestedFileLen = oisSource.readObject.asInstanceOf[Int]
- logInfo("Received requestedFileLen = " + requestedFileLen)
-
- // Turn the timer OFF, if the sender responds before timeout
- timeOutTimer.cancel()
-
- // Receive the file
- if (requestedFileLen != -1) {
- val readStartTime = System.currentTimeMillis
- logInfo("BEGIN READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit))
-
- // Receive data in an Array[Byte]
- var recvByteArray = new Array[Byte](requestedFileLen)
- var alreadyRead = 0
- var bytesRead = 0
-
- while (alreadyRead != requestedFileLen) {
- bytesRead = isSource.read(recvByteArray, alreadyRead,
- requestedFileLen - alreadyRead)
- if (bytesRead > 0) {
- alreadyRead = alreadyRead + bytesRead
- }
- }
-
- // Make it available to the consumer
- try {
- receivedData.put((splitIndex, recvByteArray))
- } catch {
- case e: Exception => {
- logInfo("Exception during putting data into receivedData")
- }
- }
-
- // TODO: Updating stats before consumption is completed
- hasSplitsBitVector.synchronized {
- hasSplitsBitVector.set(splitIndex)
- hasSplits += 1
- }
-
- // We have received splitIndex
- splitsInRequestBitVector.synchronized {
- splitsInRequestBitVector.set(splitIndex, false)
- }
-
- receptionSucceeded = true
-
- logInfo("END READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit))
- val readTime = System.currentTimeMillis - readStartTime
- logInfo("Reading http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit) + " took " + readTime + " millis.")
- } else {
- throw new SparkException("ShuffleServer " + hostAddress + " does not have " + requestSplit)
- }
- } catch {
- // EOFException is expected to happen because sender can break
- // connection due to timeout
- case eofe: java.io.EOFException => { }
- case e: Exception => {
- logInfo("ShuffleClient had a " + e)
- }
- } finally {
- // If reception failed, unset for future retry
- if (!receptionSucceeded) {
- splitsInRequestBitVector.synchronized {
- splitsInRequestBitVector.set(splitIndex, false)
- }
- }
- cleanUpConnections()
- }
- }
-
- private def cleanUpConnections(): Unit = {
- if (oisSource != null) {
- oisSource.close()
- }
- if (oosSource != null) {
- oosSource.close()
- }
- if (peerSocketToSource != null) {
- peerSocketToSource.close()
- }
- }
- }
-}
-
-object CustomParallelInMemoryShuffle extends Logging {
- // Cache for keeping the splits around
- val splitsCache = new HashMap[String, Array[Byte]]
-
- private var initialized = false
- private var nextShuffleId = new AtomicLong(0)
-
- // Variables initialized by initializeIfNeeded()
- private var shuffleDir: File = null
-
- private var shuffleServer: ShuffleServer = null
- private var serverAddress = InetAddress.getLocalHost.getHostAddress
- private var serverPort: Int = -1
-
- // Random number generator
- var ranGen = new Random
-
- private def initializeIfNeeded() = synchronized {
- if (!initialized) {
- // TODO: localDir should be created by some mechanism common to Spark
- // so that it can be shared among shuffle, broadcast, etc
- val localDirRoot = System.getProperty("spark.local.dir", "/tmp")
- var tries = 0
- var foundLocalDir = false
- var localDir: File = null
- var localDirUuid: UUID = null
-
- while (!foundLocalDir && tries < 10) {
- tries += 1
- try {
- localDirUuid = UUID.randomUUID
- localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
- if (!localDir.exists) {
- localDir.mkdirs()
- foundLocalDir = true
- }
- } catch {
- case e: Exception =>
- logWarning("Attempt " + tries + " to create local dir failed", e)
- }
- }
- if (!foundLocalDir) {
- logError("Failed 10 attempts to create local dir in " + localDirRoot)
- System.exit(1)
- }
- shuffleDir = new File(localDir, "shuffle")
- shuffleDir.mkdirs()
- logInfo("Shuffle dir: " + shuffleDir)
-
- // Create and start the shuffleServer
- shuffleServer = new ShuffleServer
- shuffleServer.setDaemon(true)
- shuffleServer.start()
- logInfo("ShuffleServer started...")
-
- initialized = true
- }
- }
-
- def getSplitName(shuffleId: Long, inputId: Int, outputId: Int): String = {
- initializeIfNeeded()
- // Adding shuffleDir is unnecessary. Added to keep the parsers working
- return "%s/%d/%d/%d".format(shuffleDir, shuffleId, inputId, outputId)
- }
-
- def newShuffleId(): Long = {
- nextShuffleId.getAndIncrement()
- }
-
- // Returns a standard ThreadFactory except all threads are daemons
- private def newDaemonThreadFactory: ThreadFactory = {
- new ThreadFactory {
- def newThread(r: Runnable): Thread = {
- var t = Executors.defaultThreadFactory.newThread(r)
- t.setDaemon(true)
- return t
- }
- }
- }
-
- // Wrapper over newFixedThreadPool
- def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = {
- var threadPool =
- Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor]
-
- threadPool.setThreadFactory(newDaemonThreadFactory)
-
- return threadPool
- }
-
- class ShuffleServer
- extends Thread with Logging {
- var threadPool = newDaemonFixedThreadPool(Shuffle.MaxTxConnections)
-
- var serverSocket: ServerSocket = null
-
- override def run: Unit = {
- serverSocket = new ServerSocket(0)
- serverPort = serverSocket.getLocalPort
-
- logInfo("ShuffleServer started with " + serverSocket)
- logInfo("Local URI: http://" + serverAddress + ":" + serverPort)
-
- try {
- while (true) {
- var clientSocket: Socket = null
- try {
- clientSocket = serverSocket.accept()
- } catch {
- case e: Exception => { }
- }
- if (clientSocket != null) {
- logInfo("Serve: Accepted new client connection:" + clientSocket)
- try {
- threadPool.execute(new ShuffleServerThread(clientSocket))
- } catch {
- // In failure, close socket here; else, the thread will close it
- case ioe: IOException => {
- clientSocket.close()
- }
- }
- }
- }
- } finally {
- if (serverSocket != null) {
- logInfo("ShuffleServer now stopping...")
- serverSocket.close()
- }
- }
- // Shutdown the thread pool
- threadPool.shutdown()
- }
-
- class ShuffleServerThread(val clientSocket: Socket)
- extends Thread with Logging {
- private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream]
- os.flush()
- private val bos = new BufferedOutputStream(os)
- bos.flush()
- private val oos = new ObjectOutputStream(os)
- oos.flush()
- private val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- logInfo("new ShuffleServerThread is running")
-
- override def run: Unit = {
- try {
- // Receive requestedSplit from the receiver
- // Adding shuffleDir is unnecessary. Added to keep the parsers working
- var requestedSplit =
- shuffleDir + "/" + ois.readObject.asInstanceOf[String]
- logInfo("requestedSplit: " + requestedSplit)
-
- // Send the length of the requestedSplit to let the receiver know that
- // transfer is about to start
- // In the case of receiver timeout and connection close, this will
- // throw a java.net.SocketException: Broken pipe
- var requestedSplitLen = -1
-
- try {
- requestedSplitLen =
- CustomParallelInMemoryShuffle.splitsCache(requestedSplit).length
- } catch {
- case e: Exception => { }
- }
-
- oos.writeObject(requestedSplitLen)
- oos.flush()
-
- logInfo("requestedSplitLen = " + requestedSplitLen)
-
- // Read and send the requested split
- if (requestedSplitLen != -1) {
- // Send
- bos.write(CustomParallelInMemoryShuffle.splitsCache(requestedSplit),
- 0, requestedSplitLen)
- bos.flush()
- } else {
- // Close the connection
- }
- } catch {
- // If something went wrong, e.g., the worker at the other end died etc
- // then close everything up
- // Exception can happen if the receiver stops receiving
- case e: Exception => {
- logInfo("ShuffleServerThread had a " + e)
- }
- } finally {
- logInfo("ShuffleServerThread is closing streams and sockets")
- ois.close()
- // TODO: Following can cause "java.net.SocketException: Socket closed"
- oos.close()
- bos.close()
- clientSocket.close()
- }
- }
- }
- }
-}
diff --git a/core/src/main/scala/spark/HttpBlockedLocalFileShuffle.scala b/core/src/main/scala/spark/HttpBlockedLocalFileShuffle.scala
deleted file mode 100644
index 8e89cadfdd..0000000000
--- a/core/src/main/scala/spark/HttpBlockedLocalFileShuffle.scala
+++ /dev/null
@@ -1,471 +0,0 @@
-package spark
-
-import java.io._
-import java.net._
-import java.util.{BitSet, Random, Timer, TimerTask, UUID}
-import java.util.concurrent.atomic.AtomicLong
-import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory}
-
-import scala.collection.mutable.{ArrayBuffer, HashMap}
-
-/**
- * An implementation of shuffle using local files served through HTTP where
- * receivers create simultaneous connections to multiple servers by setting the
- * 'spark.shuffle.maxRxConnections' config option.
- *
- * By controlling the 'spark.shuffle.blockSize' config option one can also
- * control the largest block size to retrieve by each reducers. An INDEX file
- * keeps track of block boundaries instead of creating many smaller files.
- *
- * TODO: Add support for compression when spark.compress is set to true.
- */
-@serializable
-class HttpBlockedLocalFileShuffle[K, V, C]
-extends Shuffle[K, V, C] with Logging {
- @transient var totalSplits = 0
- @transient var hasSplits = 0
-
- @transient var blocksInSplit: Array[ArrayBuffer[Long]] = null
- @transient var totalBlocksInSplit: Array[Int] = null
- @transient var hasBlocksInSplit: Array[Int] = null
-
- @transient var hasSplitsBitVector: BitSet = null
- @transient var splitsInRequestBitVector: BitSet = null
-
- @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null
- @transient var combiners: HashMap[K,C] = null
-
- override def compute(input: RDD[(K, V)],
- numOutputSplits: Int,
- createCombiner: V => C,
- mergeValue: (C, V) => C,
- mergeCombiners: (C, C) => C)
- : RDD[(K, C)] =
- {
- val sc = input.sparkContext
- val shuffleId = HttpBlockedLocalFileShuffle.newShuffleId()
- logInfo("Shuffle ID: " + shuffleId)
-
- val splitRdd = new NumberedSplitRDD(input)
- val numInputSplits = splitRdd.splits.size
-
- // Run a parallel map and collect to write the intermediate data files,
- // returning a list of inputSplitId -> serverUri pairs
- val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => {
- val myIndex = pair._1
- val myIterator = pair._2
- val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C])
- for ((k, v) <- myIterator) {
- var bucketId = k.hashCode % numOutputSplits
- if (bucketId < 0) { // Fix bucket ID if hash code was negative
- bucketId += numOutputSplits
- }
- val bucket = buckets(bucketId)
- bucket(k) = bucket.get(k) match {
- case Some(c) => mergeValue(c, v)
- case None => createCombiner(v)
- }
- }
-
- for (i <- 0 until numOutputSplits) {
- // Open the INDEX file
- var indexFile: File =
- HttpBlockedLocalFileShuffle.getBlockIndexOutputFile(shuffleId,
- myIndex, i)
- var indexOut = new ObjectOutputStream(new FileOutputStream(indexFile))
- var indexDirty: Boolean = true
- var alreadyWritten: Long = 0
-
- // Open the actual file
- var file: File =
- HttpBlockedLocalFileShuffle.getOutputFile(shuffleId, myIndex, i)
- val out = new ObjectOutputStream(new FileOutputStream(file))
-
- val writeStartTime = System.currentTimeMillis
- logInfo("BEGIN WRITE: " + file)
-
- buckets(i).foreach(pair => {
- out.writeObject(pair)
- out.flush()
- indexDirty = true
-
- // Update the INDEX file if more than blockSize limit has been written
- if (file.length - alreadyWritten > Shuffle.BlockSize) {
- indexOut.writeObject(file.length)
- indexDirty = false
- alreadyWritten = file.length
- }
- })
-
- // Write down the last range if it was not written
- if (indexDirty) {
- indexOut.writeObject(file.length)
- }
-
- out.close()
- indexOut.close()
-
- logInfo("END WRITE: " + file)
- val writeTime = (System.currentTimeMillis - writeStartTime)
- logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.")
- }
-
- (myIndex, HttpBlockedLocalFileShuffle.serverUri)
- }).collect()
-
- // TODO: Could broadcast outputLocs
-
- // Return an RDD that does each of the merges for a given partition
- val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits)
- return indexes.flatMap((myId: Int) => {
- totalSplits = outputLocs.size
- hasSplits = 0
-
- blocksInSplit = Array.tabulate(totalSplits)(_ => new ArrayBuffer[Long])
- totalBlocksInSplit = Array.tabulate(totalSplits)(_ => -1)
- hasBlocksInSplit = Array.tabulate(totalSplits)(_ => 0)
-
- hasSplitsBitVector = new BitSet(totalSplits)
- splitsInRequestBitVector = new BitSet(totalSplits)
-
- receivedData = new LinkedBlockingQueue[(Int, Array[Byte])]
- combiners = new HashMap[K, C]
-
- var threadPool = Shuffle.newDaemonFixedThreadPool(
- Shuffle.MaxRxConnections)
-
- while (hasSplits < totalSplits) {
- var numThreadsToCreate =
- Math.min(totalSplits, Shuffle.MaxRxConnections) -
- threadPool.getActiveCount
-
- while (hasSplits < totalSplits && numThreadsToCreate > 0) {
- // Select a random split to pull
- val splitIndex = selectRandomSplit
-
- if (splitIndex != -1) {
- val (inputId, serverUri) = outputLocs(splitIndex)
-
- threadPool.execute(new ShuffleClient(serverUri, shuffleId.toInt,
- inputId, myId, splitIndex))
-
- // splitIndex is in transit. Will be unset in the ShuffleClient
- splitsInRequestBitVector.synchronized {
- splitsInRequestBitVector.set(splitIndex)
- }
- }
-
- numThreadsToCreate = numThreadsToCreate - 1
- }
-
- // Sleep for a while before creating new threads
- Thread.sleep(Shuffle.MinKnockInterval)
- }
-
- threadPool.shutdown()
-
- // Start consumer
- // TODO: Consumption is delayed until everything has been received.
- // Otherwise it interferes with network performance
- var shuffleConsumer = new ShuffleConsumer(mergeCombiners)
- shuffleConsumer.setDaemon(true)
- shuffleConsumer.start()
- logInfo("ShuffleConsumer started...")
-
- // Don't return until consumption is finished
- // while (receivedData.size > 0) {
- // Thread.sleep(Shuffle.MinKnockInterval)
- // }
-
- // Wait till shuffleConsumer is done
- shuffleConsumer.join
-
- combiners
- })
- }
-
- def selectRandomSplit: Int = {
- var requiredSplits = new ArrayBuffer[Int]
-
- synchronized {
- for (i <- 0 until totalSplits) {
- if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) {
- requiredSplits += i
- }
- }
- }
-
- if (requiredSplits.size > 0) {
- requiredSplits(HttpBlockedLocalFileShuffle.ranGen.nextInt(
- requiredSplits.size))
- } else {
- -1
- }
- }
-
- class ShuffleConsumer(mergeCombiners: (C, C) => C)
- extends Thread with Logging {
- override def run: Unit = {
- // Run until all splits are here
- while (receivedData.size > 0) {
- var splitIndex = -1
- var recvByteArray: Array[Byte] = null
-
- try {
- var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])]
- splitIndex = tempPair._1
- recvByteArray = tempPair._2
- } catch {
- case e: Exception => {
- logInfo("Exception during taking data from receivedData")
- }
- }
-
- val inputStream =
- new ObjectInputStream(new ByteArrayInputStream(recvByteArray))
-
- try{
- while (true) {
- val (k, c) = inputStream.readObject.asInstanceOf[(K, C)]
- combiners(k) = combiners.get(k) match {
- case Some(oldC) => mergeCombiners(oldC, c)
- case None => c
- }
- }
- } catch {
- case e: EOFException => { }
- }
- inputStream.close()
- }
- }
- }
-
- class ShuffleClient(serverUri: String, shuffleId: Int,
- inputId: Int, myId: Int, splitIndex: Int)
- extends Thread with Logging {
- private var receptionSucceeded = false
-
- override def run: Unit = {
- try {
- // First get the INDEX file if totalBlocksInSplit(splitIndex) is unknown
- if (totalBlocksInSplit(splitIndex) == -1) {
- val url = "%s/shuffle/%d/%d/INDEX-%d".format(serverUri, shuffleId,
- inputId, myId)
- val inputStream = new ObjectInputStream(new URL(url).openStream())
-
- try {
- while (true) {
- blocksInSplit(splitIndex) +=
- inputStream.readObject().asInstanceOf[Long]
- }
- } catch {
- case e: EOFException => {}
- }
-
- totalBlocksInSplit(splitIndex) = blocksInSplit(splitIndex).size
- inputStream.close()
- }
-
- // Open connection
- val urlString =
- "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, inputId, myId)
- val url = new URL(urlString)
- val httpConnection =
- url.openConnection().asInstanceOf[HttpURLConnection]
-
- // Set the range to download
- val blockStartsAt = hasBlocksInSplit(splitIndex) match {
- case 0 => 0
- case _ => blocksInSplit(splitIndex)(hasBlocksInSplit(splitIndex) - 1) + 1
- }
- val blockEndsAt = blocksInSplit(splitIndex)(hasBlocksInSplit(splitIndex))
- httpConnection.setRequestProperty("Range",
- "bytes=" + blockStartsAt + "-" + blockEndsAt)
-
- // Connect to the server
- httpConnection.connect()
-
- val urStringWithRange =
- urlString + "[%d:%d]".format(blockStartsAt, blockEndsAt)
- val readStartTime = System.currentTimeMillis
- logInfo("BEGIN READ: " + urStringWithRange)
-
- // Receive data in an Array[Byte]
- val requestedFileLen: Int = (blockEndsAt - blockStartsAt).toInt + 1
- var recvByteArray = new Array[Byte](requestedFileLen)
- var alreadyRead = 0
- var bytesRead = 0
-
- val isSource = httpConnection.getInputStream()
- while (alreadyRead != requestedFileLen) {
- bytesRead = isSource.read(recvByteArray, alreadyRead,
- requestedFileLen - alreadyRead)
- if (bytesRead > 0) {
- alreadyRead = alreadyRead + bytesRead
- }
- }
-
- // Disconnect
- httpConnection.disconnect()
-
- // Make it available to the consumer
- try {
- receivedData.put((splitIndex, recvByteArray))
- } catch {
- case e: Exception => {
- logInfo("Exception during putting data into receivedData")
- }
- }
-
- // TODO: Updating stats before consumption is completed
- hasBlocksInSplit(splitIndex) = hasBlocksInSplit(splitIndex) + 1
-
- // Split has been received only if all the blocks have been received
- if (hasBlocksInSplit(splitIndex) == totalBlocksInSplit(splitIndex)) {
- hasSplitsBitVector.synchronized {
- hasSplitsBitVector.set(splitIndex)
- hasSplits += 1
- }
- }
-
- // We have received splitIndex
- splitsInRequestBitVector.synchronized {
- splitsInRequestBitVector.set(splitIndex, false)
- }
-
- receptionSucceeded = true
-
- logInfo("END READ: " + urStringWithRange)
- val readTime = System.currentTimeMillis - readStartTime
- logInfo("Reading " + urStringWithRange + " took " + readTime + " millis.")
- } catch {
- // EOFException is expected to happen because sender can break
- // connection due to timeout
- case eofe: java.io.EOFException => { }
- case e: Exception => {
- logInfo("ShuffleClient had a " + e)
- }
- } finally {
- // If reception failed, unset for future retry
- if (!receptionSucceeded) {
- splitsInRequestBitVector.synchronized {
- splitsInRequestBitVector.set(splitIndex, false)
- }
- }
- }
- }
- }
-}
-
-object HttpBlockedLocalFileShuffle extends Logging {
- private var initialized = false
- private var nextShuffleId = new AtomicLong(0)
-
- // Variables initialized by initializeIfNeeded()
- private var shuffleDir: File = null
- private var server: HttpServer = null
- private var serverUri: String = null
-
- // Random number generator
- var ranGen = new Random
-
- private def initializeIfNeeded() = synchronized {
- if (!initialized) {
- // TODO: localDir should be created by some mechanism common to Spark
- // so that it can be shared among shuffle, broadcast, etc
- val localDirRoot = System.getProperty("spark.local.dir", "/tmp")
- var tries = 0
- var foundLocalDir = false
- var localDir: File = null
- var localDirUuid: UUID = null
- while (!foundLocalDir && tries < 10) {
- tries += 1
- try {
- localDirUuid = UUID.randomUUID
- localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
- if (!localDir.exists) {
- localDir.mkdirs()
- foundLocalDir = true
- }
- } catch {
- case e: Exception =>
- logWarning("Attempt " + tries + " to create local dir failed", e)
- }
- }
- if (!foundLocalDir) {
- logError("Failed 10 attempts to create local dir in " + localDirRoot)
- System.exit(1)
- }
- shuffleDir = new File(localDir, "shuffle")
- shuffleDir.mkdirs()
- logInfo("Shuffle dir: " + shuffleDir)
-
- val extServerPort = System.getProperty(
- "spark.localFileShuffle.external.server.port", "-1").toInt
- if (extServerPort != -1) {
- // We're using an external HTTP server; set URI relative to its root
- var extServerPath = System.getProperty(
- "spark.localFileShuffle.external.server.path", "")
- if (extServerPath != "" && !extServerPath.endsWith("/")) {
- extServerPath += "/"
- }
- serverUri = "http://%s:%d/%s/spark-local-%s".format(
- Utils.localIpAddress, extServerPort, extServerPath, localDirUuid)
- } else {
- // Create our own server
- server = new HttpServer(localDir)
- server.start()
- serverUri = server.uri
- }
- initialized = true
- logInfo("Local URI: " + serverUri)
- }
- }
-
- def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int): File = {
- initializeIfNeeded()
- val dir = new File(shuffleDir, shuffleId + "/" + inputId)
- dir.mkdirs()
- val file = new File(dir, "" + outputId)
- return file
- }
-
- def getBlockIndexOutputFile(shuffleId: Long, inputId: Int,
- outputId: Int): File = {
- initializeIfNeeded()
- val dir = new File(shuffleDir, shuffleId + "/" + inputId)
- dir.mkdirs()
- val file = new File(dir, "INDEX-" + outputId)
- return file
- }
-
- def getServerUri(): String = {
- initializeIfNeeded()
- serverUri
- }
-
- def newShuffleId(): Long = {
- nextShuffleId.getAndIncrement()
- }
-
- // Returns a standard ThreadFactory except all threads are daemons
- private def newDaemonThreadFactory: ThreadFactory = {
- new ThreadFactory {
- def newThread(r: Runnable): Thread = {
- var t = Executors.defaultThreadFactory.newThread(r)
- t.setDaemon(true)
- return t
- }
- }
- }
-
- // Wrapper over newFixedThreadPool
- def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = {
- var threadPool =
- Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor]
-
- threadPool.setThreadFactory(newDaemonThreadFactory)
-
- return threadPool
- }
-}
diff --git a/core/src/main/scala/spark/ManualBlockedLocalFileShuffle.scala b/core/src/main/scala/spark/ManualBlockedLocalFileShuffle.scala
deleted file mode 100644
index cfd84fdb83..0000000000
--- a/core/src/main/scala/spark/ManualBlockedLocalFileShuffle.scala
+++ /dev/null
@@ -1,465 +0,0 @@
-package spark
-
-import java.io._
-import java.net._
-import java.util.{BitSet, Random, Timer, TimerTask, UUID}
-import java.util.concurrent.atomic.AtomicLong
-import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory}
-
-import scala.collection.mutable.{ArrayBuffer, HashMap}
-
-/**
- * An implementation of shuffle using local files served through HTTP where
- * receivers create simultaneous connections to multiple servers by setting the
- * 'spark.shuffle.maxRxConnections' config option.
- *
- * By controlling the 'spark.shuffle.blockSize' config option one can also
- * control the largest block size to divide each map output into. Essentially,
- * instead of creating one large output file for each reducer, maps create
- * multiple smaller files to enable finer level of engagement.
- *
- * TODO: Add support for compression when spark.compress is set to true.
- */
-@serializable
-class ManualBlockedLocalFileShuffle[K, V, C]
-extends Shuffle[K, V, C] with Logging {
- @transient var totalSplits = 0
- @transient var hasSplits = 0
-
- @transient var totalBlocksInSplit: Array[Int] = null
- @transient var hasBlocksInSplit: Array[Int] = null
-
- @transient var hasSplitsBitVector: BitSet = null
- @transient var splitsInRequestBitVector: BitSet = null
-
- @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null
- @transient var combiners: HashMap[K,C] = null
-
- override def compute(input: RDD[(K, V)],
- numOutputSplits: Int,
- createCombiner: V => C,
- mergeValue: (C, V) => C,
- mergeCombiners: (C, C) => C)
- : RDD[(K, C)] =
- {
- val sc = input.sparkContext
- val shuffleId = ManualBlockedLocalFileShuffle.newShuffleId()
- logInfo("Shuffle ID: " + shuffleId)
-
- val splitRdd = new NumberedSplitRDD(input)
- val numInputSplits = splitRdd.splits.size
-
- // Run a parallel map and collect to write the intermediate data files,
- // returning a list of inputSplitId -> serverUri pairs
- val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => {
- val myIndex = pair._1
- val myIterator = pair._2
- val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C])
- for ((k, v) <- myIterator) {
- var bucketId = k.hashCode % numOutputSplits
- if (bucketId < 0) { // Fix bucket ID if hash code was negative
- bucketId += numOutputSplits
- }
- val bucket = buckets(bucketId)
- bucket(k) = bucket.get(k) match {
- case Some(c) => mergeValue(c, v)
- case None => createCombiner(v)
- }
- }
-
- for (i <- 0 until numOutputSplits) {
- var blockNum = 0
- var isDirty = false
- var file: File = null
- var out: ObjectOutputStream = null
-
- var writeStartTime: Long = 0
-
- buckets(i).foreach(pair => {
- // Open a new file if necessary
- if (!isDirty) {
- file = ManualBlockedLocalFileShuffle.getOutputFile(shuffleId,
- myIndex, i, blockNum)
- writeStartTime = System.currentTimeMillis
- logInfo("BEGIN WRITE: " + file)
-
- out = new ObjectOutputStream(new FileOutputStream(file))
- }
-
- out.writeObject(pair)
- out.flush()
- isDirty = true
-
- // Close the old file if has crossed the blockSize limit
- if (file.length > Shuffle.BlockSize) {
- out.close()
- logInfo("END WRITE: " + file)
- val writeTime = System.currentTimeMillis - writeStartTime
- logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.")
-
- blockNum = blockNum + 1
- isDirty = false
- }
- })
-
- if (isDirty) {
- out.close()
- logInfo("END WRITE: " + file)
- val writeTime = System.currentTimeMillis - writeStartTime
- logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.")
-
- blockNum = blockNum + 1
- }
-
- // Write the BLOCKNUM file
- file = ManualBlockedLocalFileShuffle.getBlockNumOutputFile(shuffleId,
- myIndex, i)
- out = new ObjectOutputStream(new FileOutputStream(file))
- out.writeObject(blockNum)
- out.close()
- }
-
- (myIndex, ManualBlockedLocalFileShuffle.serverUri)
- }).collect()
-
- // TODO: Could broadcast outputLocs
-
- // Return an RDD that does each of the merges for a given partition
- val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits)
- return indexes.flatMap((myId: Int) => {
- totalSplits = outputLocs.size
- hasSplits = 0
-
- totalBlocksInSplit = Array.tabulate(totalSplits)(_ => -1)
- hasBlocksInSplit = Array.tabulate(totalSplits)(_ => 0)
-
- hasSplitsBitVector = new BitSet(totalSplits)
- splitsInRequestBitVector = new BitSet(totalSplits)
-
- receivedData = new LinkedBlockingQueue[(Int, Array[Byte])]
- combiners = new HashMap[K, C]
-
- var threadPool = Shuffle.newDaemonFixedThreadPool(
- Shuffle.MaxRxConnections)
-
- while (hasSplits < totalSplits) {
- var numThreadsToCreate =
- Math.min(totalSplits, Shuffle.MaxRxConnections) -
- threadPool.getActiveCount
-
- while (hasSplits < totalSplits && numThreadsToCreate > 0) {
- // Select a random split to pull
- val splitIndex = selectRandomSplit
-
- if (splitIndex != -1) {
- val (inputId, serverUri) = outputLocs(splitIndex)
-
- threadPool.execute(new ShuffleClient(serverUri, shuffleId.toInt,
- inputId, myId, splitIndex, mergeCombiners))
-
- // splitIndex is in transit. Will be unset in the ShuffleClient
- splitsInRequestBitVector.synchronized {
- splitsInRequestBitVector.set(splitIndex)
- }
- }
-
- numThreadsToCreate = numThreadsToCreate - 1
- }
-
- // Sleep for a while before creating new threads
- Thread.sleep(Shuffle.MinKnockInterval)
- }
-
- threadPool.shutdown()
-
- // Start consumer
- // TODO: Consumption is delayed until everything has been received.
- // Otherwise it interferes with network performance
- var shuffleConsumer = new ShuffleConsumer(mergeCombiners)
- shuffleConsumer.setDaemon(true)
- shuffleConsumer.start()
- logInfo("ShuffleConsumer started...")
-
- // Don't return until consumption is finished
- // while (receivedData.size > 0) {
- // Thread.sleep(Shuffle.MinKnockInterval)
- // }
-
- // Wait till shuffleConsumer is done
- shuffleConsumer.join
-
- combiners
- })
- }
-
- def selectRandomSplit: Int = {
- var requiredSplits = new ArrayBuffer[Int]
-
- synchronized {
- for (i <- 0 until totalSplits) {
- if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) {
- requiredSplits += i
- }
- }
- }
-
- if (requiredSplits.size > 0) {
- requiredSplits(ManualBlockedLocalFileShuffle.ranGen.nextInt(
- requiredSplits.size))
- } else {
- -1
- }
- }
-
- class ShuffleConsumer(mergeCombiners: (C, C) => C)
- extends Thread with Logging {
- override def run: Unit = {
- // Run until all splits are here
- while (receivedData.size > 0) {
- var splitIndex = -1
- var recvByteArray: Array[Byte] = null
-
- try {
- var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])]
- splitIndex = tempPair._1
- recvByteArray = tempPair._2
- } catch {
- case e: Exception => {
- logInfo("Exception during taking data from receivedData")
- }
- }
-
- val inputStream =
- new ObjectInputStream(new ByteArrayInputStream(recvByteArray))
-
- try{
- while (true) {
- val (k, c) = inputStream.readObject.asInstanceOf[(K, C)]
- combiners(k) = combiners.get(k) match {
- case Some(oldC) => mergeCombiners(oldC, c)
- case None => c
- }
- }
- } catch {
- case e: EOFException => { }
- }
- inputStream.close()
- }
- }
- }
-
- class ShuffleClient(serverUri: String, shuffleId: Int,
- inputId: Int, myId: Int, splitIndex: Int,
- mergeCombiners: (C, C) => C)
- extends Thread with Logging {
- private var receptionSucceeded = false
-
- override def run: Unit = {
- try {
- // Everything will break if BLOCKNUM is not correctly received
- // First get BLOCKNUM file if totalBlocksInSplit(splitIndex) is unknown
- if (totalBlocksInSplit(splitIndex) == -1) {
- val url = "%s/shuffle/%d/%d/BLOCKNUM-%d".format(serverUri, shuffleId,
- inputId, myId)
- val inputStream = new ObjectInputStream(new URL(url).openStream())
- totalBlocksInSplit(splitIndex) =
- inputStream.readObject().asInstanceOf[Int]
- inputStream.close()
- }
-
- // Open connection
- val urlString =
- "%s/shuffle/%d/%d/%d-%d".format(serverUri, shuffleId, inputId,
- myId, hasBlocksInSplit(splitIndex))
- val url = new URL(urlString)
- val httpConnection =
- url.openConnection().asInstanceOf[HttpURLConnection]
-
- // Connect to the server
- httpConnection.connect()
-
- // Receive file length
- var requestedFileLen = httpConnection.getContentLength
-
- val readStartTime = System.currentTimeMillis
- logInfo("BEGIN READ: " + url)
-
- // Receive data in an Array[Byte]
- var recvByteArray = new Array[Byte](requestedFileLen)
- var alreadyRead = 0
- var bytesRead = 0
-
- val isSource = httpConnection.getInputStream()
- while (alreadyRead != requestedFileLen) {
- bytesRead = isSource.read(recvByteArray, alreadyRead,
- requestedFileLen - alreadyRead)
- if (bytesRead > 0) {
- alreadyRead = alreadyRead + bytesRead
- }
- }
-
- // Disconnect
- httpConnection.disconnect()
-
- // Make it available to the consumer
- try {
- receivedData.put((splitIndex, recvByteArray))
- } catch {
- case e: Exception => {
- logInfo("Exception during putting data into receivedData")
- }
- }
-
- // TODO: Updating stats before consumption is completed
- hasBlocksInSplit(splitIndex) = hasBlocksInSplit(splitIndex) + 1
-
- // Split has been received only if all the blocks have been received
- if (hasBlocksInSplit(splitIndex) == totalBlocksInSplit(splitIndex)) {
- hasSplitsBitVector.synchronized {
- hasSplitsBitVector.set(splitIndex)
- hasSplits += 1
- }
- }
-
- // We have received splitIndex
- splitsInRequestBitVector.synchronized {
- splitsInRequestBitVector.set(splitIndex, false)
- }
-
- receptionSucceeded = true
-
- logInfo("END READ: " + url)
- val readTime = System.currentTimeMillis - readStartTime
- logInfo("Reading " + url + " took " + readTime + " millis.")
- } catch {
- // EOFException is expected to happen because sender can break
- // connection due to timeout
- case eofe: java.io.EOFException => { }
- case e: Exception => {
- logInfo("ShuffleClient had a " + e)
- }
- } finally {
- // If reception failed, unset for future retry
- if (!receptionSucceeded) {
- splitsInRequestBitVector.synchronized {
- splitsInRequestBitVector.set(splitIndex, false)
- }
- }
- }
- }
- }
-}
-
-object ManualBlockedLocalFileShuffle extends Logging {
- private var initialized = false
- private var nextShuffleId = new AtomicLong(0)
-
- // Variables initialized by initializeIfNeeded()
- private var shuffleDir: File = null
- private var server: HttpServer = null
- private var serverUri: String = null
-
- // Random number generator
- var ranGen = new Random
-
- private def initializeIfNeeded() = synchronized {
- if (!initialized) {
- // TODO: localDir should be created by some mechanism common to Spark
- // so that it can be shared among shuffle, broadcast, etc
- val localDirRoot = System.getProperty("spark.local.dir", "/tmp")
- var tries = 0
- var foundLocalDir = false
- var localDir: File = null
- var localDirUuid: UUID = null
- while (!foundLocalDir && tries < 10) {
- tries += 1
- try {
- localDirUuid = UUID.randomUUID
- localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
- if (!localDir.exists) {
- localDir.mkdirs()
- foundLocalDir = true
- }
- } catch {
- case e: Exception =>
- logWarning("Attempt " + tries + " to create local dir failed", e)
- }
- }
- if (!foundLocalDir) {
- logError("Failed 10 attempts to create local dir in " + localDirRoot)
- System.exit(1)
- }
- shuffleDir = new File(localDir, "shuffle")
- shuffleDir.mkdirs()
- logInfo("Shuffle dir: " + shuffleDir)
-
- val extServerPort = System.getProperty(
- "spark.localFileShuffle.external.server.port", "-1").toInt
- if (extServerPort != -1) {
- // We're using an external HTTP server; set URI relative to its root
- var extServerPath = System.getProperty(
- "spark.localFileShuffle.external.server.path", "")
- if (extServerPath != "" && !extServerPath.endsWith("/")) {
- extServerPath += "/"
- }
- serverUri = "http://%s:%d/%s/spark-local-%s".format(
- Utils.localIpAddress, extServerPort, extServerPath, localDirUuid)
- } else {
- // Create our own server
- server = new HttpServer(localDir)
- server.start()
- serverUri = server.uri
- }
- initialized = true
- logInfo("Local URI: " + serverUri)
- }
- }
-
- def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int,
- blockId: Int): File = {
- initializeIfNeeded()
- val dir = new File(shuffleDir, shuffleId + "/" + inputId)
- dir.mkdirs()
- val file = new File(dir, "%d-%d".format(outputId, blockId))
- return file
- }
-
- def getBlockNumOutputFile(shuffleId: Long, inputId: Int,
- outputId: Int): File = {
- initializeIfNeeded()
- val dir = new File(shuffleDir, shuffleId + "/" + inputId)
- dir.mkdirs()
- val file = new File(dir, "BLOCKNUM-" + outputId)
- return file
- }
-
- def getServerUri(): String = {
- initializeIfNeeded()
- serverUri
- }
-
- def newShuffleId(): Long = {
- nextShuffleId.getAndIncrement()
- }
-
- // Returns a standard ThreadFactory except all threads are daemons
- private def newDaemonThreadFactory: ThreadFactory = {
- new ThreadFactory {
- def newThread(r: Runnable): Thread = {
- var t = Executors.defaultThreadFactory.newThread(r)
- t.setDaemon(true)
- return t
- }
- }
- }
-
- // Wrapper over newFixedThreadPool
- def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = {
- var threadPool =
- Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor]
-
- threadPool.setThreadFactory(newDaemonThreadFactory)
-
- return threadPool
- }
-}
diff --git a/core/src/main/scala/spark/ShuffleTrackerStrategy.scala b/core/src/main/scala/spark/ShuffleTrackerStrategy.scala
deleted file mode 100644
index fc2f4aa5f7..0000000000
--- a/core/src/main/scala/spark/ShuffleTrackerStrategy.scala
+++ /dev/null
@@ -1,470 +0,0 @@
-package spark
-
-import java.util.{BitSet, Random}
-
-import scala.collection.mutable.ArrayBuffer
-import scala.util.Sorting._
-
-/**
- * A trait for implementing tracker strategies for the shuffle system.
- */
-trait ShuffleTrackerStrategy {
- // Initialize
- def initialize(outputLocs_ : Array[SplitInfo]): Unit
-
- // Select a set of splits and send back
- def selectSplit(reducerSplitInfo: SplitInfo): ArrayBuffer[Int]
-
- // Update internal stats if things could be sent back successfully
- def AddReducerToSplit(reducerSplitInfo: SplitInfo, splitIndices: ArrayBuffer[Int]): Unit
-
- // A reducer is done. Update internal stats
- def deleteReducerFrom(reducerSplitInfo: SplitInfo,
- receptionStat: ReceptionStats): Unit
-}
-
-/**
- * Helper class to send back reception stats from the reducer
- */
-case class ReceptionStats(val bytesReceived: Int, val timeSpent: Int,
- serverSplitIndex: Int) { }
-
-/**
- * A simple ShuffleTrackerStrategy that tries to balance the total number of
- * connections created for each mapper.
- */
-class BalanceConnectionsShuffleTrackerStrategy
-extends ShuffleTrackerStrategy with Logging {
- private var numSources = -1
- private var outputLocs: Array[SplitInfo] = null
- private var curConnectionsPerLoc: Array[Int] = null
- private var totalConnectionsPerLoc: Array[Int] = null
-
- // The order of elements in the outputLocs (splitIndex) is used to pass
- // information back and forth between the tracker, mappers, and reducers
- def initialize(outputLocs_ : Array[SplitInfo]): Unit = {
- outputLocs = outputLocs_
- numSources = outputLocs.size
-
- // Now initialize other data structures
- curConnectionsPerLoc = Array.tabulate(numSources)(_ => 0)
- totalConnectionsPerLoc = Array.tabulate(numSources)(_ => 0)
- }
-
- def selectSplit(reducerSplitInfo: SplitInfo): ArrayBuffer[Int] = synchronized {
- var minConnections = Int.MaxValue
- var minIndex = -1
-
- var splitIndices = ArrayBuffer[Int]()
-
- for (i <- 0 until numSources) {
- // TODO: Use of MaxRxConnections instead of MaxTxConnections is
- // intentional here. MaxTxConnections is per machine whereas
- // MaxRxConnections is per mapper/reducer. Will have to find a better way.
- if (curConnectionsPerLoc(i) < Shuffle.MaxRxConnections &&
- totalConnectionsPerLoc(i) < minConnections &&
- !reducerSplitInfo.hasSplitsBitVector.get(i)) {
- minConnections = totalConnectionsPerLoc(i)
- minIndex = i
- }
- }
-
- if (minIndex != -1) {
- splitIndices += minIndex
- }
-
- return splitIndices
- }
-
- def AddReducerToSplit(reducerSplitInfo: SplitInfo, splitIndices: ArrayBuffer[Int]): Unit = synchronized {
- splitIndices.foreach { splitIndex =>
- curConnectionsPerLoc(splitIndex) = curConnectionsPerLoc(splitIndex) + 1
- totalConnectionsPerLoc(splitIndex) =
- totalConnectionsPerLoc(splitIndex) + 1
- }
- }
-
- def deleteReducerFrom(reducerSplitInfo: SplitInfo,
- receptionStat: ReceptionStats): Unit = synchronized {
- // Decrease number of active connections
- curConnectionsPerLoc(receptionStat.serverSplitIndex) =
- curConnectionsPerLoc(receptionStat.serverSplitIndex) - 1
-
- // TODO: This assertion can legally fail when ShuffleClient times out while
- // waiting for tracker response and decides to go to a random server
- // assert(curConnectionsPerLoc(receptionStat.serverSplitIndex) >= 0)
-
- // Just in case
- if (curConnectionsPerLoc(receptionStat.serverSplitIndex) < 0) {
- curConnectionsPerLoc(receptionStat.serverSplitIndex) = 0
- }
- }
-}
-
-/**
- * A simple ShuffleTrackerStrategy that randomly selects mapper for each reducer
- */
-class SelectRandomShuffleTrackerStrategy
-extends ShuffleTrackerStrategy with Logging {
- private var numMappers = -1
- private var outputLocs: Array[SplitInfo] = null
-
- private var ranGen = new Random
-
- // The order of elements in the outputLocs (splitIndex) is used to pass
- // information back and forth between the tracker, mappers, and reducers
- def initialize(outputLocs_ : Array[SplitInfo]): Unit = {
- outputLocs = outputLocs_
- numMappers = outputLocs.size
- }
-
- def selectSplit(reducerSplitInfo: SplitInfo): ArrayBuffer[Int] = synchronized {
- var splitIndex = -1
-
- do {
- splitIndex = ranGen.nextInt(numMappers)
- } while (reducerSplitInfo.hasSplitsBitVector.get(splitIndex))
-
- return ArrayBuffer(splitIndex)
- }
-
- def AddReducerToSplit(reducerSplitInfo: SplitInfo, splitIndices: ArrayBuffer[Int]): Unit = synchronized {
- }
-
- def deleteReducerFrom(reducerSplitInfo: SplitInfo,
- receptionStat: ReceptionStats): Unit = synchronized {
- // TODO: This assertion can legally fail when ShuffleClient times out while
- // waiting for tracker response and decides to go to a random server
- // assert(curConnectionsPerLoc(receptionStat.serverSplitIndex) >= 0)
- }
-}
-
-/**
- * Shuffle tracker strategy that tries to balance the percentage of blocks
- * remaining for each reducer
- */
-class BalanceRemainingShuffleTrackerStrategy
-extends ShuffleTrackerStrategy with Logging {
- // Number of mappers
- private var numMappers = -1
- // Number of reducers
- private var numReducers = -1
- private var outputLocs: Array[SplitInfo] = null
-
- // Data structures from reducers' perspectives
- private var totalBlocksPerInputSplit: Array[Array[Int]] = null
- private var hasBlocksPerInputSplit: Array[Array[Int]] = null
-
- // Stored in bytes per millisecond
- private var speedPerInputSplit: Array[Array[Double]] = null
-
- private var curConnectionsPerLoc: Array[Int] = null
- private var totalConnectionsPerLoc: Array[Int] = null
-
- // The order of elements in the outputLocs (splitIndex) is used to pass
- // information back and forth between the tracker, mappers, and reducers
- def initialize(outputLocs_ : Array[SplitInfo]): Unit = {
- outputLocs = outputLocs_
-
- numMappers = outputLocs.size
-
- // All the outputLocs have totalBlocksPerOutputSplit of same size
- numReducers = outputLocs(0).totalBlocksPerOutputSplit.size
-
- // Now initialize the data structures
- totalBlocksPerInputSplit = Array.tabulate(numReducers, numMappers)((i,j) =>
- outputLocs(j).totalBlocksPerOutputSplit(i))
- hasBlocksPerInputSplit = Array.tabulate(numReducers, numMappers)((_,_) => 0)
-
- // Initialize to -1
- speedPerInputSplit = Array.tabulate(numReducers, numMappers)((_,_) => -1.0)
-
- curConnectionsPerLoc = Array.tabulate(numMappers)(_ => 0)
- totalConnectionsPerLoc = Array.tabulate(numMappers)(_ => 0)
- }
-
- def selectSplit(reducerSplitInfo: SplitInfo): ArrayBuffer[Int] = synchronized {
- var splitIndex = -1
-
- // Estimate time remaining to finish receiving for all reducer/mapper pairs
- // If speed is unknown or zero then make it 1 to give a large estimate
- var individualEstimates = Array.tabulate(numReducers, numMappers)((_,_) => 0.0)
- for (i <- 0 until numReducers; j <- 0 until numMappers) {
- var blocksRemaining = totalBlocksPerInputSplit(i)(j) -
- hasBlocksPerInputSplit(i)(j)
- assert(blocksRemaining >= 0)
-
- individualEstimates(i)(j) = 1.0 * blocksRemaining * Shuffle.BlockSize /
- { if (speedPerInputSplit(i)(j) <= 0.0) 1.0 else speedPerInputSplit(i)(j) }
- }
-
- // Check if all speedPerInputSplit entries have non-zero values
- var estimationComplete = true
- for (i <- 0 until numReducers; j <- 0 until numMappers) {
- if (speedPerInputSplit(i)(j) < 0.0) {
- estimationComplete = false
- }
- }
-
- // Mark mappers where this reducer is too fast
- var throttleFromMapper = Array.tabulate(numMappers)(_ => false)
-
- for (i <- 0 until numMappers) {
- var estimatesFromAMapper =
- Array.tabulate(numReducers)(j => individualEstimates(j)(i))
-
- val estimateOfThisReducer = estimatesFromAMapper(reducerSplitInfo.splitId)
-
- // Only care if this reducer yet has something to receive from this mapper
- if (estimateOfThisReducer > 0) {
- // Sort the estimated times
- quickSort(estimatesFromAMapper)
-
- // Find a Shuffle.ThrottleFraction amount of gap
- var gapIndex = -1
- for (i <- 0 until numReducers - 1) {
- if (gapIndex == -1 && estimatesFromAMapper(i) > 0 &&
- (Shuffle.ThrottleFraction * estimatesFromAMapper(i) <
- estimatesFromAMapper(i + 1))) {
- gapIndex = i
- }
-
- assert (estimatesFromAMapper(i) <= estimatesFromAMapper(i + 1))
- }
-
- // Keep track of how many have completed
- var numComplete = estimatesFromAMapper.findIndexOf(i => (i > 0))
- if (numComplete == -1) {
- numComplete = numReducers
- }
-
- // TODO: Pick a configurable parameter
- if (gapIndex != -1 && (1.0 * (gapIndex - numComplete + 1) < 0.1 * Shuffle.ThrottleFraction * (numReducers - numComplete)) &&
- estimateOfThisReducer <= estimatesFromAMapper(gapIndex)) {
- throttleFromMapper(i) = true
- logInfo("Throttling R-%d at M-%d with %f and cut-off %f at %d".format(reducerSplitInfo.splitId, i, estimateOfThisReducer, estimatesFromAMapper(gapIndex + 1), gapIndex))
-// for (i <- 0 until numReducers) {
-// print(estimatesFromAMapper(i) + " ")
-// }
-// println("")
- }
- } else {
- throttleFromMapper(i) = true
- }
- }
-
- var minConnections = Int.MaxValue
- for (i <- 0 until numMappers) {
- // TODO: Use of MaxRxConnections instead of MaxTxConnections is
- // intentional here. MaxTxConnections is per machine whereas
- // MaxRxConnections is per mapper/reducer. Will have to find a better way.
- if (curConnectionsPerLoc(i) < Shuffle.MaxRxConnections &&
- totalConnectionsPerLoc(i) < minConnections &&
- !reducerSplitInfo.hasSplitsBitVector.get(i) &&
- !throttleFromMapper(i)) {
- minConnections = totalConnectionsPerLoc(i)
- splitIndex = i
- }
- }
-
- return ArrayBuffer(splitIndex)
- }
-
- def AddReducerToSplit(reducerSplitInfo: SplitInfo, splitIndices: ArrayBuffer[Int]): Unit = synchronized {
- splitIndices.foreach { splitIndex =>
- curConnectionsPerLoc(splitIndex) += 1
- totalConnectionsPerLoc(splitIndex) += 1
- }
- }
-
- def deleteReducerFrom(reducerSplitInfo: SplitInfo,
- receptionStat: ReceptionStats): Unit = synchronized {
- // Update hasBlocksPerInputSplit for reducerSplitInfo
- hasBlocksPerInputSplit(reducerSplitInfo.splitId) =
- reducerSplitInfo.hasBlocksPerInputSplit
-
- // Store the last known speed. Add 1 to avoid divide-by-zero. Ignore 0 bytes
- // TODO: We are forgetting the old speed. Can use averaging at some point.
- if (receptionStat.bytesReceived > 0) {
- speedPerInputSplit(reducerSplitInfo.splitId)(receptionStat.serverSplitIndex) =
- 1.0 * receptionStat.bytesReceived / (receptionStat.timeSpent + 1.0)
- }
-
- // logInfo("%d received %d bytes in %d millis".format(reducerSplitInfo.splitId, receptionStat.bytesReceived, receptionStat.timeSpent))
-
- // Update current connections to the mapper
- curConnectionsPerLoc(receptionStat.serverSplitIndex) -= 1
-
- // TODO: This assertion can legally fail when ShuffleClient times out while
- // waiting for tracker response and decides to go to a random server
- // assert(curConnectionsPerLoc(receptionStat.serverSplitIndex) >= 0)
-
- // Just in case
- if (curConnectionsPerLoc(receptionStat.serverSplitIndex) < 0) {
- curConnectionsPerLoc(receptionStat.serverSplitIndex) = 0
- }
- }
-}
-
-/**
- * Shuffle tracker strategy that allows reducers to create receiving threads
- * depending on their estimated time remaining
- */
-class LimitConnectionsShuffleTrackerStrategy
-extends ShuffleTrackerStrategy with Logging {
- // Number of mappers
- private var numMappers = -1
- // Number of reducers
- private var numReducers = -1
- private var outputLocs: Array[SplitInfo] = null
-
- private var ranGen = new Random
-
- // Data structures from reducers' perspectives
- private var totalBlocksPerInputSplit: Array[Array[Int]] = null
- private var hasBlocksPerInputSplit: Array[Array[Int]] = null
-
- // Stored in bytes per millisecond
- private var speedPerInputSplit: Array[Array[Double]] = null
-
- private var curConnectionsPerReducer: Array[Int] = null
- private var maxConnectionsPerReducer: Array[Int] = null
-
- // The order of elements in the outputLocs (splitIndex) is used to pass
- // information back and forth between the tracker, mappers, and reducers
- def initialize(outputLocs_ : Array[SplitInfo]): Unit = {
- outputLocs = outputLocs_
-
- numMappers = outputLocs.size
-
- // All the outputLocs have totalBlocksPerOutputSplit of same size
- numReducers = outputLocs(0).totalBlocksPerOutputSplit.size
-
- // Now initialize the data structures
- totalBlocksPerInputSplit = Array.tabulate(numReducers, numMappers)((i,j) =>
- outputLocs(j).totalBlocksPerOutputSplit(i))
- hasBlocksPerInputSplit = Array.tabulate(numReducers, numMappers)((_,_) => 0)
-
- // Initialize to -1
- speedPerInputSplit = Array.tabulate(numReducers, numMappers)((_,_) => -1.0)
-
- curConnectionsPerReducer = Array.tabulate(numReducers)(_ => 0)
- maxConnectionsPerReducer = Array.tabulate(numReducers)(_ => Shuffle.MaxRxConnections)
- }
-
- def selectSplit(reducerSplitInfo: SplitInfo): ArrayBuffer[Int] = synchronized {
- var splitIndices = ArrayBuffer[Int]()
-
- // Estimate time remaining to finish receiving for all reducer/mapper pairs
- // If speed is unknown or zero then make it 1 to give a large estimate
- var individualEstimates = Array.tabulate(numReducers, numMappers)((_,_) => 0.0)
- for (i <- 0 until numReducers; j <- 0 until numMappers) {
- var blocksRemaining = totalBlocksPerInputSplit(i)(j) -
- hasBlocksPerInputSplit(i)(j)
- assert(blocksRemaining >= 0)
-
- individualEstimates(i)(j) = 1.0 * blocksRemaining * Shuffle.BlockSize /
- { if (speedPerInputSplit(i)(j) <= 0.0) 1.0 else speedPerInputSplit(i)(j) }
- }
-
- // Check if all speedPerInputSplit entries have non-zero values
- var estimationComplete = true
- for (i <- 0 until numReducers; j <- 0 until numMappers) {
- if (speedPerInputSplit(i)(j) < 0.0) {
- estimationComplete = false
- }
- }
-
- // Estimate time remaining to finish receiving for each reducer
- var completionEstimates = Array.tabulate(numReducers)(
- individualEstimates(_).foldLeft(Double.MinValue)(Math.max(_,_)))
-
- var numFinished = 0
- for (i <- 0 until numReducers) {
- if (completionEstimates(i).toInt == 0) {
- numFinished += 1
- }
- }
-
- // If a certain number of reducers have finished already, then don't bother
- if (numFinished >= ((1.0 - 0.1 * Shuffle.ThrottleFraction) * numReducers)) {
- for (i <- 0 until numReducers) {
- maxConnectionsPerReducer(i) = Shuffle.MaxRxConnections
- }
- // Otherwise, if estimation is complete give reducers proportional threads
- } else if (estimationComplete) {
- val fastestEstimate =
- completionEstimates.foldLeft(Double.MaxValue)(Math.min(_,_))
- val slowestEstimate =
- completionEstimates.foldLeft(Double.MinValue)(Math.max(_,_))
-
- // Set maxConnectionsPerReducer for all reducers proportional to their
- // estimated time remaining with slowestEstimate reducer having the max
- for (i <- 0 until numReducers) {
- maxConnectionsPerReducer(i) =
- ((completionEstimates(i) / slowestEstimate) * Shuffle.MaxRxConnections).toInt
- }
- }
-
- // Send back a splitIndex if this reducer is within its limit
- if (curConnectionsPerReducer(reducerSplitInfo.splitId) <
- maxConnectionsPerReducer(reducerSplitInfo.splitId)) {
-
- var i = maxConnectionsPerReducer(reducerSplitInfo.splitId) -
- curConnectionsPerReducer(reducerSplitInfo.splitId)
-
- var temp = reducerSplitInfo.hasSplitsBitVector.clone.asInstanceOf[BitSet]
- temp.flip(0, numMappers)
-
- i = Math.min(i, temp.cardinality)
-
- while (i > 0) {
- var splitIndex = -1
-
- do {
- splitIndex = ranGen.nextInt(numMappers)
- } while (reducerSplitInfo.hasSplitsBitVector.get(splitIndex))
-
- reducerSplitInfo.hasSplitsBitVector.set(splitIndex)
- splitIndices += splitIndex
- i -= 1
- }
- }
-
- return splitIndices
- }
-
- def AddReducerToSplit(reducerSplitInfo: SplitInfo, splitIndices: ArrayBuffer[Int]): Unit = synchronized {
- splitIndices.foreach { splitIndex =>
- curConnectionsPerReducer(reducerSplitInfo.splitId) += 1
- }
- }
-
- def deleteReducerFrom(reducerSplitInfo: SplitInfo,
- receptionStat: ReceptionStats): Unit = synchronized {
- // Update hasBlocksPerInputSplit for reducerSplitInfo
- hasBlocksPerInputSplit(reducerSplitInfo.splitId) =
- reducerSplitInfo.hasBlocksPerInputSplit
-
- // Store the last known speed. Add 1 to avoid divide-by-zero. Ignore 0 bytes
- // TODO: We are forgetting the old speed. Can use averaging at some point.
- if (receptionStat.bytesReceived > 0) {
- speedPerInputSplit(reducerSplitInfo.splitId)(receptionStat.serverSplitIndex) =
- 1.0 * receptionStat.bytesReceived / (receptionStat.timeSpent + 1.0)
- }
-
- // logInfo("%d received %d bytes in %d millis".format(reducerSplitInfo.splitId, receptionStat.bytesReceived, receptionStat.timeSpent))
-
- // Update current threads by this reducer
- curConnectionsPerReducer(reducerSplitInfo.splitId) -= 1
-
- // TODO: This assertion can legally fail when ShuffleClient times out while
- // waiting for tracker response and decides to go to a random server
- // assert(curConnectionsPerLoc(receptionStat.serverSplitIndex) >= 0)
-
- // Just in case
- if (curConnectionsPerReducer(reducerSplitInfo.splitId) < 0) {
- curConnectionsPerReducer(reducerSplitInfo.splitId) = 0
- }
- }
-}
diff --git a/core/src/main/scala/spark/TrackedCustomBlockedInMemoryShuffle.scala b/core/src/main/scala/spark/TrackedCustomBlockedInMemoryShuffle.scala
deleted file mode 100644
index 0d21df9338..0000000000
--- a/core/src/main/scala/spark/TrackedCustomBlockedInMemoryShuffle.scala
+++ /dev/null
@@ -1,926 +0,0 @@
-package spark
-
-import java.io._
-import java.net._
-import java.util.{BitSet, Random, Timer, TimerTask, UUID}
-import java.util.concurrent.atomic.AtomicLong
-import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory}
-
-import scala.collection.mutable.{ArrayBuffer, HashMap}
-
-/**
- * An implementation of shuffle using memory served through custom server
- * where receivers create simultaneous connections to multiple servers by
- * setting the 'spark.shuffle.maxRxConnections' config option.
- *
- * By controlling the 'spark.shuffle.blockSize' config option one can also
- * control the largest block size to divide each map output into. Essentially,
- * instead of creating one large output file for each reducer, maps create
- * multiple smaller files to enable finer level of engagement.
- *
- * 'spark.shuffle.maxTxConnections' enforces server-side cap. Ideally,
- * maxTxConnections >= maxRxConnections * numReducersPerMachine
- *
- * 'spark.shuffle.TrackerStrategy' decides which strategy to use in the tracker
- *
- * TODO: Add support for compression when spark.compress is set to true.
- */
-@serializable
-class TrackedCustomBlockedInMemoryShuffle[K, V, C]
-extends Shuffle[K, V, C] with Logging {
- @transient var totalSplits = 0
- @transient var hasSplits = 0
-
- @transient var totalBlocksInSplit: Array[Int] = null
- @transient var hasBlocksInSplit: Array[Int] = null
-
- @transient var hasSplitsBitVector: BitSet = null
- @transient var splitsInRequestBitVector: BitSet = null
-
- @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null
- @transient var combiners: HashMap[K,C] = null
-
- override def compute(input: RDD[(K, V)],
- numOutputSplits: Int,
- createCombiner: V => C,
- mergeValue: (C, V) => C,
- mergeCombiners: (C, C) => C)
- : RDD[(K, C)] =
- {
- val sc = input.sparkContext
- val shuffleId = TrackedCustomBlockedInMemoryShuffle.newShuffleId()
- logInfo("Shuffle ID: " + shuffleId)
-
- val splitRdd = new NumberedSplitRDD(input)
- val numInputSplits = splitRdd.splits.size
-
- // Run a parallel map and collect to write the intermediate data files,
- // returning a list of inputSplitId -> serverUri pairs
- val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => {
- val myIndex = pair._1
- val myIterator = pair._2
- val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C])
- for ((k, v) <- myIterator) {
- var bucketId = k.hashCode % numOutputSplits
- if (bucketId < 0) { // Fix bucket ID if hash code was negative
- bucketId += numOutputSplits
- }
- val bucket = buckets(bucketId)
- bucket(k) = bucket.get(k) match {
- case Some(c) => mergeValue(c, v)
- case None => createCombiner(v)
- }
- }
-
- // Keep track of number of blocks for each output split
- var numBlocksPerOutputSplit = Array.tabulate(numOutputSplits)(_ => 0)
-
- for (i <- 0 until numOutputSplits) {
- var blockNum = 0
- var isDirty = false
-
- var splitName = ""
- var baos: ByteArrayOutputStream = null
- var oos: ObjectOutputStream = null
-
- var writeStartTime: Long = 0
-
- buckets(i).foreach(pair => {
- // Open a new stream if necessary
- if (!isDirty) {
- splitName = TrackedCustomBlockedInMemoryShuffle.getSplitName(shuffleId,
- myIndex, i, blockNum)
-
- baos = new ByteArrayOutputStream
- oos = new ObjectOutputStream(baos)
-
- writeStartTime = System.currentTimeMillis
- logInfo("BEGIN WRITE: " + splitName)
- }
-
- oos.writeObject(pair)
- isDirty = true
-
- // Close the old stream if has crossed the blockSize limit
- if (baos.size > Shuffle.BlockSize) {
- TrackedCustomBlockedInMemoryShuffle.splitsCache(splitName) =
- baos.toByteArray
-
- logInfo("END WRITE: " + splitName)
- val writeTime = System.currentTimeMillis - writeStartTime
- logInfo("Writing " + splitName + " of size " + baos.size + " bytes took " + writeTime + " millis.")
-
- blockNum = blockNum + 1
- isDirty = false
- oos.close()
- }
- })
-
- if (isDirty) {
- TrackedCustomBlockedInMemoryShuffle.splitsCache(splitName) = baos.toByteArray
-
- logInfo("END WRITE: " + splitName)
- val writeTime = System.currentTimeMillis - writeStartTime
- logInfo("Writing " + splitName + " of size " + baos.size + " bytes took " + writeTime + " millis.")
-
- blockNum = blockNum + 1
- oos.close()
- }
-
- // Store BLOCKNUM info
- splitName = TrackedCustomBlockedInMemoryShuffle.getBlockNumOutputName(
- shuffleId, myIndex, i)
- baos = new ByteArrayOutputStream
- oos = new ObjectOutputStream(baos)
- oos.writeObject(blockNum)
- TrackedCustomBlockedInMemoryShuffle.splitsCache(splitName) = baos.toByteArray
-
- // Close streams
- oos.close()
-
- // Store number of blocks for this outputSplit
- numBlocksPerOutputSplit(i) = blockNum
- }
-
- var retVal = SplitInfo(TrackedCustomBlockedInMemoryShuffle.serverAddress,
- TrackedCustomBlockedInMemoryShuffle.serverPort, myIndex)
- retVal.totalBlocksPerOutputSplit = numBlocksPerOutputSplit
-
- (retVal)
- }).collect()
-
- // Start tracker
- var shuffleTracker = new ShuffleTracker(outputLocs)
- shuffleTracker.setDaemon(true)
- shuffleTracker.start()
- logInfo("ShuffleTracker started...")
-
- // Return an RDD that does each of the merges for a given partition
- val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits)
- return indexes.flatMap((myId: Int) => {
- totalSplits = outputLocs.size
- hasSplits = 0
-
- totalBlocksInSplit = Array.tabulate(totalSplits)(_ => -1)
- hasBlocksInSplit = Array.tabulate(totalSplits)(_ => 0)
-
- hasSplitsBitVector = new BitSet(totalSplits)
- splitsInRequestBitVector = new BitSet(totalSplits)
-
- receivedData = new LinkedBlockingQueue[(Int, Array[Byte])]
- combiners = new HashMap[K, C]
-
- var threadPool = Shuffle.newDaemonFixedThreadPool(
- Shuffle.MaxRxConnections)
-
- while (hasSplits < totalSplits) {
- // Local status of hasSplitsBitVector and splitsInRequestBitVector
- val localSplitInfo = getLocalSplitInfo(myId)
-
- // DO NOT talk to the tracker if all the required splits are already busy
- val hasOrWillHaveSplits = localSplitInfo.hasSplitsBitVector.cardinality
-
- var numThreadsToCreate =
- Math.min(totalSplits - hasOrWillHaveSplits, Shuffle.MaxRxConnections) -
- threadPool.getActiveCount
-
- while (hasSplits < totalSplits && numThreadsToCreate > 0) {
- // Receive which split to pull from the tracker
- logInfo("Talking to tracker...")
- val startTime = System.currentTimeMillis
- val splitIndices = getTrackerSelectedSplit(myId)
- logInfo("Got %s from tracker in %d millis".format(splitIndices, System.currentTimeMillis - startTime))
-
- if (splitIndices.size > 0) {
- splitIndices.foreach { splitIndex =>
- val selectedSplitInfo = outputLocs(splitIndex)
- val requestSplit =
- "%d/%d/%d".format(shuffleId, selectedSplitInfo.splitId, myId)
-
- threadPool.execute(new ShuffleClient(splitIndex, selectedSplitInfo,
- requestSplit, myId))
-
- // splitIndex is in transit. Will be unset in the ShuffleClient
- splitsInRequestBitVector.synchronized {
- splitsInRequestBitVector.set(splitIndex)
- }
- }
- } else {
- // Tracker replied back with a NO. Sleep for a while.
- Thread.sleep(Shuffle.MinKnockInterval)
- numThreadsToCreate = 0
- }
-
- numThreadsToCreate = numThreadsToCreate - splitIndices.size
- }
-
- // Sleep for a while before creating new threads
- Thread.sleep(Shuffle.MinKnockInterval)
- }
-
- threadPool.shutdown()
-
- // Start consumer
- // TODO: Consumption is delayed until everything has been received.
- // Otherwise it interferes with network performance
- var shuffleConsumer = new ShuffleConsumer(mergeCombiners)
- shuffleConsumer.setDaemon(true)
- shuffleConsumer.start()
- logInfo("ShuffleConsumer started...")
-
- // Don't return until consumption is finished
- // while (receivedData.size > 0) {
- // Thread.sleep(Shuffle.MinKnockInterval)
- // }
-
- // Wait till shuffleConsumer is done
- shuffleConsumer.join
-
- combiners
- })
- }
-
- private def getLocalSplitInfo(myId: Int): SplitInfo = {
- var localSplitInfo = SplitInfo(InetAddress.getLocalHost.getHostAddress,
- SplitInfo.UnusedParam, myId)
-
- // Store hasSplits
- localSplitInfo.hasSplits = hasSplits
-
- // Store hasSplitsBitVector
- hasSplitsBitVector.synchronized {
- localSplitInfo.hasSplitsBitVector =
- hasSplitsBitVector.clone.asInstanceOf[BitSet]
- }
-
- // Store hasBlocksInSplit to hasBlocksPerInputSplit
- hasBlocksInSplit.synchronized {
- localSplitInfo.hasBlocksPerInputSplit =
- hasBlocksInSplit.clone.asInstanceOf[Array[Int]]
- }
-
- // Include the splitsInRequest as well
- splitsInRequestBitVector.synchronized {
- localSplitInfo.hasSplitsBitVector.or(splitsInRequestBitVector)
- }
-
- return localSplitInfo
- }
-
- def selectRandomSplit: Int = {
- var requiredSplits = new ArrayBuffer[Int]
-
- synchronized {
- for (i <- 0 until totalSplits) {
- if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) {
- requiredSplits += i
- }
- }
- }
-
- if (requiredSplits.size > 0) {
- requiredSplits(TrackedCustomBlockedInMemoryShuffle.ranGen.nextInt(
- requiredSplits.size))
- } else {
- -1
- }
- }
-
- // Talks to the tracker and receives instruction
- private def getTrackerSelectedSplit(myId: Int): ArrayBuffer[Int] = {
- // Local status of hasSplitsBitVector and splitsInRequestBitVector
- val localSplitInfo = getLocalSplitInfo(myId)
-
- // DO NOT talk to the tracker if all the required splits are already busy
- if (localSplitInfo.hasSplitsBitVector.cardinality == totalSplits) {
- return ArrayBuffer[Int]()
- }
-
- val clientSocketToTracker = new Socket(Shuffle.MasterHostAddress,
- Shuffle.MasterTrackerPort)
- val oosTracker =
- new ObjectOutputStream(clientSocketToTracker.getOutputStream)
- oosTracker.flush()
- val oisTracker =
- new ObjectInputStream(clientSocketToTracker.getInputStream)
-
- var selectedSplitIndices = ArrayBuffer[Int]()
-
- // Setup the timeout mechanism
- var timeOutTask = new TimerTask {
- override def run: Unit = {
- logInfo("Waited enough for tracker response... Take random response...")
-
- // sockets will be closed in finally
- // TODO: Sometimes timer wont go off
-
- // TODO: Selecting randomly here. Tracker won't know about it and get an
- // asssertion failure when this thread leaves
-
- selectedSplitIndices = ArrayBuffer(selectRandomSplit)
- }
- }
-
- var timeOutTimer = new Timer
- // TODO: Which timeout to use?
- // timeOutTimer.schedule(timeOutTask, Shuffle.MinKnockInterval)
-
- try {
- // Send intention
- oosTracker.writeObject(Shuffle.ReducerEntering)
- oosTracker.flush()
-
- // Send what this reducer has
- oosTracker.writeObject(localSplitInfo)
- oosTracker.flush()
-
- // Receive reply from the tracker
- selectedSplitIndices = oisTracker.readObject.asInstanceOf[ArrayBuffer[Int]]
-
- // Turn the timer OFF
- timeOutTimer.cancel()
- } catch {
- case e: Exception => {
- logInfo("getTrackerSelectedSplit had a " + e)
- }
- } finally {
- oisTracker.close()
- oosTracker.close()
- clientSocketToTracker.close()
- }
-
- return selectedSplitIndices
- }
-
- class ShuffleTracker(outputLocs: Array[SplitInfo])
- extends Thread with Logging {
- var threadPool = Shuffle.newDaemonCachedThreadPool
- var serverSocket: ServerSocket = null
-
- // Create trackerStrategy object
- val trackerStrategyClass = System.getProperty(
- "spark.shuffle.trackerStrategy",
- "spark.BalanceConnectionsShuffleTrackerStrategy")
-
- val trackerStrategy =
- Class.forName(trackerStrategyClass).newInstance().asInstanceOf[ShuffleTrackerStrategy]
-
- // Must initialize here by supplying the outputLocs param
- // TODO: This could be avoided by directly passing it to the constructor
- trackerStrategy.initialize(outputLocs)
-
- override def run: Unit = {
- serverSocket = new ServerSocket(Shuffle.MasterTrackerPort)
- logInfo("ShuffleTracker" + serverSocket)
-
- try {
- while (true) {
- var clientSocket: Socket = null
- try {
- clientSocket = serverSocket.accept()
- } catch {
- case e: Exception => {
- logInfo("ShuffleTracker had a " + e)
- }
- }
-
- if (clientSocket != null) {
- try {
- threadPool.execute(new Thread {
- override def run: Unit = {
- val oos = new ObjectOutputStream(clientSocket.getOutputStream)
- oos.flush()
- val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- try {
- // Receive intention
- val reducerIntention = ois.readObject.asInstanceOf[Int]
-
- if (reducerIntention == Shuffle.ReducerEntering) {
- // Receive what the reducer has
- val reducerSplitInfo =
- ois.readObject.asInstanceOf[SplitInfo]
-
- // Select splits and update stats if necessary
- var selectedSplitIndices = ArrayBuffer[Int]()
- trackerStrategy.synchronized {
- selectedSplitIndices = trackerStrategy.selectSplit(
- reducerSplitInfo)
- }
-
- // Send reply back
- oos.writeObject(selectedSplitIndices)
- oos.flush()
-
- // Update internal stats, only if receiver got the reply
- trackerStrategy.synchronized {
- trackerStrategy.AddReducerToSplit(reducerSplitInfo,
- selectedSplitIndices)
- }
- }
- else if (reducerIntention == Shuffle.ReducerLeaving) {
- val reducerSplitInfo =
- ois.readObject.asInstanceOf[SplitInfo]
-
- // Receive reception stats: how many blocks the reducer
- // read in how much time and from where
- val receptionStat =
- ois.readObject.asInstanceOf[ReceptionStats]
-
- // Update stats
- trackerStrategy.synchronized {
- trackerStrategy.deleteReducerFrom(reducerSplitInfo,
- receptionStat)
- }
-
- // Send ACK
- oos.writeObject(receptionStat.serverSplitIndex)
- oos.flush()
- }
- else {
- throw new SparkException("Undefined reducerIntention")
- }
- } catch {
- // EOFException is expected to happen because receiver can
- // break connection due to timeout and pick random instead
- case eofe: java.io.EOFException => { }
- case e: Exception => {
- logInfo("ShuffleTracker had a " + e)
- }
- } finally {
- ois.close()
- oos.close()
- clientSocket.close()
- }
- }
- })
- } catch {
- // In failure, close socket here; else, client thread will close
- case ioe: IOException => {
- clientSocket.close()
- }
- }
- }
- }
- } finally {
- serverSocket.close()
- }
- // Shutdown the thread pool
- threadPool.shutdown()
- }
- }
-
- class ShuffleConsumer(mergeCombiners: (C, C) => C)
- extends Thread with Logging {
- override def run: Unit = {
- // Run until all splits are here
- while (receivedData.size > 0) {
- var splitIndex = -1
- var recvByteArray: Array[Byte] = null
-
- try {
- var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])]
- splitIndex = tempPair._1
- recvByteArray = tempPair._2
- } catch {
- case e: Exception => {
- logInfo("Exception during taking data from receivedData")
- }
- }
-
- val inputStream =
- new ObjectInputStream(new ByteArrayInputStream(recvByteArray))
-
- try{
- while (true) {
- val (k, c) = inputStream.readObject.asInstanceOf[(K, C)]
- combiners(k) = combiners.get(k) match {
- case Some(oldC) => mergeCombiners(oldC, c)
- case None => c
- }
- }
- } catch {
- case e: EOFException => { }
- }
- inputStream.close()
- }
- }
- }
-
- class ShuffleClient(splitIndex: Int, serversplitInfo: SplitInfo,
- requestSplit: String, myId: Int)
- extends Thread with Logging {
- private var peerSocketToSource: Socket = null
- private var oosSource: ObjectOutputStream = null
- private var oisSource: ObjectInputStream = null
-
- private var receptionSucceeded = false
-
- // Make sure that multiple messages don't go to the tracker
- private var alreadySentLeavingNotification = false
-
- // Keep track of bytes received and time spent
- private var numBytesReceived = 0
- private var totalTimeSpent = 0
-
- override def run: Unit = {
- // Setup the timeout mechanism
- var timeOutTask = new TimerTask {
- override def run: Unit = {
- cleanUp()
- }
- }
-
- var timeOutTimer = new Timer
- timeOutTimer.schedule(timeOutTask, Shuffle.MaxKnockInterval)
-
- try {
- // Everything will break if BLOCKNUM is not correctly received
- // First get BLOCKNUM file if totalBlocksInSplit(splitIndex) is unknown
- peerSocketToSource = new Socket(
- serversplitInfo.hostAddress, serversplitInfo.listenPort)
- oosSource =
- new ObjectOutputStream(peerSocketToSource.getOutputStream)
- oosSource.flush()
- var isSource = peerSocketToSource.getInputStream
- oisSource = new ObjectInputStream(isSource)
-
- // Send path information
- oosSource.writeObject(requestSplit)
-
- // TODO: Can be optimized. No need to do it everytime.
- // Receive BLOCKNUM
- totalBlocksInSplit(splitIndex) = oisSource.readObject.asInstanceOf[Int]
-
- // Turn the timer OFF, if the sender responds before timeout
- timeOutTimer.cancel()
-
- while (hasBlocksInSplit(splitIndex) < totalBlocksInSplit(splitIndex)) {
- // Set receptionSucceeded to false before trying for each block
- receptionSucceeded = false
-
- // Request specific block
- oosSource.writeObject(hasBlocksInSplit(splitIndex))
-
- // Good to go. First, receive the length of the requested file
- var requestedFileLen = oisSource.readObject.asInstanceOf[Int]
- logInfo("Received requestedFileLen = " + requestedFileLen)
-
- // Create a temp variable to be used in different places
- val requestPath = "http://%s:%d/shuffle/%s-%d".format(
- serversplitInfo.hostAddress, serversplitInfo.listenPort, requestSplit,
- hasBlocksInSplit(splitIndex))
-
- // Receive the file
- if (requestedFileLen != -1) {
- val readStartTime = System.currentTimeMillis
- logInfo("BEGIN READ: " + requestPath)
-
- // Receive data in an Array[Byte]
- var recvByteArray = new Array[Byte](requestedFileLen)
- var alreadyRead = 0
- var bytesRead = 0
-
- while (alreadyRead != requestedFileLen) {
- bytesRead = isSource.read(recvByteArray, alreadyRead,
- requestedFileLen - alreadyRead)
- if (bytesRead > 0) {
- alreadyRead = alreadyRead + bytesRead
- }
- }
-
- // Make it available to the consumer
- try {
- receivedData.put((splitIndex, recvByteArray))
- } catch {
- case e: Exception => {
- logInfo("Exception during putting data into receivedData")
- }
- }
-
- // TODO: Updating stats before consumption is completed
- hasBlocksInSplit(splitIndex) = hasBlocksInSplit(splitIndex) + 1
-
- // Split has been received only if all the blocks have been received
- if (hasBlocksInSplit(splitIndex) == totalBlocksInSplit(splitIndex)) {
- hasSplitsBitVector.synchronized {
- hasSplitsBitVector.set(splitIndex)
- hasSplits += 1
- }
- }
-
- receptionSucceeded = true
-
- logInfo("END READ: " + requestPath)
- val readTime = System.currentTimeMillis - readStartTime
- logInfo("Reading " + requestPath + " took " + readTime + " millis.")
-
- // Update stats
- numBytesReceived = numBytesReceived + requestedFileLen
- totalTimeSpent = totalTimeSpent + readTime.toInt
- } else {
- throw new SparkException("ShuffleServer " + serversplitInfo.hostAddress + " does not have " + requestSplit)
- }
- }
- } catch {
- // EOFException is expected to happen because sender can break
- // connection due to timeout
- case eofe: java.io.EOFException => { }
- case e: Exception => {
- logInfo("ShuffleClient had a " + e)
- }
- } finally {
- splitsInRequestBitVector.synchronized {
- splitsInRequestBitVector.set(splitIndex, false)
- }
- cleanUp()
- }
- }
-
- // Connect to the tracker and update its stats
- private def sendLeavingNotification(): Unit = synchronized {
- if (!alreadySentLeavingNotification) {
- val clientSocketToTracker = new Socket(Shuffle.MasterHostAddress,
- Shuffle.MasterTrackerPort)
- val oosTracker =
- new ObjectOutputStream(clientSocketToTracker.getOutputStream)
- oosTracker.flush()
- val oisTracker =
- new ObjectInputStream(clientSocketToTracker.getInputStream)
-
- try {
- // Send intention
- oosTracker.writeObject(Shuffle.ReducerLeaving)
- oosTracker.flush()
-
- // Send reducerSplitInfo
- oosTracker.writeObject(getLocalSplitInfo(myId))
- oosTracker.flush()
-
- // Send reception stats
- oosTracker.writeObject(ReceptionStats(
- numBytesReceived, totalTimeSpent, splitIndex))
- oosTracker.flush()
-
- // Receive ACK. No need to do anything with that
- oisTracker.readObject.asInstanceOf[Int]
-
- // Now update sentLeavingNotifacation
- alreadySentLeavingNotification = true
- } catch {
- case e: Exception => {
- logInfo("sendLeavingNotification had a " + e)
- }
- } finally {
- oisTracker.close()
- oosTracker.close()
- clientSocketToTracker.close()
- }
- }
- }
-
- private def cleanUp(): Unit = {
- // Update tracker stats first
- sendLeavingNotification()
-
- // Clean up the connections to the mapper
- if (oisSource != null) {
- oisSource.close()
- }
- if (oosSource != null) {
- oosSource.close()
- }
- if (peerSocketToSource != null) {
- peerSocketToSource.close()
- }
-
- logInfo("Leaving client")
- }
- }
-}
-
-object TrackedCustomBlockedInMemoryShuffle extends Logging {
- // Cache for keeping the splits around
- val splitsCache = new HashMap[String, Array[Byte]]
-
- private var initialized = false
- private var nextShuffleId = new AtomicLong(0)
-
- // Variables initialized by initializeIfNeeded()
- private var shuffleDir: File = null
-
- private var shuffleServer: ShuffleServer = null
- private var serverAddress = InetAddress.getLocalHost.getHostAddress
- private var serverPort: Int = -1
-
- // Random number generator
- var ranGen = new Random
-
- private def initializeIfNeeded() = synchronized {
- if (!initialized) {
- // TODO: localDir should be created by some mechanism common to Spark
- // so that it can be shared among shuffle, broadcast, etc
- val localDirRoot = System.getProperty("spark.local.dir", "/tmp")
- var tries = 0
- var foundLocalDir = false
- var localDir: File = null
- var localDirUuid: UUID = null
- while (!foundLocalDir && tries < 10) {
- tries += 1
- try {
- localDirUuid = UUID.randomUUID
- localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
- if (!localDir.exists) {
- localDir.mkdirs()
- foundLocalDir = true
- }
- } catch {
- case e: Exception =>
- logWarning("Attempt " + tries + " to create local dir failed", e)
- }
- }
- if (!foundLocalDir) {
- logError("Failed 10 attempts to create local dir in " + localDirRoot)
- System.exit(1)
- }
- shuffleDir = new File(localDir, "shuffle")
- shuffleDir.mkdirs()
- logInfo("Shuffle dir: " + shuffleDir)
-
- // Create and start the shuffleServer
- shuffleServer = new ShuffleServer
- shuffleServer.setDaemon(true)
- shuffleServer.start()
- logInfo("ShuffleServer started...")
-
- initialized = true
- }
- }
-
- def getSplitName(shuffleId: Long, inputId: Int, outputId: Int,
- blockId: Int): String = {
- initializeIfNeeded()
- // Adding shuffleDir is unnecessary. Added to keep the parsers working
- return "%s/%d/%d/%d-%d".format(shuffleDir, shuffleId, inputId, outputId,
- blockId)
- }
-
- def getBlockNumOutputName(shuffleId: Long, inputId: Int,
- outputId: Int): String = {
- initializeIfNeeded()
- return "%s/%d/%d/%d-BLOCKNUM".format(shuffleDir, shuffleId, inputId,
- outputId)
- }
-
- def newShuffleId(): Long = {
- nextShuffleId.getAndIncrement()
- }
-
- class ShuffleServer
- extends Thread with Logging {
- var threadPool = Shuffle.newDaemonFixedThreadPool(Shuffle.MaxTxConnections)
-
- var serverSocket: ServerSocket = null
-
- override def run: Unit = {
- serverSocket = new ServerSocket(0)
- serverPort = serverSocket.getLocalPort
-
- logInfo("ShuffleServer started with " + serverSocket)
- logInfo("Local URI: http://" + serverAddress + ":" + serverPort)
-
- try {
- while (true) {
- var clientSocket: Socket = null
- try {
- clientSocket = serverSocket.accept()
- } catch {
- case e: Exception => { }
- }
- if (clientSocket != null) {
- logInfo("Serve: Accepted new client connection:" + clientSocket)
- try {
- threadPool.execute(new ShuffleServerThread(clientSocket))
- } catch {
- // In failure, close socket here; else, the thread will close it
- case ioe: IOException => {
- clientSocket.close()
- }
- }
- }
- }
- } finally {
- if (serverSocket != null) {
- logInfo("ShuffleServer now stopping...")
- serverSocket.close()
- }
- }
- // Shutdown the thread pool
- threadPool.shutdown()
- }
-
- class ShuffleServerThread(val clientSocket: Socket)
- extends Thread with Logging {
- private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream]
- os.flush()
- private val bos = new BufferedOutputStream(os)
- bos.flush()
- private val oos = new ObjectOutputStream(os)
- oos.flush()
- private val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- logInfo("new ShuffleServerThread is running")
-
- override def run: Unit = {
- try {
- // Receive basic path information
- var requestedSplitBase = ois.readObject.asInstanceOf[String]
-
- logInfo("requestedSplitBase: " + requestedSplitBase)
-
- // Read BLOCKNUM and send back the total number of blocks
- val blockNumName = "%s/%s-BLOCKNUM".format(shuffleDir,
- requestedSplitBase)
-
- val blockNumIn = new ObjectInputStream(new ByteArrayInputStream(
- TrackedCustomBlockedInMemoryShuffle.splitsCache(blockNumName)))
- val BLOCKNUM = blockNumIn.readObject.asInstanceOf[Int]
- blockNumIn.close()
-
- oos.writeObject(BLOCKNUM)
- oos.flush()
-
- val startTime = System.currentTimeMillis
- var curTime = startTime
- var keepSending = true
- var numBlocksToSend = Shuffle.MaxChatBlocks
-
- while (keepSending && numBlocksToSend > 0) {
- // Receive specific block request
- val blockId = ois.readObject.asInstanceOf[Int]
-
- // Ready to send
- var requestedSplit = shuffleDir + "/" + requestedSplitBase + "-" + blockId
-
- // Send the length of the requestedSplit to let the receiver know that
- // transfer is about to start
- // In the case of receiver timeout and connection close, this will
- // throw a java.net.SocketException: Broken pipe
- var requestedSplitLen = -1
-
- try {
- requestedSplitLen =
- TrackedCustomBlockedInMemoryShuffle.splitsCache(requestedSplit).length
- } catch {
- case e: Exception => { }
- }
-
- oos.writeObject(requestedSplitLen)
- oos.flush()
-
- logInfo("requestedSplitLen = " + requestedSplitLen)
-
- // Read and send the requested file
- if (requestedSplitLen != -1) {
- // Send
- bos.write(TrackedCustomBlockedInMemoryShuffle.splitsCache(requestedSplit),
- 0, requestedSplitLen)
- bos.flush()
-
- // Update loop variables
- numBlocksToSend = numBlocksToSend - 1
-
- curTime = System.currentTimeMillis
- // Revoke sending only if there is anyone waiting in the queue
- // TODO: Turning OFF the optimization so that reducers go back to
- // tracker get advice
- if (curTime - startTime >= Shuffle.MaxChatTime /* &&
- threadPool.getQueue.size > 0 */) {
- keepSending = false
- }
- } else {
- // Close the connection
- }
- }
- } catch {
- // If something went wrong, e.g., the worker at the other end died etc
- // then close everything up
- // Exception can happen if the receiver stops receiving
- // EOFException is expected to happen because receiver can break
- // connection as soon as it has all the blocks
- case eofe: java.io.EOFException => { }
- case e: Exception => {
- logInfo("ShuffleServerThread had a " + e)
- }
- } finally {
- logInfo("ShuffleServerThread is closing streams and sockets")
- ois.close()
- // TODO: Following can cause "java.net.SocketException: Socket closed"
- oos.close()
- bos.close()
- clientSocket.close()
- }
- }
- }
- }
-}
diff --git a/core/src/main/scala/spark/TrackedCustomBlockedLocalFileShuffle.scala b/core/src/main/scala/spark/TrackedCustomBlockedLocalFileShuffle.scala
deleted file mode 100644
index a27fa628c6..0000000000
--- a/core/src/main/scala/spark/TrackedCustomBlockedLocalFileShuffle.scala
+++ /dev/null
@@ -1,930 +0,0 @@
-package spark
-
-import java.io._
-import java.net._
-import java.util.{BitSet, Random, Timer, TimerTask, UUID}
-import java.util.concurrent.atomic.AtomicLong
-import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory}
-
-import scala.collection.mutable.{ArrayBuffer, HashMap}
-
-/**
- * An implementation of shuffle using local files served through custom server
- * where receivers create simultaneous connections to multiple servers by
- * setting the 'spark.shuffle.maxRxConnections' config option.
- *
- * By controlling the 'spark.shuffle.blockSize' config option one can also
- * control the largest block size to divide each map output into. Essentially,
- * instead of creating one large output file for each reducer, maps create
- * multiple smaller files to enable finer level of engagement.
- *
- * 'spark.shuffle.maxTxConnections' enforces server-side cap. Ideally,
- * maxTxConnections >= maxRxConnections * numReducersPerMachine
- *
- * 'spark.shuffle.TrackerStrategy' decides which strategy to use in the tracker
- *
- * TODO: Add support for compression when spark.compress is set to true.
- */
-@serializable
-class TrackedCustomBlockedLocalFileShuffle[K, V, C]
-extends Shuffle[K, V, C] with Logging {
- @transient var totalSplits = 0
- @transient var hasSplits = 0
-
- @transient var totalBlocksInSplit: Array[Int] = null
- @transient var hasBlocksInSplit: Array[Int] = null
-
- @transient var hasSplitsBitVector: BitSet = null
- @transient var splitsInRequestBitVector: BitSet = null
-
- @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null
- @transient var combiners: HashMap[K,C] = null
-
- override def compute(input: RDD[(K, V)],
- numOutputSplits: Int,
- createCombiner: V => C,
- mergeValue: (C, V) => C,
- mergeCombiners: (C, C) => C)
- : RDD[(K, C)] =
- {
- val sc = input.sparkContext
- val shuffleId = TrackedCustomBlockedLocalFileShuffle.newShuffleId()
- logInfo("Shuffle ID: " + shuffleId)
-
- val splitRdd = new NumberedSplitRDD(input)
- val numInputSplits = splitRdd.splits.size
-
- // Run a parallel map and collect to write the intermediate data files,
- // returning a list of inputSplitId -> serverUri pairs
- val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => {
- val myIndex = pair._1
- val myIterator = pair._2
- val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C])
- for ((k, v) <- myIterator) {
- var bucketId = k.hashCode % numOutputSplits
- if (bucketId < 0) { // Fix bucket ID if hash code was negative
- bucketId += numOutputSplits
- }
- val bucket = buckets(bucketId)
- bucket(k) = bucket.get(k) match {
- case Some(c) => mergeValue(c, v)
- case None => createCombiner(v)
- }
- }
-
- // Keep track of number of blocks for each output split
- var numBlocksPerOutputSplit = Array.tabulate(numOutputSplits)(_ => 0)
-
- for (i <- 0 until numOutputSplits) {
- var blockNum = 0
- var isDirty = false
- var file: File = null
- var out: ObjectOutputStream = null
-
- var writeStartTime: Long = 0
-
- buckets(i).foreach(pair => {
- // Open a new file if necessary
- if (!isDirty) {
- file = TrackedCustomBlockedLocalFileShuffle.getOutputFile(shuffleId,
- myIndex, i, blockNum)
- writeStartTime = System.currentTimeMillis
- logInfo("BEGIN WRITE: " + file)
-
- out = new ObjectOutputStream(new FileOutputStream(file))
- }
-
- out.writeObject(pair)
- out.flush()
- isDirty = true
-
- // Close the old file if has crossed the blockSize limit
- if (file.length > Shuffle.BlockSize) {
- out.close()
- logInfo("END WRITE: " + file)
- val writeTime = System.currentTimeMillis - writeStartTime
- logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.")
-
- blockNum = blockNum + 1
- isDirty = false
- }
- })
-
- if (isDirty) {
- out.close()
- logInfo("END WRITE: " + file)
- val writeTime = System.currentTimeMillis - writeStartTime
- logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.")
-
- blockNum = blockNum + 1
- }
-
- // Write the BLOCKNUM file
- file = TrackedCustomBlockedLocalFileShuffle.getBlockNumOutputFile(shuffleId,
- myIndex, i)
- out = new ObjectOutputStream(new FileOutputStream(file))
- out.writeObject(blockNum)
- out.close()
-
- // Store number of blocks for this outputSplit
- numBlocksPerOutputSplit(i) = blockNum
- }
-
- var retVal = SplitInfo(TrackedCustomBlockedLocalFileShuffle.serverAddress,
- TrackedCustomBlockedLocalFileShuffle.serverPort, myIndex)
- retVal.totalBlocksPerOutputSplit = numBlocksPerOutputSplit
-
- (retVal)
- }).collect()
-
- // Start tracker
- var shuffleTracker = new ShuffleTracker(outputLocs)
- shuffleTracker.setDaemon(true)
- shuffleTracker.start()
- logInfo("ShuffleTracker started...")
-
- // Return an RDD that does each of the merges for a given partition
- val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits)
- return indexes.flatMap((myId: Int) => {
- totalSplits = outputLocs.size
- hasSplits = 0
-
- totalBlocksInSplit = Array.tabulate(totalSplits)(_ => -1)
- hasBlocksInSplit = Array.tabulate(totalSplits)(_ => 0)
-
- hasSplitsBitVector = new BitSet(totalSplits)
- splitsInRequestBitVector = new BitSet(totalSplits)
-
- receivedData = new LinkedBlockingQueue[(Int, Array[Byte])]
- combiners = new HashMap[K, C]
-
- var threadPool = Shuffle.newDaemonFixedThreadPool(
- Shuffle.MaxRxConnections)
-
- while (hasSplits < totalSplits) {
- // Local status of hasSplitsBitVector and splitsInRequestBitVector
- val localSplitInfo = getLocalSplitInfo(myId)
-
- // DO NOT talk to the tracker if all the required splits are already busy
- val hasOrWillHaveSplits = localSplitInfo.hasSplitsBitVector.cardinality
-
- var numThreadsToCreate =
- Math.min(totalSplits - hasOrWillHaveSplits, Shuffle.MaxRxConnections) -
- threadPool.getActiveCount
-
- while (hasSplits < totalSplits && numThreadsToCreate > 0) {
- // Receive which split to pull from the tracker
- logInfo("Talking to tracker...")
- val startTime = System.currentTimeMillis
- val splitIndices = getTrackerSelectedSplit(myId)
- logInfo("Got %s from tracker in %d millis".format(splitIndices, System.currentTimeMillis - startTime))
-
- if (splitIndices.size > 0) {
- splitIndices.foreach { splitIndex =>
- val selectedSplitInfo = outputLocs(splitIndex)
- val requestSplit =
- "%d/%d/%d".format(shuffleId, selectedSplitInfo.splitId, myId)
-
- threadPool.execute(new ShuffleClient(splitIndex, selectedSplitInfo,
- requestSplit, myId))
-
- // splitIndex is in transit. Will be unset in the ShuffleClient
- splitsInRequestBitVector.synchronized {
- splitsInRequestBitVector.set(splitIndex)
- }
- }
- } else {
- // Tracker replied back with a NO. Sleep for a while.
- Thread.sleep(Shuffle.MinKnockInterval)
- numThreadsToCreate = 0
- }
-
- numThreadsToCreate = numThreadsToCreate - splitIndices.size
- }
-
- // Sleep for a while before creating new threads
- Thread.sleep(Shuffle.MinKnockInterval)
- }
-
- threadPool.shutdown()
-
- // Start consumer
- // TODO: Consumption is delayed until everything has been received.
- // Otherwise it interferes with network performance
- var shuffleConsumer = new ShuffleConsumer(mergeCombiners)
- shuffleConsumer.setDaemon(true)
- shuffleConsumer.start()
- logInfo("ShuffleConsumer started...")
-
- // Don't return until consumption is finished
- // while (receivedData.size > 0) {
- // Thread.sleep(Shuffle.MinKnockInterval)
- // }
-
- // Wait till shuffleConsumer is done
- shuffleConsumer.join
-
- combiners
- })
- }
-
- private def getLocalSplitInfo(myId: Int): SplitInfo = {
- var localSplitInfo = SplitInfo(InetAddress.getLocalHost.getHostAddress,
- SplitInfo.UnusedParam, myId)
-
- // Store hasSplits
- localSplitInfo.hasSplits = hasSplits
-
- // Store hasSplitsBitVector
- hasSplitsBitVector.synchronized {
- localSplitInfo.hasSplitsBitVector =
- hasSplitsBitVector.clone.asInstanceOf[BitSet]
- }
-
- // Store hasBlocksInSplit to hasBlocksPerInputSplit
- hasBlocksInSplit.synchronized {
- localSplitInfo.hasBlocksPerInputSplit =
- hasBlocksInSplit.clone.asInstanceOf[Array[Int]]
- }
-
- // Include the splitsInRequest as well
- splitsInRequestBitVector.synchronized {
- localSplitInfo.hasSplitsBitVector.or(splitsInRequestBitVector)
- }
-
- return localSplitInfo
- }
-
- def selectRandomSplit: Int = {
- var requiredSplits = new ArrayBuffer[Int]
-
- synchronized {
- for (i <- 0 until totalSplits) {
- if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) {
- requiredSplits += i
- }
- }
- }
-
- if (requiredSplits.size > 0) {
- requiredSplits(TrackedCustomBlockedLocalFileShuffle.ranGen.nextInt(
- requiredSplits.size))
- } else {
- -1
- }
- }
-
- // Talks to the tracker and receives instruction
- private def getTrackerSelectedSplit(myId: Int): ArrayBuffer[Int] = {
- // Local status of hasSplitsBitVector and splitsInRequestBitVector
- val localSplitInfo = getLocalSplitInfo(myId)
-
- // DO NOT talk to the tracker if all the required splits are already busy
- if (localSplitInfo.hasSplitsBitVector.cardinality == totalSplits) {
- return ArrayBuffer[Int]()
- }
-
- val clientSocketToTracker = new Socket(Shuffle.MasterHostAddress,
- Shuffle.MasterTrackerPort)
- val oosTracker =
- new ObjectOutputStream(clientSocketToTracker.getOutputStream)
- oosTracker.flush()
- val oisTracker =
- new ObjectInputStream(clientSocketToTracker.getInputStream)
-
- var selectedSplitIndices = ArrayBuffer[Int]()
-
- // Setup the timeout mechanism
- var timeOutTask = new TimerTask {
- override def run: Unit = {
- logInfo("Waited enough for tracker response... Take random response...")
-
- // sockets will be closed in finally
- // TODO: Sometimes timer wont go off
-
- // TODO: Selecting randomly here. Tracker won't know about it and get an
- // asssertion failure when this thread leaves
-
- selectedSplitIndices = ArrayBuffer(selectRandomSplit)
- }
- }
-
- var timeOutTimer = new Timer
- // TODO: Which timeout to use?
- // timeOutTimer.schedule(timeOutTask, Shuffle.MinKnockInterval)
-
- try {
- // Send intention
- oosTracker.writeObject(Shuffle.ReducerEntering)
- oosTracker.flush()
-
- // Send what this reducer has
- oosTracker.writeObject(localSplitInfo)
- oosTracker.flush()
-
- // Receive reply from the tracker
- selectedSplitIndices = oisTracker.readObject.asInstanceOf[ArrayBuffer[Int]]
-
- // Turn the timer OFF
- timeOutTimer.cancel()
- } catch {
- case e: Exception => {
- logInfo("getTrackerSelectedSplit had a " + e)
- }
- } finally {
- oisTracker.close()
- oosTracker.close()
- clientSocketToTracker.close()
- }
-
- return selectedSplitIndices
- }
-
- class ShuffleTracker(outputLocs: Array[SplitInfo])
- extends Thread with Logging {
- var threadPool = Shuffle.newDaemonCachedThreadPool
- var serverSocket: ServerSocket = null
-
- // Create trackerStrategy object
- val trackerStrategyClass = System.getProperty(
- "spark.shuffle.trackerStrategy",
- "spark.BalanceConnectionsShuffleTrackerStrategy")
-
- val trackerStrategy =
- Class.forName(trackerStrategyClass).newInstance().asInstanceOf[ShuffleTrackerStrategy]
-
- // Must initialize here by supplying the outputLocs param
- // TODO: This could be avoided by directly passing it to the constructor
- trackerStrategy.initialize(outputLocs)
-
- override def run: Unit = {
- serverSocket = new ServerSocket(Shuffle.MasterTrackerPort)
- logInfo("ShuffleTracker" + serverSocket)
-
- try {
- while (true) {
- var clientSocket: Socket = null
- try {
- clientSocket = serverSocket.accept()
- } catch {
- case e: Exception => {
- logInfo("ShuffleTracker had a " + e)
- }
- }
-
- if (clientSocket != null) {
- try {
- threadPool.execute(new Thread {
- override def run: Unit = {
- val oos = new ObjectOutputStream(clientSocket.getOutputStream)
- oos.flush()
- val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- try {
- // Receive intention
- val reducerIntention = ois.readObject.asInstanceOf[Int]
-
- if (reducerIntention == Shuffle.ReducerEntering) {
- // Receive what the reducer has
- val reducerSplitInfo =
- ois.readObject.asInstanceOf[SplitInfo]
-
- // Select splits and update stats if necessary
- var selectedSplitIndices = ArrayBuffer[Int]()
- trackerStrategy.synchronized {
- selectedSplitIndices = trackerStrategy.selectSplit(
- reducerSplitInfo)
- }
-
- // Send reply back
- oos.writeObject(selectedSplitIndices)
- oos.flush()
-
- // Update internal stats, only if receiver got the reply
- trackerStrategy.synchronized {
- trackerStrategy.AddReducerToSplit(reducerSplitInfo,
- selectedSplitIndices)
- }
- }
- else if (reducerIntention == Shuffle.ReducerLeaving) {
- val reducerSplitInfo =
- ois.readObject.asInstanceOf[SplitInfo]
-
- // Receive reception stats: how many blocks the reducer
- // read in how much time and from where
- val receptionStat =
- ois.readObject.asInstanceOf[ReceptionStats]
-
- // Update stats
- trackerStrategy.synchronized {
- trackerStrategy.deleteReducerFrom(reducerSplitInfo,
- receptionStat)
- }
-
- // Send ACK
- oos.writeObject(receptionStat.serverSplitIndex)
- oos.flush()
- }
- else {
- throw new SparkException("Undefined reducerIntention")
- }
- } catch {
- // EOFException is expected to happen because receiver can
- // break connection due to timeout and pick random instead
- case eofe: java.io.EOFException => { }
- case e: Exception => {
- logInfo("ShuffleTracker had a " + e)
- }
- } finally {
- ois.close()
- oos.close()
- clientSocket.close()
- }
- }
- })
- } catch {
- // In failure, close socket here; else, client thread will close
- case ioe: IOException => {
- clientSocket.close()
- }
- }
- }
- }
- } finally {
- serverSocket.close()
- }
- // Shutdown the thread pool
- threadPool.shutdown()
- }
- }
-
- class ShuffleConsumer(mergeCombiners: (C, C) => C)
- extends Thread with Logging {
- override def run: Unit = {
- // Run until all splits are here
- while (receivedData.size > 0) {
- var splitIndex = -1
- var recvByteArray: Array[Byte] = null
-
- try {
- var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])]
- splitIndex = tempPair._1
- recvByteArray = tempPair._2
- } catch {
- case e: Exception => {
- logInfo("Exception during taking data from receivedData")
- }
- }
-
- val inputStream =
- new ObjectInputStream(new ByteArrayInputStream(recvByteArray))
-
- try{
- while (true) {
- val (k, c) = inputStream.readObject.asInstanceOf[(K, C)]
- combiners(k) = combiners.get(k) match {
- case Some(oldC) => mergeCombiners(oldC, c)
- case None => c
- }
- }
- } catch {
- case e: EOFException => { }
- }
- inputStream.close()
- }
- }
- }
-
- class ShuffleClient(splitIndex: Int, serversplitInfo: SplitInfo,
- requestSplit: String, myId: Int)
- extends Thread with Logging {
- private var peerSocketToSource: Socket = null
- private var oosSource: ObjectOutputStream = null
- private var oisSource: ObjectInputStream = null
-
- private var receptionSucceeded = false
-
- // Make sure that multiple messages don't go to the tracker
- private var alreadySentLeavingNotification = false
-
- // Keep track of bytes received and time spent
- private var numBytesReceived = 0
- private var totalTimeSpent = 0
-
- override def run: Unit = {
- // Setup the timeout mechanism
- var timeOutTask = new TimerTask {
- override def run: Unit = {
- cleanUp()
- }
- }
-
- var timeOutTimer = new Timer
- timeOutTimer.schedule(timeOutTask, Shuffle.MaxKnockInterval)
-
- try {
- // Everything will break if BLOCKNUM is not correctly received
- // First get BLOCKNUM file if totalBlocksInSplit(splitIndex) is unknown
- peerSocketToSource = new Socket(
- serversplitInfo.hostAddress, serversplitInfo.listenPort)
- oosSource =
- new ObjectOutputStream(peerSocketToSource.getOutputStream)
- oosSource.flush()
- var isSource = peerSocketToSource.getInputStream
- oisSource = new ObjectInputStream(isSource)
-
- // Send path information
- oosSource.writeObject(requestSplit)
-
- // TODO: Can be optimized. No need to do it everytime.
- // Receive BLOCKNUM
- totalBlocksInSplit(splitIndex) = oisSource.readObject.asInstanceOf[Int]
-
- // Turn the timer OFF, if the sender responds before timeout
- timeOutTimer.cancel()
-
- while (hasBlocksInSplit(splitIndex) < totalBlocksInSplit(splitIndex)) {
- // Set receptionSucceeded to false before trying for each block
- receptionSucceeded = false
-
- // Request specific block
- oosSource.writeObject(hasBlocksInSplit(splitIndex))
-
- // Good to go. First, receive the length of the requested file
- var requestedFileLen = oisSource.readObject.asInstanceOf[Int]
- logInfo("Received requestedFileLen = " + requestedFileLen)
-
- // Create a temp variable to be used in different places
- val requestPath = "http://%s:%d/shuffle/%s-%d".format(
- serversplitInfo.hostAddress, serversplitInfo.listenPort, requestSplit,
- hasBlocksInSplit(splitIndex))
-
- // Receive the file
- if (requestedFileLen != -1) {
- val readStartTime = System.currentTimeMillis
- logInfo("BEGIN READ: " + requestPath)
-
- // Receive data in an Array[Byte]
- var recvByteArray = new Array[Byte](requestedFileLen)
- var alreadyRead = 0
- var bytesRead = 0
-
- while (alreadyRead != requestedFileLen) {
- bytesRead = isSource.read(recvByteArray, alreadyRead,
- requestedFileLen - alreadyRead)
- if (bytesRead > 0) {
- alreadyRead = alreadyRead + bytesRead
- }
- }
-
- // Make it available to the consumer
- try {
- receivedData.put((splitIndex, recvByteArray))
- } catch {
- case e: Exception => {
- logInfo("Exception during putting data into receivedData")
- }
- }
-
- // TODO: Updating stats before consumption is completed
- hasBlocksInSplit(splitIndex) = hasBlocksInSplit(splitIndex) + 1
-
- // Split has been received only if all the blocks have been received
- if (hasBlocksInSplit(splitIndex) == totalBlocksInSplit(splitIndex)) {
- hasSplitsBitVector.synchronized {
- hasSplitsBitVector.set(splitIndex)
- hasSplits += 1
- }
- }
-
- receptionSucceeded = true
-
- logInfo("END READ: " + requestPath)
- val readTime = System.currentTimeMillis - readStartTime
- logInfo("Reading " + requestPath + " took " + readTime + " millis.")
-
- // Update stats
- numBytesReceived = numBytesReceived + requestedFileLen
- totalTimeSpent = totalTimeSpent + readTime.toInt
- } else {
- throw new SparkException("ShuffleServer " + serversplitInfo.hostAddress + " does not have " + requestSplit)
- }
- }
- } catch {
- // EOFException is expected to happen because sender can break
- // connection due to timeout
- case eofe: java.io.EOFException => { }
- case e: Exception => {
- logInfo("ShuffleClient had a " + e)
- }
- } finally {
- splitsInRequestBitVector.synchronized {
- splitsInRequestBitVector.set(splitIndex, false)
- }
- cleanUp()
- }
- }
-
- // Connect to the tracker and update its stats
- private def sendLeavingNotification(): Unit = synchronized {
- if (!alreadySentLeavingNotification) {
- val clientSocketToTracker = new Socket(Shuffle.MasterHostAddress,
- Shuffle.MasterTrackerPort)
- val oosTracker =
- new ObjectOutputStream(clientSocketToTracker.getOutputStream)
- oosTracker.flush()
- val oisTracker =
- new ObjectInputStream(clientSocketToTracker.getInputStream)
-
- try {
- // Send intention
- oosTracker.writeObject(Shuffle.ReducerLeaving)
- oosTracker.flush()
-
- // Send reducerSplitInfo
- oosTracker.writeObject(getLocalSplitInfo(myId))
- oosTracker.flush()
-
- // Send reception stats
- oosTracker.writeObject(ReceptionStats(
- numBytesReceived, totalTimeSpent, splitIndex))
- oosTracker.flush()
-
- // Receive ACK. No need to do anything with that
- oisTracker.readObject.asInstanceOf[Int]
-
- // Now update sentLeavingNotifacation
- alreadySentLeavingNotification = true
- } catch {
- case e: Exception => {
- logInfo("sendLeavingNotification had a " + e)
- }
- } finally {
- oisTracker.close()
- oosTracker.close()
- clientSocketToTracker.close()
- }
- }
- }
-
- private def cleanUp(): Unit = {
- // Update tracker stats first
- sendLeavingNotification()
-
- // Clean up the connections to the mapper
- if (oisSource != null) {
- oisSource.close()
- }
- if (oosSource != null) {
- oosSource.close()
- }
- if (peerSocketToSource != null) {
- peerSocketToSource.close()
- }
-
- logInfo("Leaving client")
- }
- }
-}
-
-object TrackedCustomBlockedLocalFileShuffle extends Logging {
- private var initialized = false
- private var nextShuffleId = new AtomicLong(0)
-
- // Variables initialized by initializeIfNeeded()
- private var shuffleDir: File = null
-
- private var shuffleServer: ShuffleServer = null
- private var serverAddress = InetAddress.getLocalHost.getHostAddress
- private var serverPort: Int = -1
-
- // Random number generator
- var ranGen = new Random
-
- private def initializeIfNeeded() = synchronized {
- if (!initialized) {
- // TODO: localDir should be created by some mechanism common to Spark
- // so that it can be shared among shuffle, broadcast, etc
- val localDirRoot = System.getProperty("spark.local.dir", "/tmp")
- var tries = 0
- var foundLocalDir = false
- var localDir: File = null
- var localDirUuid: UUID = null
- while (!foundLocalDir && tries < 10) {
- tries += 1
- try {
- localDirUuid = UUID.randomUUID
- localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
- if (!localDir.exists) {
- localDir.mkdirs()
- foundLocalDir = true
- }
- } catch {
- case e: Exception =>
- logWarning("Attempt " + tries + " to create local dir failed", e)
- }
- }
- if (!foundLocalDir) {
- logError("Failed 10 attempts to create local dir in " + localDirRoot)
- System.exit(1)
- }
- shuffleDir = new File(localDir, "shuffle")
- shuffleDir.mkdirs()
- logInfo("Shuffle dir: " + shuffleDir)
-
- // Create and start the shuffleServer
- shuffleServer = new ShuffleServer
- shuffleServer.setDaemon(true)
- shuffleServer.start()
- logInfo("ShuffleServer started...")
-
- initialized = true
- }
- }
-
- def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int,
- blockId: Int): File = {
- initializeIfNeeded()
- val dir = new File(shuffleDir, shuffleId + "/" + inputId)
- dir.mkdirs()
- val file = new File(dir, "%d-%d".format(outputId, blockId))
- return file
- }
-
- def getBlockNumOutputFile(shuffleId: Long, inputId: Int,
- outputId: Int): File = {
- initializeIfNeeded()
- val dir = new File(shuffleDir, shuffleId + "/" + inputId)
- dir.mkdirs()
- val file = new File(dir, outputId + "-BLOCKNUM")
- return file
- }
-
- def newShuffleId(): Long = {
- nextShuffleId.getAndIncrement()
- }
-
- class ShuffleServer
- extends Thread with Logging {
- var threadPool = Shuffle.newDaemonFixedThreadPool(Shuffle.MaxTxConnections)
-
- var serverSocket: ServerSocket = null
-
- override def run: Unit = {
- serverSocket = new ServerSocket(0)
- serverPort = serverSocket.getLocalPort
-
- logInfo("ShuffleServer started with " + serverSocket)
- logInfo("Local URI: http://" + serverAddress + ":" + serverPort)
-
- try {
- while (true) {
- var clientSocket: Socket = null
- try {
- clientSocket = serverSocket.accept()
- } catch {
- case e: Exception => { }
- }
- if (clientSocket != null) {
- logInfo("Serve: Accepted new client connection:" + clientSocket)
- try {
- threadPool.execute(new ShuffleServerThread(clientSocket))
- } catch {
- // In failure, close socket here; else, the thread will close it
- case ioe: IOException => {
- clientSocket.close()
- }
- }
- }
- }
- } finally {
- if (serverSocket != null) {
- logInfo("ShuffleServer now stopping...")
- serverSocket.close()
- }
- }
- // Shutdown the thread pool
- threadPool.shutdown()
- }
-
- class ShuffleServerThread(val clientSocket: Socket)
- extends Thread with Logging {
- private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream]
- os.flush()
- private val bos = new BufferedOutputStream(os)
- bos.flush()
- private val oos = new ObjectOutputStream(os)
- oos.flush()
- private val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- logInfo("new ShuffleServerThread is running")
-
- override def run: Unit = {
- try {
- // Receive basic path information
- var requestPathBase = ois.readObject.asInstanceOf[String]
-
- logInfo("requestPathBase: " + requestPathBase)
-
- // Read BLOCKNUM file and send back the total number of blocks
- val blockNumFilePath = "%s/%s-BLOCKNUM".format(shuffleDir,
- requestPathBase)
- val blockNumIn =
- new ObjectInputStream(new FileInputStream(blockNumFilePath))
- val BLOCKNUM = blockNumIn.readObject.asInstanceOf[Int]
- blockNumIn.close()
-
- oos.writeObject(BLOCKNUM)
- oos.flush()
-
- val startTime = System.currentTimeMillis
- var curTime = startTime
- var keepSending = true
- var numBlocksToSend = Shuffle.MaxChatBlocks
-
- while (keepSending && numBlocksToSend > 0) {
- // Receive specific block request
- val blockId = ois.readObject.asInstanceOf[Int]
-
- // Ready to send
- var requestPath = requestPathBase + "-" + blockId
-
- // Open the file
- var requestedFile: File = null
- var requestedFileLen = -1
- try {
- requestedFile = new File(shuffleDir + "/" + requestPath)
- requestedFileLen = requestedFile.length.toInt
- } catch {
- case e: Exception => { }
- }
-
- // Send the length of the requestPath to let the receiver know that
- // transfer is about to start
- // In the case of receiver timeout and connection close, this will
- // throw a java.net.SocketException: Broken pipe
- oos.writeObject(requestedFileLen)
- oos.flush()
-
- logInfo("requestedFileLen = " + requestedFileLen)
-
- // Read and send the requested file
- if (requestedFileLen != -1) {
- // Read
- var byteArray = new Array[Byte](requestedFileLen)
- val bis =
- new BufferedInputStream(new FileInputStream(requestedFile))
-
- var bytesRead = bis.read(byteArray, 0, byteArray.length)
- var alreadyRead = bytesRead
-
- while (alreadyRead < requestedFileLen) {
- bytesRead = bis.read(byteArray, alreadyRead,
- (byteArray.length - alreadyRead))
- if(bytesRead > 0) {
- alreadyRead = alreadyRead + bytesRead
- }
- }
- bis.close()
-
- // Send
- bos.write(byteArray, 0, byteArray.length)
- bos.flush()
-
- // Update loop variables
- numBlocksToSend = numBlocksToSend - 1
-
- curTime = System.currentTimeMillis
- // Revoke sending only if there is anyone waiting in the queue
- // TODO: Turning OFF the optimization so that reducers go back to
- // tracker get advice
- if (curTime - startTime >= Shuffle.MaxChatTime /* &&
- threadPool.getQueue.size > 0 */) {
- keepSending = false
- }
- } else {
- // Close the connection
- }
- }
- } catch {
- // If something went wrong, e.g., the worker at the other end died etc
- // then close everything up
- // Exception can happen if the receiver stops receiving
- // EOFException is expected to happen because receiver can break
- // connection as soon as it has all the blocks
- case eofe: java.io.EOFException => { }
- case e: Exception => {
- logInfo("ShuffleServerThread had a " + e)
- }
- } finally {
- logInfo("ShuffleServerThread is closing streams and sockets")
- ois.close()
- // TODO: Following can cause "java.net.SocketException: Socket closed"
- oos.close()
- bos.close()
- clientSocket.close()
- }
- }
- }
- }
-}
diff --git a/core/src/main/scala/spark/TrackedCustomParallelLocalFileShuffle.scala b/core/src/main/scala/spark/TrackedCustomParallelLocalFileShuffle.scala
deleted file mode 100644
index 4a38a8d7ff..0000000000
--- a/core/src/main/scala/spark/TrackedCustomParallelLocalFileShuffle.scala
+++ /dev/null
@@ -1,802 +0,0 @@
-package spark
-
-import java.io._
-import java.net._
-import java.util.{BitSet, Random, Timer, TimerTask, UUID}
-import java.util.concurrent.atomic.AtomicLong
-import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory}
-
-import scala.collection.mutable.{ArrayBuffer, HashMap}
-
-/**
- * An implementation of shuffle using local files served through custom server
- * where receivers create simultaneous connections to multiple servers by
- * setting the 'spark.shuffle.maxRxConnections' config option.
- *
- * 'spark.shuffle.maxTxConnections' enforces server-side cap. Ideally,
- * maxTxConnections >= maxRxConnections * numReducersPerMachine
- *
- * 'spark.shuffle.TrackerStrategy' decides which strategy to use in the tracker
- *
- * TODO: Add support for compression when spark.compress is set to true.
- */
-@serializable
-class TrackedCustomParallelLocalFileShuffle[K, V, C]
-extends Shuffle[K, V, C] with Logging {
- @transient var totalSplits = 0
- @transient var hasSplits = 0
- @transient var hasSplitsBitVector: BitSet = null
- @transient var splitsInRequestBitVector: BitSet = null
-
- @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null
- @transient var combiners: HashMap[K,C] = null
-
- override def compute(input: RDD[(K, V)],
- numOutputSplits: Int,
- createCombiner: V => C,
- mergeValue: (C, V) => C,
- mergeCombiners: (C, C) => C)
- : RDD[(K, C)] =
- {
- val sc = input.sparkContext
- val shuffleId = TrackedCustomParallelLocalFileShuffle.newShuffleId()
- logInfo("Shuffle ID: " + shuffleId)
-
- val splitRdd = new NumberedSplitRDD(input)
- val numInputSplits = splitRdd.splits.size
-
- // Run a parallel map and collect to write the intermediate data files
- val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => {
- val myIndex = pair._1
- val myIterator = pair._2
- val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C])
- for ((k, v) <- myIterator) {
- var bucketId = k.hashCode % numOutputSplits
- if (bucketId < 0) { // Fix bucket ID if hash code was negative
- bucketId += numOutputSplits
- }
- val bucket = buckets(bucketId)
- bucket(k) = bucket.get(k) match {
- case Some(c) => mergeValue(c, v)
- case None => createCombiner(v)
- }
- }
-
- for (i <- 0 until numOutputSplits) {
- val file = TrackedCustomParallelLocalFileShuffle.getOutputFile(shuffleId,
- myIndex, i)
- val writeStartTime = System.currentTimeMillis
- logInfo("BEGIN WRITE: " + file)
- val out = new ObjectOutputStream(new FileOutputStream(file))
- buckets(i).foreach(pair => out.writeObject(pair))
- out.close()
- logInfo("END WRITE: " + file)
- val writeTime = System.currentTimeMillis - writeStartTime
- logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.")
- }
-
- (SplitInfo (TrackedCustomParallelLocalFileShuffle.serverAddress,
- TrackedCustomParallelLocalFileShuffle.serverPort, myIndex))
- }).collect()
-
- // Start tracker
- var shuffleTracker = new ShuffleTracker(outputLocs)
- shuffleTracker.setDaemon(true)
- shuffleTracker.start()
- logInfo("ShuffleTracker started...")
-
- // Return an RDD that does each of the merges for a given partition
- val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits)
- return indexes.flatMap((myId: Int) => {
- totalSplits = outputLocs.size
- hasSplits = 0
- hasSplitsBitVector = new BitSet(totalSplits)
- splitsInRequestBitVector = new BitSet(totalSplits)
-
- receivedData = new LinkedBlockingQueue[(Int, Array[Byte])]
- combiners = new HashMap[K, C]
-
- var threadPool = Shuffle.newDaemonFixedThreadPool(
- Shuffle.MaxRxConnections)
-
- while (hasSplits < totalSplits) {
- // Local status of hasSplitsBitVector and splitsInRequestBitVector
- val localSplitInfo = getLocalSplitInfo(myId)
-
- // DO NOT talk to the tracker if all the required splits are already busy
- val hasOrWillHaveSplits = localSplitInfo.hasSplitsBitVector.cardinality
-
- var numThreadsToCreate =
- Math.min(totalSplits - hasOrWillHaveSplits, Shuffle.MaxRxConnections) -
- threadPool.getActiveCount
-
- while (hasSplits < totalSplits && numThreadsToCreate > 0) {
- // Receive which split to pull from the tracker
- logInfo("Talking to tracker...")
- val startTime = System.currentTimeMillis
- val splitIndices = getTrackerSelectedSplit(myId)
- logInfo("Got %s from tracker in %d millis".format(splitIndices, System.currentTimeMillis - startTime))
-
- if (splitIndices.size > 0) {
- splitIndices.foreach { splitIndex =>
- val selectedSplitInfo = outputLocs(splitIndex)
- val requestSplit =
- "%d/%d/%d".format(shuffleId, selectedSplitInfo.splitId, myId)
-
- threadPool.execute(new ShuffleClient(splitIndex, selectedSplitInfo,
- requestSplit, myId))
-
- // splitIndex is in transit. Will be unset in the ShuffleClient
- splitsInRequestBitVector.synchronized {
- splitsInRequestBitVector.set(splitIndex)
- }
- }
- } else {
- // Tracker replied back with a NO. Sleep for a while.
- Thread.sleep(Shuffle.MinKnockInterval)
- numThreadsToCreate = 0
- }
-
- numThreadsToCreate = numThreadsToCreate - splitIndices.size
- }
-
- // Sleep for a while before creating new threads
- Thread.sleep(Shuffle.MinKnockInterval)
- }
-
- threadPool.shutdown()
-
- // Start consumer
- // TODO: Consumption is delayed until everything has been received.
- // Otherwise it interferes with network performance
- var shuffleConsumer = new ShuffleConsumer(mergeCombiners)
- shuffleConsumer.setDaemon(true)
- shuffleConsumer.start()
- logInfo("ShuffleConsumer started...")
-
- // Don't return until consumption is finished
- // while (receivedData.size > 0) {
- // Thread.sleep(Shuffle.MinKnockInterval)
- // }
-
- // Wait till shuffleConsumer is done
- shuffleConsumer.join
-
- combiners
- })
- }
-
- private def getLocalSplitInfo(myId: Int): SplitInfo = {
- var localSplitInfo = SplitInfo(InetAddress.getLocalHost.getHostAddress,
- SplitInfo.UnusedParam, myId)
-
- localSplitInfo.hasSplits = hasSplits
-
- hasSplitsBitVector.synchronized {
- localSplitInfo.hasSplitsBitVector =
- hasSplitsBitVector.clone.asInstanceOf[BitSet]
- }
-
- // Include the splitsInRequest as well
- splitsInRequestBitVector.synchronized {
- localSplitInfo.hasSplitsBitVector.or(splitsInRequestBitVector)
- }
-
- return localSplitInfo
- }
-
- // Selects a random split using local information
- private def selectRandomSplit: Int = {
- var requiredSplits = new ArrayBuffer[Int]
-
- synchronized {
- for (i <- 0 until totalSplits) {
- if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) {
- requiredSplits += i
- }
- }
- }
-
- if (requiredSplits.size > 0) {
- requiredSplits(TrackedCustomParallelLocalFileShuffle.ranGen.nextInt(
- requiredSplits.size))
- } else {
- -1
- }
- }
-
- // Talks to the tracker and receives instruction
- private def getTrackerSelectedSplit(myId: Int): ArrayBuffer[Int] = {
- // Local status of hasSplitsBitVector and splitsInRequestBitVector
- val localSplitInfo = getLocalSplitInfo(myId)
-
- // DO NOT talk to the tracker if all the required splits are already busy
- if (localSplitInfo.hasSplitsBitVector.cardinality == totalSplits) {
- return ArrayBuffer[Int]()
- }
-
- val clientSocketToTracker = new Socket(Shuffle.MasterHostAddress,
- Shuffle.MasterTrackerPort)
- val oosTracker =
- new ObjectOutputStream(clientSocketToTracker.getOutputStream)
- oosTracker.flush()
- val oisTracker =
- new ObjectInputStream(clientSocketToTracker.getInputStream)
-
- var selectedSplitIndices = ArrayBuffer[Int]()
-
- // Setup the timeout mechanism
- var timeOutTask = new TimerTask {
- override def run: Unit = {
- logInfo("Waited enough for tracker response... Take random response...")
-
- // sockets will be closed in finally
- // TODO: Sometimes timer wont go off
-
- // TODO: Selecting randomly here. Tracker won't know about it and get an
- // asssertion failure when this thread leaves
-
- selectedSplitIndices = ArrayBuffer(selectRandomSplit)
- }
- }
-
- var timeOutTimer = new Timer
- // TODO: Which timeout to use?
- // timeOutTimer.schedule(timeOutTask, Shuffle.MinKnockInterval)
-
- try {
- // Send intention
- oosTracker.writeObject(Shuffle.ReducerEntering)
- oosTracker.flush()
-
- // Send what this reducer has
- oosTracker.writeObject(localSplitInfo)
- oosTracker.flush()
-
- // Receive reply from the tracker
- selectedSplitIndices = oisTracker.readObject.asInstanceOf[ArrayBuffer[Int]]
-
- // Turn the timer OFF
- timeOutTimer.cancel()
- } catch {
- case e: Exception => {
- logInfo("getTrackerSelectedSplit had a " + e)
- }
- } finally {
- oisTracker.close()
- oosTracker.close()
- clientSocketToTracker.close()
- }
-
- return selectedSplitIndices
- }
-
- class ShuffleTracker(outputLocs: Array[SplitInfo])
- extends Thread with Logging {
- var threadPool = Shuffle.newDaemonCachedThreadPool
- var serverSocket: ServerSocket = null
-
- // Create trackerStrategy object
- val trackerStrategyClass = System.getProperty(
- "spark.shuffle.trackerStrategy",
- "spark.BalanceConnectionsShuffleTrackerStrategy")
-
- val trackerStrategy =
- Class.forName(trackerStrategyClass).newInstance().asInstanceOf[ShuffleTrackerStrategy]
-
- // Must initialize here by supplying the outputLocs param
- // TODO: This could be avoided by directly passing it to the constructor
- trackerStrategy.initialize(outputLocs)
-
- override def run: Unit = {
- serverSocket = new ServerSocket(Shuffle.MasterTrackerPort)
- logInfo("ShuffleTracker" + serverSocket)
-
- try {
- while (true) {
- var clientSocket: Socket = null
- try {
- clientSocket = serverSocket.accept()
- } catch {
- case e: Exception => {
- logInfo("ShuffleTracker had a " + e)
- }
- }
-
- if (clientSocket != null) {
- try {
- threadPool.execute(new Thread {
- override def run: Unit = {
- val oos = new ObjectOutputStream(clientSocket.getOutputStream)
- oos.flush()
- val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- try {
- // Receive intention
- val reducerIntention = ois.readObject.asInstanceOf[Int]
-
- if (reducerIntention == Shuffle.ReducerEntering) {
- // Receive what the reducer has
- val reducerSplitInfo =
- ois.readObject.asInstanceOf[SplitInfo]
-
- // Select split and update stats if necessary
- var selectedSplitIndices = ArrayBuffer[Int]()
- trackerStrategy.synchronized {
- selectedSplitIndices = trackerStrategy.selectSplit(
- reducerSplitInfo)
- }
-
- // Send reply back
- oos.writeObject(selectedSplitIndices)
- oos.flush()
-
- // Update internal stats, only if receiver got the reply
- trackerStrategy.synchronized {
- trackerStrategy.AddReducerToSplit(reducerSplitInfo,
- selectedSplitIndices)
- }
- }
- else if (reducerIntention == Shuffle.ReducerLeaving) {
- val reducerSplitInfo =
- ois.readObject.asInstanceOf[SplitInfo]
-
- // Receive reception stats: how many blocks the reducer
- // read in how much time and from where
- val receptionStat =
- ois.readObject.asInstanceOf[ReceptionStats]
-
- // Update stats
- trackerStrategy.synchronized {
- trackerStrategy.deleteReducerFrom(reducerSplitInfo,
- receptionStat)
- }
-
- // Send ACK
- oos.writeObject(receptionStat.serverSplitIndex)
- oos.flush()
- }
- else {
- throw new SparkException("Undefined reducerIntention")
- }
- } catch {
- // EOFException is expected to happen because receiver can
- // break connection due to timeout and pick random instead
- case eofe: java.io.EOFException => { }
- case e: Exception => {
- logInfo("ShuffleTracker had a " + e)
- }
- } finally {
- ois.close()
- oos.close()
- clientSocket.close()
- }
- }
- })
- } catch {
- // In failure, close socket here; else, client thread will close
- case ioe: IOException => {
- clientSocket.close()
- }
- }
- }
- }
- } finally {
- serverSocket.close()
- }
- // Shutdown the thread pool
- threadPool.shutdown()
- }
- }
-
- class ShuffleConsumer(mergeCombiners: (C, C) => C)
- extends Thread with Logging {
- override def run: Unit = {
- // Run until all splits are here
- while (receivedData.size > 0) {
- var splitIndex = -1
- var recvByteArray: Array[Byte] = null
-
- try {
- var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])]
- splitIndex = tempPair._1
- recvByteArray = tempPair._2
- } catch {
- case e: Exception => {
- logInfo("Exception during taking data from receivedData")
- }
- }
-
- val inputStream =
- new ObjectInputStream(new ByteArrayInputStream(recvByteArray))
-
- try{
- while (true) {
- val (k, c) = inputStream.readObject.asInstanceOf[(K, C)]
- combiners(k) = combiners.get(k) match {
- case Some(oldC) => mergeCombiners(oldC, c)
- case None => c
- }
- }
- } catch {
- case e: EOFException => { }
- }
- inputStream.close()
- }
- }
- }
-
- class ShuffleClient(splitIndex: Int, serversplitInfo: SplitInfo,
- requestSplit: String, myId: Int)
- extends Thread with Logging {
- private var peerSocketToSource: Socket = null
- private var oosSource: ObjectOutputStream = null
- private var oisSource: ObjectInputStream = null
-
- private var receptionSucceeded = false
-
- // Make sure that multiple messages don't go to the tracker
- private var alreadySentLeavingNotification = false
-
- // Keep track of bytes received and time spent
- private var numBytesReceived = 0
- private var totalTimeSpent = 0
-
- override def run: Unit = {
- // Setup the timeout mechanism
- var timeOutTask = new TimerTask {
- override def run: Unit = {
- cleanUp()
- }
- }
-
- var timeOutTimer = new Timer
- timeOutTimer.schedule(timeOutTask, Shuffle.MaxKnockInterval)
-
- // Create a temp variable to be used in different places
- val requestPath = "http://%s:%d/shuffle/%s".format(
- serversplitInfo.hostAddress, serversplitInfo.listenPort, requestSplit)
-
- logInfo("ShuffleClient started... => " + requestPath)
-
- try {
- // Connect to the source
- peerSocketToSource = new Socket(
- serversplitInfo.hostAddress, serversplitInfo.listenPort)
- oosSource =
- new ObjectOutputStream(peerSocketToSource.getOutputStream)
- oosSource.flush()
- var isSource = peerSocketToSource.getInputStream
- oisSource = new ObjectInputStream(isSource)
-
- // Send the request
- oosSource.writeObject(requestSplit)
- oosSource.flush()
-
- // Receive the length of the requested file
- var requestedFileLen = oisSource.readObject.asInstanceOf[Int]
- logInfo("Received requestedFileLen = " + requestedFileLen)
-
- // Turn the timer OFF, if the sender responds before timeout
- timeOutTimer.cancel()
-
- // Receive the file
- if (requestedFileLen != -1) {
- val readStartTime = System.currentTimeMillis
- logInfo("BEGIN READ: " + requestPath)
-
- // Receive data in an Array[Byte]
- var recvByteArray = new Array[Byte](requestedFileLen)
- var alreadyRead = 0
- var bytesRead = 0
-
- while (alreadyRead < requestedFileLen) {
- bytesRead = isSource.read(recvByteArray, alreadyRead,
- requestedFileLen - alreadyRead)
- if (bytesRead > 0) {
- alreadyRead = alreadyRead + bytesRead
- }
- }
-
- // Make it available to the consumer
- try {
- receivedData.put((splitIndex, recvByteArray))
- } catch {
- case e: Exception => {
- logInfo("Exception during putting data into receivedData")
- }
- }
-
- // TODO: Updating stats before consumption is completed
- hasSplitsBitVector.synchronized {
- hasSplitsBitVector.set(splitIndex)
- hasSplits += 1
- }
-
- // We have received splitIndex
- splitsInRequestBitVector.synchronized {
- splitsInRequestBitVector.set(splitIndex, false)
- }
-
- receptionSucceeded = true
-
- logInfo("END READ: " + requestPath)
- val readTime = System.currentTimeMillis - readStartTime
- logInfo("Reading " + requestPath + " took " + readTime + " millis.")
-
- // Update stats
- numBytesReceived = numBytesReceived + requestedFileLen
- totalTimeSpent = totalTimeSpent + readTime.toInt
- } else {
- throw new SparkException("ShuffleServer " + serversplitInfo.hostAddress + " does not have " + requestSplit)
- }
- } catch {
- // EOFException is expected to happen because sender can break
- // connection due to timeout
- case eofe: java.io.EOFException => { }
- case e: Exception => {
- logInfo("ShuffleClient had a " + e)
- }
- } finally {
- // If reception failed, unset for future retry
- if (!receptionSucceeded) {
- splitsInRequestBitVector.synchronized {
- splitsInRequestBitVector.set(splitIndex, false)
- }
- }
- cleanUp()
- }
- }
-
- // Connect to the tracker and update its stats
- private def sendLeavingNotification(): Unit = synchronized {
- if (!alreadySentLeavingNotification) {
- val clientSocketToTracker = new Socket(Shuffle.MasterHostAddress,
- Shuffle.MasterTrackerPort)
- val oosTracker =
- new ObjectOutputStream(clientSocketToTracker.getOutputStream)
- oosTracker.flush()
- val oisTracker =
- new ObjectInputStream(clientSocketToTracker.getInputStream)
-
- try {
- // Send intention
- oosTracker.writeObject(Shuffle.ReducerLeaving)
- oosTracker.flush()
-
- // Send reducerSplitInfo
- oosTracker.writeObject(getLocalSplitInfo(myId))
- oosTracker.flush()
-
- // Send reception stats
- oosTracker.writeObject(ReceptionStats(
- numBytesReceived, totalTimeSpent, splitIndex))
- oosTracker.flush()
-
- // Receive ACK. No need to do anything with that
- oisTracker.readObject.asInstanceOf[Int]
-
- // Now update sentLeavingNotifacation
- alreadySentLeavingNotification = true
- } catch {
- case e: Exception => {
- logInfo("sendLeavingNotification had a " + e)
- }
- } finally {
- oisTracker.close()
- oosTracker.close()
- clientSocketToTracker.close()
- }
- }
- }
-
- private def cleanUp(): Unit = {
- // Update tracker stats first
- sendLeavingNotification()
-
- // Clean up the connections to the mapper
- if (oisSource != null) {
- oisSource.close()
- }
- if (oosSource != null) {
- oosSource.close()
- }
- if (peerSocketToSource != null) {
- peerSocketToSource.close()
- }
- }
- }
-}
-
-object TrackedCustomParallelLocalFileShuffle extends Logging {
- private var initialized = false
- private var nextShuffleId = new AtomicLong(0)
-
- // Variables initialized by initializeIfNeeded()
- private var shuffleDir: File = null
-
- private var shuffleServer: ShuffleServer = null
- private var serverAddress = InetAddress.getLocalHost.getHostAddress
- private var serverPort: Int = -1
-
- // Random number generator
- var ranGen = new Random
-
- private def initializeIfNeeded() = synchronized {
- if (!initialized) {
- // TODO: localDir should be created by some mechanism common to Spark
- // so that it can be shared among shuffle, broadcast, etc
- val localDirRoot = System.getProperty("spark.local.dir", "/tmp")
- var tries = 0
- var foundLocalDir = false
- var localDir: File = null
- var localDirUuid: UUID = null
-
- while (!foundLocalDir && tries < 10) {
- tries += 1
- try {
- localDirUuid = UUID.randomUUID
- localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
- if (!localDir.exists) {
- localDir.mkdirs()
- foundLocalDir = true
- }
- } catch {
- case e: Exception =>
- logWarning("Attempt " + tries + " to create local dir failed", e)
- }
- }
- if (!foundLocalDir) {
- logError("Failed 10 attempts to create local dir in " + localDirRoot)
- System.exit(1)
- }
- shuffleDir = new File(localDir, "shuffle")
- shuffleDir.mkdirs()
- logInfo("Shuffle dir: " + shuffleDir)
-
- // Create and start the shuffleServer
- shuffleServer = new ShuffleServer
- shuffleServer.setDaemon(true)
- shuffleServer.start()
- logInfo("ShuffleServer started...")
-
- initialized = true
- }
- }
-
- def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int): File = {
- initializeIfNeeded()
- val dir = new File(shuffleDir, shuffleId + "/" + inputId)
- dir.mkdirs()
- val file = new File(dir, "" + outputId)
- return file
- }
-
- def newShuffleId(): Long = {
- nextShuffleId.getAndIncrement()
- }
-
- class ShuffleServer
- extends Thread with Logging {
- var threadPool = Shuffle.newDaemonFixedThreadPool(Shuffle.MaxTxConnections)
-
- var serverSocket: ServerSocket = null
-
- override def run: Unit = {
- serverSocket = new ServerSocket(0)
- serverPort = serverSocket.getLocalPort
-
- logInfo("ShuffleServer started with " + serverSocket)
- logInfo("Local URI: http://" + serverAddress + ":" + serverPort)
-
- try {
- while (true) {
- var clientSocket: Socket = null
- try {
- clientSocket = serverSocket.accept()
- } catch {
- case e: Exception => { }
- }
- if (clientSocket != null) {
- logInfo("Serve: Accepted new client connection:" + clientSocket)
- try {
- threadPool.execute(new ShuffleServerThread(clientSocket))
- } catch {
- // In failure, close socket here; else, the thread will close it
- case ioe: IOException => {
- clientSocket.close()
- }
- }
- }
- }
- } finally {
- if (serverSocket != null) {
- logInfo("ShuffleServer now stopping...")
- serverSocket.close()
- }
- }
- // Shutdown the thread pool
- threadPool.shutdown()
- }
-
- class ShuffleServerThread(val clientSocket: Socket)
- extends Thread with Logging {
- private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream]
- os.flush()
- private val bos = new BufferedOutputStream(os)
- bos.flush()
- private val oos = new ObjectOutputStream(os)
- oos.flush()
- private val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- logInfo("new ShuffleServerThread is running")
-
- override def run: Unit = {
- try {
- // Receive requestPath from the receiver
- var requestPath = ois.readObject.asInstanceOf[String]
- logInfo("requestPath: " + shuffleDir + "/" + requestPath)
-
- // Open the file
- var requestedFile: File = null
- var requestedFileLen = -1
- try {
- requestedFile = new File(shuffleDir + "/" + requestPath)
- requestedFileLen = requestedFile.length.toInt
- } catch {
- case e: Exception => { }
- }
-
- // Send the length of the requestPath to let the receiver know that
- // transfer is about to start
- // In the case of receiver timeout and connection close, this will
- // throw a java.net.SocketException: Broken pipe
- oos.writeObject(requestedFileLen)
- oos.flush()
-
- logInfo("requestedFileLen = " + requestedFileLen)
-
- // Read and send the requested file
- if (requestedFileLen != -1) {
- // Read
- var byteArray = new Array[Byte](requestedFileLen)
- val bis =
- new BufferedInputStream(new FileInputStream(requestedFile))
-
- var bytesRead = bis.read(byteArray, 0, byteArray.length)
- var alreadyRead = bytesRead
-
- while (alreadyRead < requestedFileLen) {
- bytesRead = bis.read(byteArray, alreadyRead,
- (byteArray.length - alreadyRead))
- if(bytesRead > 0) {
- alreadyRead = alreadyRead + bytesRead
- }
- }
- bis.close()
-
- // Send
- bos.write(byteArray, 0, byteArray.length)
- bos.flush()
- } else {
- // Close the connection
- }
- } catch {
- // If something went wrong, e.g., the worker at the other end died etc
- // then close everything up
- // Exception can happen if the receiver stops receiving
- case e: Exception => {
- logInfo("ShuffleServerThread had a " + e)
- }
- } finally {
- logInfo("ShuffleServerThread is closing streams and sockets")
- ois.close()
- // TODO: Following can cause "java.net.SocketException: Socket closed"
- oos.close()
- bos.close()
- clientSocket.close()
- }
- }
- }
- }
-}