aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorMosharaf Chowdhury <mosharaf@cs.berkeley.edu>2011-04-27 20:47:07 -0700
committerMosharaf Chowdhury <mosharaf@cs.berkeley.edu>2011-04-27 20:47:07 -0700
commit9d78779257b156bec335af4ab2a66bb3cac30ca6 (patch)
tree738420e4f3653510e2803b6cb6764c261feeb8c8 /core
parent4e4c41026c33d9f8fe75137f32266de75a0aa30e (diff)
parentac7e066383a6878beb0618597c2be6fa9eb1982e (diff)
downloadspark-9d78779257b156bec335af4ab2a66bb3cac30ca6.tar.gz
spark-9d78779257b156bec335af4ab2a66bb3cac30ca6.tar.bz2
spark-9d78779257b156bec335af4ab2a66bb3cac30ca6.zip
Merge branch 'mos-shuffle-tracked' into mos-bt
Conflicts: core/src/main/scala/spark/Broadcast.scala
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/spark/BasicLocalFileShuffle.scala (renamed from core/src/main/scala/spark/LocalFileShuffle.scala)32
-rw-r--r--core/src/main/scala/spark/CustomBlockedInMemoryShuffle.scala629
-rw-r--r--core/src/main/scala/spark/CustomBlockedLocalFileShuffle.scala654
-rw-r--r--core/src/main/scala/spark/CustomParallelFakeShuffle.scala495
-rw-r--r--core/src/main/scala/spark/CustomParallelInMemoryShuffle.scala535
-rw-r--r--core/src/main/scala/spark/CustomParallelLocalFileShuffle.scala541
-rw-r--r--core/src/main/scala/spark/DfsShuffle.scala2
-rw-r--r--core/src/main/scala/spark/HttpBlockedLocalFileShuffle.scala471
-rw-r--r--core/src/main/scala/spark/HttpParallelLocalFileShuffle.scala391
-rw-r--r--core/src/main/scala/spark/ManualBlockedLocalFileShuffle.scala465
-rw-r--r--core/src/main/scala/spark/Shuffle.scala113
-rw-r--r--core/src/main/scala/spark/ShuffleTrackerStrategy.scala470
-rw-r--r--core/src/main/scala/spark/TrackedCustomBlockedInMemoryShuffle.scala926
-rw-r--r--core/src/main/scala/spark/TrackedCustomBlockedLocalFileShuffle.scala930
-rw-r--r--core/src/main/scala/spark/TrackedCustomParallelLocalFileShuffle.scala802
15 files changed, 7444 insertions, 12 deletions
diff --git a/core/src/main/scala/spark/LocalFileShuffle.scala b/core/src/main/scala/spark/BasicLocalFileShuffle.scala
index 367599cfb4..3c3f132083 100644
--- a/core/src/main/scala/spark/LocalFileShuffle.scala
+++ b/core/src/main/scala/spark/BasicLocalFileShuffle.scala
@@ -7,14 +7,13 @@ import java.util.concurrent.atomic.AtomicLong
import scala.collection.mutable.{ArrayBuffer, HashMap}
-
/**
- * A simple implementation of shuffle using local files served through HTTP.
+ * A basic implementation of shuffle using local files served through HTTP.
*
* TODO: Add support for compression when spark.compress is set to true.
*/
@serializable
-class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging {
+class BasicLocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging {
override def compute(input: RDD[(K, V)],
numOutputSplits: Int,
createCombiner: V => C,
@@ -23,7 +22,7 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging {
: RDD[(K, C)] =
{
val sc = input.sparkContext
- val shuffleId = LocalFileShuffle.newShuffleId()
+ val shuffleId = BasicLocalFileShuffle.newShuffleId()
logInfo("Shuffle ID: " + shuffleId)
val splitRdd = new NumberedSplitRDD(input)
@@ -46,13 +45,20 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging {
case None => createCombiner(v)
}
}
+
for (i <- 0 until numOutputSplits) {
- val file = LocalFileShuffle.getOutputFile(shuffleId, myIndex, i)
+ val file = BasicLocalFileShuffle.getOutputFile(shuffleId, myIndex, i)
+ val writeStartTime = System.currentTimeMillis
+ logInfo("BEGIN WRITE: " + file)
val out = new ObjectOutputStream(new FileOutputStream(file))
buckets(i).foreach(pair => out.writeObject(pair))
out.close()
+ logInfo("END WRITE: " + file)
+ val writeTime = (System.currentTimeMillis - writeStartTime)
+ logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.")
}
- (myIndex, LocalFileShuffle.serverUri)
+
+ (myIndex, BasicLocalFileShuffle.serverUri)
}).collect()
// Build a hashmap from server URI to list of splits (to facillitate
@@ -71,6 +77,8 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging {
for ((serverUri, inputIds) <- Utils.shuffle(splitsByUri)) {
for (i <- inputIds) {
val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, myId)
+ val readStartTime = System.currentTimeMillis
+ logInfo("BEGIN READ: " + url)
val inputStream = new ObjectInputStream(new URL(url).openStream())
try {
while (true) {
@@ -84,6 +92,9 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging {
case e: EOFException => {}
}
inputStream.close()
+ logInfo("END READ: " + url)
+ val readTime = System.currentTimeMillis - readStartTime
+ logInfo("Reading " + url + " took " + readTime + " millis.")
}
}
combiners
@@ -91,8 +102,7 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging {
}
}
-
-object LocalFileShuffle extends Logging {
+object BasicLocalFileShuffle extends Logging {
private var initialized = false
private var nextShuffleId = new AtomicLong(0)
@@ -113,9 +123,9 @@ object LocalFileShuffle extends Logging {
while (!foundLocalDir && tries < 10) {
tries += 1
try {
- localDirUuid = UUID.randomUUID()
+ localDirUuid = UUID.randomUUID
localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
- if (!localDir.exists()) {
+ if (!localDir.exists) {
localDir.mkdirs()
foundLocalDir = true
}
@@ -131,6 +141,7 @@ object LocalFileShuffle extends Logging {
shuffleDir = new File(localDir, "shuffle")
shuffleDir.mkdirs()
logInfo("Shuffle dir: " + shuffleDir)
+
val extServerPort = System.getProperty(
"spark.localFileShuffle.external.server.port", "-1").toInt
if (extServerPort != -1) {
@@ -149,6 +160,7 @@ object LocalFileShuffle extends Logging {
serverUri = server.uri
}
initialized = true
+ logInfo("Local URI: " + serverUri)
}
}
diff --git a/core/src/main/scala/spark/CustomBlockedInMemoryShuffle.scala b/core/src/main/scala/spark/CustomBlockedInMemoryShuffle.scala
new file mode 100644
index 0000000000..5aef43f302
--- /dev/null
+++ b/core/src/main/scala/spark/CustomBlockedInMemoryShuffle.scala
@@ -0,0 +1,629 @@
+package spark
+
+import java.io._
+import java.net._
+import java.util.{BitSet, Random, Timer, TimerTask, UUID}
+import java.util.concurrent.atomic.AtomicLong
+import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory}
+
+import scala.collection.mutable.{ArrayBuffer, HashMap}
+
+/**
+ * TODO: THIS IS AN ABSOLUTELY EXPERIMENTAL IMPLEMENTATON (FOR NOW).
+ *
+ * An implementation of shuffle using local memory served through custom server
+ * where receivers create simultaneous connections to multiple servers by
+ * setting the 'spark.shuffle.maxRxConnections' config option.
+ *
+ * By controlling the 'spark.shuffle.blockSize' config option one can also
+ * control the largest block size to divide each map output into. Essentially,
+ * instead of creating one large output file for each reducer, maps create
+ * multiple smaller files to enable finer level of engagement.
+ *
+ * 'spark.shuffle.maxTxConnections' enforces server-side cap. Ideally,
+ * maxTxConnections >= maxRxConnections * numReducersPerMachine
+ *
+ * TODO: Add support for compression when spark.compress is set to true.
+ */
+@serializable
+class CustomBlockedInMemoryShuffle[K, V, C]
+extends Shuffle[K, V, C] with Logging {
+ @transient var totalSplits = 0
+ @transient var hasSplits = 0
+
+ @transient var totalBlocksInSplit: Array[Int] = null
+ @transient var hasBlocksInSplit: Array[Int] = null
+
+ @transient var hasSplitsBitVector: BitSet = null
+ @transient var splitsInRequestBitVector: BitSet = null
+
+ @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null
+ @transient var combiners: HashMap[K,C] = null
+
+ override def compute(input: RDD[(K, V)],
+ numOutputSplits: Int,
+ createCombiner: V => C,
+ mergeValue: (C, V) => C,
+ mergeCombiners: (C, C) => C)
+ : RDD[(K, C)] =
+ {
+ val sc = input.sparkContext
+ val shuffleId = CustomBlockedInMemoryShuffle.newShuffleId()
+ logInfo("Shuffle ID: " + shuffleId)
+
+ val splitRdd = new NumberedSplitRDD(input)
+ val numInputSplits = splitRdd.splits.size
+
+ // Run a parallel map and collect to write the intermediate data files,
+ // returning a list of inputSplitId -> serverUri pairs
+ val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => {
+ val myIndex = pair._1
+ val myIterator = pair._2
+ val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C])
+ for ((k, v) <- myIterator) {
+ var bucketId = k.hashCode % numOutputSplits
+ if (bucketId < 0) { // Fix bucket ID if hash code was negative
+ bucketId += numOutputSplits
+ }
+ val bucket = buckets(bucketId)
+ bucket(k) = bucket.get(k) match {
+ case Some(c) => mergeValue(c, v)
+ case None => createCombiner(v)
+ }
+ }
+
+ for (i <- 0 until numOutputSplits) {
+ var blockNum = 0
+ var isDirty = false
+
+ var splitName = ""
+ var baos: ByteArrayOutputStream = null
+ var oos: ObjectOutputStream = null
+
+ var writeStartTime: Long = 0
+
+ buckets(i).foreach(pair => {
+ // Open a new stream if necessary
+ if (!isDirty) {
+ splitName = CustomBlockedInMemoryShuffle.getSplitName(shuffleId,
+ myIndex, i, blockNum)
+
+ baos = new ByteArrayOutputStream
+ oos = new ObjectOutputStream(baos)
+
+ writeStartTime = System.currentTimeMillis
+ logInfo("BEGIN WRITE: " + splitName)
+ }
+
+ oos.writeObject(pair)
+ isDirty = true
+
+ // Close the old stream if has crossed the blockSize limit
+ if (baos.size > Shuffle.BlockSize) {
+ CustomBlockedInMemoryShuffle.splitsCache(splitName) =
+ baos.toByteArray
+
+ logInfo("END WRITE: " + splitName)
+ val writeTime = System.currentTimeMillis - writeStartTime
+ logInfo("Writing " + splitName + " of size " + baos.size + " bytes took " + writeTime + " millis.")
+
+ blockNum = blockNum + 1
+ isDirty = false
+ oos.close()
+ }
+ })
+
+ if (isDirty) {
+ CustomBlockedInMemoryShuffle.splitsCache(splitName) = baos.toByteArray
+
+ logInfo("END WRITE: " + splitName)
+ val writeTime = System.currentTimeMillis - writeStartTime
+ logInfo("Writing " + splitName + " of size " + baos.size + " bytes took " + writeTime + " millis.")
+
+ blockNum = blockNum + 1
+ oos.close()
+ }
+
+ // Store BLOCKNUM info
+ splitName = CustomBlockedInMemoryShuffle.getBlockNumOutputName(
+ shuffleId, myIndex, i)
+ baos = new ByteArrayOutputStream
+ oos = new ObjectOutputStream(baos)
+ oos.writeObject(blockNum)
+ CustomBlockedInMemoryShuffle.splitsCache(splitName) = baos.toByteArray
+
+ // Close streams
+ oos.close()
+ }
+
+ (myIndex, CustomBlockedInMemoryShuffle.serverAddress,
+ CustomBlockedInMemoryShuffle.serverPort)
+ }).collect()
+
+ val splitsByUri = new ArrayBuffer[(String, Int, Int)]
+ for ((inputId, serverAddress, serverPort) <- outputLocs) {
+ splitsByUri += ((serverAddress, serverPort, inputId))
+ }
+
+ // TODO: Could broadcast outputLocs
+
+ // Return an RDD that does each of the merges for a given partition
+ val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits)
+ return indexes.flatMap((myId: Int) => {
+ totalSplits = outputLocs.size
+ hasSplits = 0
+
+ totalBlocksInSplit = Array.tabulate(totalSplits)(_ => -1)
+ hasBlocksInSplit = Array.tabulate(totalSplits)(_ => 0)
+
+ hasSplitsBitVector = new BitSet(totalSplits)
+ splitsInRequestBitVector = new BitSet(totalSplits)
+
+ receivedData = new LinkedBlockingQueue[(Int, Array[Byte])]
+ combiners = new HashMap[K, C]
+
+ var threadPool = Shuffle.newDaemonFixedThreadPool(
+ Shuffle.MaxRxConnections)
+
+ while (hasSplits < totalSplits) {
+ var numThreadsToCreate =
+ Math.min(totalSplits, Shuffle.MaxRxConnections) -
+ threadPool.getActiveCount
+
+ while (hasSplits < totalSplits && numThreadsToCreate > 0) {
+ // Select a random split to pull
+ val splitIndex = selectRandomSplit
+
+ if (splitIndex != -1) {
+ val (serverAddress, serverPort, inputId) = splitsByUri(splitIndex)
+
+ threadPool.execute(new ShuffleClient(serverAddress, serverPort,
+ shuffleId.toInt, inputId, myId, splitIndex))
+
+ // splitIndex is in transit. Will be unset in the ShuffleClient
+ splitsInRequestBitVector.synchronized {
+ splitsInRequestBitVector.set(splitIndex)
+ }
+ }
+
+ numThreadsToCreate = numThreadsToCreate - 1
+ }
+
+ // Sleep for a while before creating new threads
+ Thread.sleep(Shuffle.MinKnockInterval)
+ }
+
+ threadPool.shutdown()
+
+ // Start consumer
+ // TODO: Consumption is delayed until everything has been received.
+ // Otherwise it interferes with network performance
+ var shuffleConsumer = new ShuffleConsumer(mergeCombiners)
+ shuffleConsumer.setDaemon(true)
+ shuffleConsumer.start()
+ logInfo("ShuffleConsumer started...")
+
+ // Don't return until consumption is finished
+ // while (receivedData.size > 0) {
+ // Thread.sleep(Shuffle.MinKnockInterval)
+ // }
+
+ // Wait till shuffleConsumer is done
+ shuffleConsumer.join
+
+ combiners
+ })
+ }
+
+ def selectRandomSplit: Int = {
+ var requiredSplits = new ArrayBuffer[Int]
+
+ synchronized {
+ for (i <- 0 until totalSplits) {
+ if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) {
+ requiredSplits += i
+ }
+ }
+ }
+
+ if (requiredSplits.size > 0) {
+ requiredSplits(CustomBlockedInMemoryShuffle.ranGen.nextInt(
+ requiredSplits.size))
+ } else {
+ -1
+ }
+ }
+
+ class ShuffleConsumer(mergeCombiners: (C, C) => C)
+ extends Thread with Logging {
+ override def run: Unit = {
+ // Run until all splits are here
+ while (receivedData.size > 0) {
+ var splitIndex = -1
+ var recvByteArray: Array[Byte] = null
+
+ try {
+ var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])]
+ splitIndex = tempPair._1
+ recvByteArray = tempPair._2
+ } catch {
+ case e: Exception => {
+ logInfo("Exception during taking data from receivedData")
+ }
+ }
+
+ val inputStream =
+ new ObjectInputStream(new ByteArrayInputStream(recvByteArray))
+
+ try{
+ while (true) {
+ val (k, c) = inputStream.readObject.asInstanceOf[(K, C)]
+ combiners(k) = combiners.get(k) match {
+ case Some(oldC) => mergeCombiners(oldC, c)
+ case None => c
+ }
+ }
+ } catch {
+ case e: EOFException => { }
+ }
+ inputStream.close()
+ }
+ }
+ }
+
+ class ShuffleClient(hostAddress: String, listenPort: Int, shuffleId: Int,
+ inputId: Int, myId: Int, splitIndex: Int)
+ extends Thread with Logging {
+ private var peerSocketToSource: Socket = null
+ private var oosSource: ObjectOutputStream = null
+ private var oisSource: ObjectInputStream = null
+
+ private var receptionSucceeded = false
+
+ override def run: Unit = {
+ // Setup the timeout mechanism
+ var timeOutTask = new TimerTask {
+ override def run: Unit = {
+ cleanUpConnections()
+ }
+ }
+
+ var timeOutTimer = new Timer
+ timeOutTimer.schedule(timeOutTask, Shuffle.MaxKnockInterval)
+
+ try {
+ // Everything will break if BLOCKNUM is not correctly received
+ // First get BLOCKNUM file if totalBlocksInSplit(splitIndex) is unknown
+ peerSocketToSource = new Socket(hostAddress, listenPort)
+ oosSource =
+ new ObjectOutputStream(peerSocketToSource.getOutputStream)
+ oosSource.flush()
+ var isSource = peerSocketToSource.getInputStream
+ oisSource = new ObjectInputStream(isSource)
+
+ // Send path information
+ oosSource.writeObject((shuffleId, inputId, myId))
+
+ // TODO: Can be optimized. No need to do it everytime.
+ // Receive BLOCKNUM
+ totalBlocksInSplit(splitIndex) = oisSource.readObject.asInstanceOf[Int]
+
+ // Turn the timer OFF, if the sender responds before timeout
+ timeOutTimer.cancel()
+
+ while (hasBlocksInSplit(splitIndex) < totalBlocksInSplit(splitIndex)) {
+ // Set receptionSucceeded to false before trying for each block
+ receptionSucceeded = false
+
+ // Request specific block
+ oosSource.writeObject(hasBlocksInSplit(splitIndex))
+
+ // Good to go. First, receive the length of the requested file
+ var requestedFileLen = oisSource.readObject.asInstanceOf[Int]
+ logInfo("Received requestedFileLen = " + requestedFileLen)
+
+ val requestSplit = "%d/%d/%d-%d".format(shuffleId, inputId, myId,
+ hasBlocksInSplit(splitIndex))
+
+ // Receive the file
+ if (requestedFileLen != -1) {
+ val readStartTime = System.currentTimeMillis
+ logInfo("BEGIN READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit))
+
+ // Receive data in an Array[Byte]
+ var recvByteArray = new Array[Byte](requestedFileLen)
+ var alreadyRead = 0
+ var bytesRead = 0
+
+ while (alreadyRead != requestedFileLen) {
+ bytesRead = isSource.read(recvByteArray, alreadyRead,
+ requestedFileLen - alreadyRead)
+ if (bytesRead > 0) {
+ alreadyRead = alreadyRead + bytesRead
+ }
+ }
+
+ // Make it available to the consumer
+ try {
+ receivedData.put((splitIndex, recvByteArray))
+ } catch {
+ case e: Exception => {
+ logInfo("Exception during putting data into receivedData")
+ }
+ }
+
+ // TODO: Updating stats before consumption is completed
+ hasBlocksInSplit(splitIndex) = hasBlocksInSplit(splitIndex) + 1
+
+ // Split has been received only if all the blocks have been received
+ if (hasBlocksInSplit(splitIndex) == totalBlocksInSplit(splitIndex)) {
+ hasSplitsBitVector.synchronized {
+ hasSplitsBitVector.set(splitIndex)
+ hasSplits += 1
+ }
+ }
+
+ // Consistent state in accounting variables
+ receptionSucceeded = true
+
+ logInfo("END READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit))
+ val readTime = System.currentTimeMillis - readStartTime
+ logInfo("Reading http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit) + " took " + readTime + " millis.")
+ } else {
+ throw new SparkException("ShuffleServer " + hostAddress + " does not have " + requestSplit)
+ }
+ }
+ } catch {
+ // EOFException is expected to happen because sender can break
+ // connection due to timeout
+ case eofe: java.io.EOFException => { }
+ case e: Exception => {
+ logInfo("ShuffleClient had a " + e)
+ }
+ } finally {
+ splitsInRequestBitVector.synchronized {
+ splitsInRequestBitVector.set(splitIndex, false)
+ }
+ cleanUpConnections()
+ }
+ }
+
+ private def cleanUpConnections(): Unit = {
+ if (oisSource != null) {
+ oisSource.close()
+ }
+ if (oosSource != null) {
+ oosSource.close()
+ }
+ if (peerSocketToSource != null) {
+ peerSocketToSource.close()
+ }
+ }
+ }
+}
+
+object CustomBlockedInMemoryShuffle extends Logging {
+ // Cache for keeping the splits around
+ val splitsCache = new HashMap[String, Array[Byte]]
+
+ private var initialized = false
+ private var nextShuffleId = new AtomicLong(0)
+
+ // Variables initialized by initializeIfNeeded()
+ private var shuffleDir: File = null
+
+ private var shuffleServer: ShuffleServer = null
+ private var serverAddress = InetAddress.getLocalHost.getHostAddress
+ private var serverPort: Int = -1
+
+ // Random number generator
+ var ranGen = new Random
+
+ private def initializeIfNeeded() = synchronized {
+ if (!initialized) {
+ // TODO: localDir should be created by some mechanism common to Spark
+ // so that it can be shared among shuffle, broadcast, etc
+ val localDirRoot = System.getProperty("spark.local.dir", "/tmp")
+ var tries = 0
+ var foundLocalDir = false
+ var localDir: File = null
+ var localDirUuid: UUID = null
+ while (!foundLocalDir && tries < 10) {
+ tries += 1
+ try {
+ localDirUuid = UUID.randomUUID
+ localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
+ if (!localDir.exists) {
+ localDir.mkdirs()
+ foundLocalDir = true
+ }
+ } catch {
+ case e: Exception =>
+ logWarning("Attempt " + tries + " to create local dir failed", e)
+ }
+ }
+ if (!foundLocalDir) {
+ logError("Failed 10 attempts to create local dir in " + localDirRoot)
+ System.exit(1)
+ }
+ shuffleDir = new File(localDir, "shuffle")
+ shuffleDir.mkdirs()
+ logInfo("Shuffle dir: " + shuffleDir)
+
+ // Create and start the shuffleServer
+ shuffleServer = new ShuffleServer
+ shuffleServer.setDaemon(true)
+ shuffleServer.start()
+ logInfo("ShuffleServer started...")
+
+ initialized = true
+ }
+ }
+
+ def getSplitName(shuffleId: Long, inputId: Int, outputId: Int,
+ blockId: Int): String = {
+ initializeIfNeeded()
+ // Adding shuffleDir is unnecessary. Added to keep the parsers working
+ return "%s/%d/%d/%d-%d".format(shuffleDir, shuffleId, inputId, outputId,
+ blockId)
+ }
+
+ def getBlockNumOutputName(shuffleId: Long, inputId: Int,
+ outputId: Int): String = {
+ initializeIfNeeded()
+ return "%s/%d/%d/BLOCKNUM-%d".format(shuffleDir, shuffleId, inputId,
+ outputId)
+ }
+
+ def newShuffleId(): Long = {
+ nextShuffleId.getAndIncrement()
+ }
+
+ class ShuffleServer
+ extends Thread with Logging {
+ var threadPool = Shuffle.newDaemonFixedThreadPool(Shuffle.MaxTxConnections)
+
+ var serverSocket: ServerSocket = null
+
+ override def run: Unit = {
+ serverSocket = new ServerSocket(0)
+ serverPort = serverSocket.getLocalPort
+
+ logInfo("ShuffleServer started with " + serverSocket)
+ logInfo("Local URI: http://" + serverAddress + ":" + serverPort)
+
+ try {
+ while (true) {
+ var clientSocket: Socket = null
+ try {
+ clientSocket = serverSocket.accept()
+ } catch {
+ case e: Exception => { }
+ }
+ if (clientSocket != null) {
+ logInfo("Serve: Accepted new client connection:" + clientSocket)
+ try {
+ threadPool.execute(new ShuffleServerThread(clientSocket))
+ } catch {
+ // In failure, close socket here; else, the thread will close it
+ case ioe: IOException => {
+ clientSocket.close()
+ }
+ }
+ }
+ }
+ } finally {
+ if (serverSocket != null) {
+ logInfo("ShuffleServer now stopping...")
+ serverSocket.close()
+ }
+ }
+ // Shutdown the thread pool
+ threadPool.shutdown()
+ }
+
+ class ShuffleServerThread(val clientSocket: Socket)
+ extends Thread with Logging {
+ private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream]
+ os.flush()
+ private val bos = new BufferedOutputStream(os)
+ bos.flush()
+ private val oos = new ObjectOutputStream(os)
+ oos.flush()
+ private val ois = new ObjectInputStream(clientSocket.getInputStream)
+
+ logInfo("new ShuffleServerThread is running")
+
+ override def run: Unit = {
+ try {
+ // Receive basic path information
+ val (shuffleId, myIndex, outputId) =
+ ois.readObject.asInstanceOf[(Int, Int, Int)]
+
+ var requestedSplitBase = "%s/%d/%d/%d".format(
+ shuffleDir, shuffleId, myIndex, outputId)
+ logInfo("requestedSplitBase: " + requestedSplitBase)
+
+ // Read BLOCKNUM and send back the total number of blocks
+ val blockNumName = "%s/%d/%d/BLOCKNUM-%d".format(shuffleDir,
+ shuffleId, myIndex, outputId)
+
+ val blockNumIn = new ObjectInputStream(new ByteArrayInputStream(
+ CustomBlockedInMemoryShuffle.splitsCache(blockNumName)))
+ val BLOCKNUM = blockNumIn.readObject.asInstanceOf[Int]
+ blockNumIn.close()
+
+ oos.writeObject(BLOCKNUM)
+
+ val startTime = System.currentTimeMillis
+ var curTime = startTime
+ var keepSending = true
+ var numBlocksToSend = Shuffle.MaxChatBlocks
+
+ while (keepSending && numBlocksToSend > 0) {
+ // Receive specific block request
+ val blockId = ois.readObject.asInstanceOf[Int]
+
+ // Ready to send
+ var requestedSplit = requestedSplitBase + "-" + blockId
+
+ // Send the length of the requestedSplit to let the receiver know that
+ // transfer is about to start
+ // In the case of receiver timeout and connection close, this will
+ // throw a java.net.SocketException: Broken pipe
+ var requestedSplitLen = -1
+
+ try {
+ requestedSplitLen =
+ CustomBlockedInMemoryShuffle.splitsCache(requestedSplit).length
+ } catch {
+ case e: Exception => { }
+ }
+
+ oos.writeObject(requestedSplitLen)
+ oos.flush()
+
+ logInfo("requestedSplitLen = " + requestedSplitLen)
+
+ // Read and send the requested file
+ if (requestedSplitLen != -1) {
+ // Send
+ bos.write(CustomBlockedInMemoryShuffle.splitsCache(requestedSplit),
+ 0, requestedSplitLen)
+ bos.flush()
+
+ // Update loop variables
+ numBlocksToSend = numBlocksToSend - 1
+
+ curTime = System.currentTimeMillis
+ // Revoke sending only if there is anyone waiting in the queue
+ if (curTime - startTime >= Shuffle.MaxChatTime &&
+ threadPool.getQueue.size > 0) {
+ keepSending = false
+ }
+ } else {
+ // Close the connection
+ }
+ }
+ } catch {
+ // If something went wrong, e.g., the worker at the other end died etc
+ // then close everything up
+ // Exception can happen if the receiver stops receiving
+ // EOFException is expected to happen because receiver can break
+ // connection as soon as it has all the blocks
+ case eofe: java.io.EOFException => { }
+ case e: Exception => {
+ logInfo("ShuffleServerThread had a " + e)
+ }
+ } finally {
+ logInfo("ShuffleServerThread is closing streams and sockets")
+ ois.close()
+ // TODO: Following can cause "java.net.SocketException: Socket closed"
+ oos.close()
+ bos.close()
+ clientSocket.close()
+ }
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/CustomBlockedLocalFileShuffle.scala b/core/src/main/scala/spark/CustomBlockedLocalFileShuffle.scala
new file mode 100644
index 0000000000..98af7c8d65
--- /dev/null
+++ b/core/src/main/scala/spark/CustomBlockedLocalFileShuffle.scala
@@ -0,0 +1,654 @@
+package spark
+
+import java.io._
+import java.net._
+import java.util.{BitSet, Random, Timer, TimerTask, UUID}
+import java.util.concurrent.atomic.AtomicLong
+import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory}
+
+import scala.collection.mutable.{ArrayBuffer, HashMap}
+
+/**
+ * An implementation of shuffle using local files served through custom server
+ * where receivers create simultaneous connections to multiple servers by
+ * setting the 'spark.shuffle.maxRxConnections' config option.
+ *
+ * By controlling the 'spark.shuffle.blockSize' config option one can also
+ * control the largest block size to divide each map output into. Essentially,
+ * instead of creating one large output file for each reducer, maps create
+ * multiple smaller files to enable finer level of engagement.
+ *
+ * 'spark.shuffle.maxTxConnections' enforces server-side cap. Ideally,
+ * maxTxConnections >= maxRxConnections * numReducersPerMachine
+ *
+ * TODO: Add support for compression when spark.compress is set to true.
+ */
+@serializable
+class CustomBlockedLocalFileShuffle[K, V, C]
+extends Shuffle[K, V, C] with Logging {
+ @transient var totalSplits = 0
+ @transient var hasSplits = 0
+
+ @transient var totalBlocksInSplit: Array[Int] = null
+ @transient var hasBlocksInSplit: Array[Int] = null
+
+ @transient var hasSplitsBitVector: BitSet = null
+ @transient var splitsInRequestBitVector: BitSet = null
+
+ @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null
+ @transient var combiners: HashMap[K,C] = null
+
+ override def compute(input: RDD[(K, V)],
+ numOutputSplits: Int,
+ createCombiner: V => C,
+ mergeValue: (C, V) => C,
+ mergeCombiners: (C, C) => C)
+ : RDD[(K, C)] =
+ {
+ val sc = input.sparkContext
+ val shuffleId = CustomBlockedLocalFileShuffle.newShuffleId()
+ logInfo("Shuffle ID: " + shuffleId)
+
+ val splitRdd = new NumberedSplitRDD(input)
+ val numInputSplits = splitRdd.splits.size
+
+ // Run a parallel map and collect to write the intermediate data files,
+ // returning a list of inputSplitId -> serverUri pairs
+ val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => {
+ val myIndex = pair._1
+ val myIterator = pair._2
+ val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C])
+ for ((k, v) <- myIterator) {
+ var bucketId = k.hashCode % numOutputSplits
+ if (bucketId < 0) { // Fix bucket ID if hash code was negative
+ bucketId += numOutputSplits
+ }
+ val bucket = buckets(bucketId)
+ bucket(k) = bucket.get(k) match {
+ case Some(c) => mergeValue(c, v)
+ case None => createCombiner(v)
+ }
+ }
+
+ for (i <- 0 until numOutputSplits) {
+ var blockNum = 0
+ var isDirty = false
+ var file: File = null
+ var out: ObjectOutputStream = null
+
+ var writeStartTime: Long = 0
+
+ buckets(i).foreach(pair => {
+ // Open a new file if necessary
+ if (!isDirty) {
+ file = CustomBlockedLocalFileShuffle.getOutputFile(shuffleId,
+ myIndex, i, blockNum)
+ writeStartTime = System.currentTimeMillis
+ logInfo("BEGIN WRITE: " + file)
+
+ out = new ObjectOutputStream(new FileOutputStream(file))
+ }
+
+ out.writeObject(pair)
+ out.flush()
+ isDirty = true
+
+ // Close the old file if has crossed the blockSize limit
+ if (file.length > Shuffle.BlockSize) {
+ out.close()
+ logInfo("END WRITE: " + file)
+ val writeTime = System.currentTimeMillis - writeStartTime
+ logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.")
+
+ blockNum = blockNum + 1
+ isDirty = false
+ }
+ })
+
+ if (isDirty) {
+ out.close()
+ logInfo("END WRITE: " + file)
+ val writeTime = System.currentTimeMillis - writeStartTime
+ logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.")
+
+ blockNum = blockNum + 1
+ }
+
+ // Write the BLOCKNUM file
+ file = CustomBlockedLocalFileShuffle.getBlockNumOutputFile(shuffleId,
+ myIndex, i)
+ out = new ObjectOutputStream(new FileOutputStream(file))
+ out.writeObject(blockNum)
+ out.close()
+ }
+
+ (myIndex, CustomBlockedLocalFileShuffle.serverAddress,
+ CustomBlockedLocalFileShuffle.serverPort)
+ }).collect()
+
+ val splitsByUri = new ArrayBuffer[(String, Int, Int)]
+ for ((inputId, serverAddress, serverPort) <- outputLocs) {
+ splitsByUri += ((serverAddress, serverPort, inputId))
+ }
+
+ // TODO: Could broadcast outputLocs
+
+ // Return an RDD that does each of the merges for a given partition
+ val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits)
+ return indexes.flatMap((myId: Int) => {
+ totalSplits = outputLocs.size
+ hasSplits = 0
+
+ totalBlocksInSplit = Array.tabulate(totalSplits)(_ => -1)
+ hasBlocksInSplit = Array.tabulate(totalSplits)(_ => 0)
+
+ hasSplitsBitVector = new BitSet(totalSplits)
+ splitsInRequestBitVector = new BitSet(totalSplits)
+
+ receivedData = new LinkedBlockingQueue[(Int, Array[Byte])]
+ combiners = new HashMap[K, C]
+
+ var threadPool = Shuffle.newDaemonFixedThreadPool(
+ Shuffle.MaxRxConnections)
+
+ while (hasSplits < totalSplits) {
+ var numThreadsToCreate =
+ Math.min(totalSplits, Shuffle.MaxRxConnections) -
+ threadPool.getActiveCount
+
+ while (hasSplits < totalSplits && numThreadsToCreate > 0) {
+ // Select a random split to pull
+ val splitIndex = selectRandomSplit
+
+ if (splitIndex != -1) {
+ val (serverAddress, serverPort, inputId) = splitsByUri(splitIndex)
+ val requestPath = "%d/%d/%d".format(shuffleId, inputId, myId)
+
+ threadPool.execute(new ShuffleClient(serverAddress, serverPort,
+ shuffleId.toInt, inputId, myId, splitIndex))
+
+ // splitIndex is in transit. Will be unset in the ShuffleClient
+ splitsInRequestBitVector.synchronized {
+ splitsInRequestBitVector.set(splitIndex)
+ }
+ }
+
+ numThreadsToCreate = numThreadsToCreate - 1
+ }
+
+ // Sleep for a while before creating new threads
+ Thread.sleep(Shuffle.MinKnockInterval)
+ }
+
+ threadPool.shutdown()
+
+ // Start consumer
+ // TODO: Consumption is delayed until everything has been received.
+ // Otherwise it interferes with network performance
+ var shuffleConsumer = new ShuffleConsumer(mergeCombiners)
+ shuffleConsumer.setDaemon(true)
+ shuffleConsumer.start()
+ logInfo("ShuffleConsumer started...")
+
+ // Don't return until consumption is finished
+ // while (receivedData.size > 0) {
+ // Thread.sleep(Shuffle.MinKnockInterval)
+ // }
+
+ // Wait till shuffleConsumer is done
+ shuffleConsumer.join
+
+ combiners
+ })
+ }
+
+ def selectRandomSplit: Int = {
+ var requiredSplits = new ArrayBuffer[Int]
+
+ synchronized {
+ for (i <- 0 until totalSplits) {
+ if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) {
+ requiredSplits += i
+ }
+ }
+ }
+
+ if (requiredSplits.size > 0) {
+ requiredSplits(CustomBlockedLocalFileShuffle.ranGen.nextInt(
+ requiredSplits.size))
+ } else {
+ -1
+ }
+ }
+
+ class ShuffleConsumer(mergeCombiners: (C, C) => C)
+ extends Thread with Logging {
+ override def run: Unit = {
+ // Run until all splits are here
+ while (receivedData.size > 0) {
+ var splitIndex = -1
+ var recvByteArray: Array[Byte] = null
+
+ try {
+ var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])]
+ splitIndex = tempPair._1
+ recvByteArray = tempPair._2
+ } catch {
+ case e: Exception => {
+ logInfo("Exception during taking data from receivedData")
+ }
+ }
+
+ val inputStream =
+ new ObjectInputStream(new ByteArrayInputStream(recvByteArray))
+
+ try{
+ while (true) {
+ val (k, c) = inputStream.readObject.asInstanceOf[(K, C)]
+ combiners(k) = combiners.get(k) match {
+ case Some(oldC) => mergeCombiners(oldC, c)
+ case None => c
+ }
+ }
+ } catch {
+ case e: EOFException => { }
+ }
+ inputStream.close()
+ }
+ }
+ }
+
+ class ShuffleClient(hostAddress: String, listenPort: Int, shuffleId: Int,
+ inputId: Int, myId: Int, splitIndex: Int)
+ extends Thread with Logging {
+ private var peerSocketToSource: Socket = null
+ private var oosSource: ObjectOutputStream = null
+ private var oisSource: ObjectInputStream = null
+
+ private var receptionSucceeded = false
+
+ override def run: Unit = {
+ // Setup the timeout mechanism
+ var timeOutTask = new TimerTask {
+ override def run: Unit = {
+ cleanUpConnections()
+ }
+ }
+
+ var timeOutTimer = new Timer
+ timeOutTimer.schedule(timeOutTask, Shuffle.MaxKnockInterval)
+
+ try {
+ // Everything will break if BLOCKNUM is not correctly received
+ // First get BLOCKNUM file if totalBlocksInSplit(splitIndex) is unknown
+ peerSocketToSource = new Socket(hostAddress, listenPort)
+ oosSource =
+ new ObjectOutputStream(peerSocketToSource.getOutputStream)
+ oosSource.flush()
+ var isSource = peerSocketToSource.getInputStream
+ oisSource = new ObjectInputStream(isSource)
+
+ // Send path information
+ oosSource.writeObject((shuffleId, inputId, myId))
+
+ // TODO: Can be optimized. No need to do it everytime.
+ // Receive BLOCKNUM
+ totalBlocksInSplit(splitIndex) = oisSource.readObject.asInstanceOf[Int]
+
+ // Turn the timer OFF, if the sender responds before timeout
+ timeOutTimer.cancel()
+
+ while (hasBlocksInSplit(splitIndex) < totalBlocksInSplit(splitIndex)) {
+ // Set receptionSucceeded to false before trying for each block
+ receptionSucceeded = false
+
+ // Request specific block
+ oosSource.writeObject(hasBlocksInSplit(splitIndex))
+
+ // Good to go. First, receive the length of the requested file
+ var requestedFileLen = oisSource.readObject.asInstanceOf[Int]
+ logInfo("Received requestedFileLen = " + requestedFileLen)
+
+ val requestPath = "%d/%d/%d-%d".format(shuffleId, inputId, myId,
+ hasBlocksInSplit(splitIndex))
+
+ // Receive the file
+ if (requestedFileLen != -1) {
+ val readStartTime = System.currentTimeMillis
+ logInfo("BEGIN READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestPath))
+
+ // Receive data in an Array[Byte]
+ var recvByteArray = new Array[Byte](requestedFileLen)
+ var alreadyRead = 0
+ var bytesRead = 0
+
+ while (alreadyRead != requestedFileLen) {
+ bytesRead = isSource.read(recvByteArray, alreadyRead,
+ requestedFileLen - alreadyRead)
+ if (bytesRead > 0) {
+ alreadyRead = alreadyRead + bytesRead
+ }
+ }
+
+ // Make it available to the consumer
+ try {
+ receivedData.put((splitIndex, recvByteArray))
+ } catch {
+ case e: Exception => {
+ logInfo("Exception during putting data into receivedData")
+ }
+ }
+
+ // TODO: Updating stats before consumption is completed
+ hasBlocksInSplit(splitIndex) = hasBlocksInSplit(splitIndex) + 1
+
+ // Split has been received only if all the blocks have been received
+ if (hasBlocksInSplit(splitIndex) == totalBlocksInSplit(splitIndex)) {
+ hasSplitsBitVector.synchronized {
+ hasSplitsBitVector.set(splitIndex)
+ hasSplits += 1
+ }
+ }
+
+ // Consistent state in accounting variables
+ receptionSucceeded = true
+
+ logInfo("END READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestPath))
+ val readTime = System.currentTimeMillis - readStartTime
+ logInfo("Reading http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestPath) + " took " + readTime + " millis.")
+ } else {
+ throw new SparkException("ShuffleServer " + hostAddress + " does not have " + requestPath)
+ }
+ }
+ } catch {
+ // EOFException is expected to happen because sender can break
+ // connection due to timeout
+ case eofe: java.io.EOFException => { }
+ case e: Exception => {
+ logInfo("ShuffleClient had a " + e)
+ }
+ } finally {
+ splitsInRequestBitVector.synchronized {
+ splitsInRequestBitVector.set(splitIndex, false)
+ }
+ cleanUpConnections()
+ }
+ }
+
+ private def cleanUpConnections(): Unit = {
+ if (oisSource != null) {
+ oisSource.close()
+ }
+ if (oosSource != null) {
+ oosSource.close()
+ }
+ if (peerSocketToSource != null) {
+ peerSocketToSource.close()
+ }
+ }
+ }
+}
+
+object CustomBlockedLocalFileShuffle extends Logging {
+ private var initialized = false
+ private var nextShuffleId = new AtomicLong(0)
+
+ // Variables initialized by initializeIfNeeded()
+ private var shuffleDir: File = null
+
+ private var shuffleServer: ShuffleServer = null
+ private var serverAddress = InetAddress.getLocalHost.getHostAddress
+ private var serverPort: Int = -1
+
+ // Random number generator
+ var ranGen = new Random
+
+ private def initializeIfNeeded() = synchronized {
+ if (!initialized) {
+ // TODO: localDir should be created by some mechanism common to Spark
+ // so that it can be shared among shuffle, broadcast, etc
+ val localDirRoot = System.getProperty("spark.local.dir", "/tmp")
+ var tries = 0
+ var foundLocalDir = false
+ var localDir: File = null
+ var localDirUuid: UUID = null
+ while (!foundLocalDir && tries < 10) {
+ tries += 1
+ try {
+ localDirUuid = UUID.randomUUID
+ localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
+ if (!localDir.exists) {
+ localDir.mkdirs()
+ foundLocalDir = true
+ }
+ } catch {
+ case e: Exception =>
+ logWarning("Attempt " + tries + " to create local dir failed", e)
+ }
+ }
+ if (!foundLocalDir) {
+ logError("Failed 10 attempts to create local dir in " + localDirRoot)
+ System.exit(1)
+ }
+ shuffleDir = new File(localDir, "shuffle")
+ shuffleDir.mkdirs()
+ logInfo("Shuffle dir: " + shuffleDir)
+
+ // Create and start the shuffleServer
+ shuffleServer = new ShuffleServer
+ shuffleServer.setDaemon(true)
+ shuffleServer.start()
+ logInfo("ShuffleServer started...")
+
+ initialized = true
+ }
+ }
+
+ def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int,
+ blockId: Int): File = {
+ initializeIfNeeded()
+ val dir = new File(shuffleDir, shuffleId + "/" + inputId)
+ dir.mkdirs()
+ val file = new File(dir, "%d-%d".format(outputId, blockId))
+ return file
+ }
+
+ def getBlockNumOutputFile(shuffleId: Long, inputId: Int,
+ outputId: Int): File = {
+ initializeIfNeeded()
+ val dir = new File(shuffleDir, shuffleId + "/" + inputId)
+ dir.mkdirs()
+ val file = new File(dir, "BLOCKNUM-" + outputId)
+ return file
+ }
+
+ def newShuffleId(): Long = {
+ nextShuffleId.getAndIncrement()
+ }
+
+ // Returns a standard ThreadFactory except all threads are daemons
+ private def newDaemonThreadFactory: ThreadFactory = {
+ new ThreadFactory {
+ def newThread(r: Runnable): Thread = {
+ var t = Executors.defaultThreadFactory.newThread(r)
+ t.setDaemon(true)
+ return t
+ }
+ }
+ }
+
+ // Wrapper over newFixedThreadPool
+ def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = {
+ var threadPool =
+ Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor]
+
+ threadPool.setThreadFactory(newDaemonThreadFactory)
+
+ return threadPool
+ }
+
+ class ShuffleServer
+ extends Thread with Logging {
+ var threadPool =
+ newDaemonFixedThreadPool(Shuffle.MaxTxConnections)
+
+ var serverSocket: ServerSocket = null
+
+ override def run: Unit = {
+ serverSocket = new ServerSocket(0)
+ serverPort = serverSocket.getLocalPort
+
+ logInfo("ShuffleServer started with " + serverSocket)
+ logInfo("Local URI: http://" + serverAddress + ":" + serverPort)
+
+ try {
+ while (true) {
+ var clientSocket: Socket = null
+ try {
+ clientSocket = serverSocket.accept()
+ } catch {
+ case e: Exception => { }
+ }
+ if (clientSocket != null) {
+ logInfo("Serve: Accepted new client connection:" + clientSocket)
+ try {
+ threadPool.execute(new ShuffleServerThread(clientSocket))
+ } catch {
+ // In failure, close socket here; else, the thread will close it
+ case ioe: IOException => {
+ clientSocket.close()
+ }
+ }
+ }
+ }
+ } finally {
+ if (serverSocket != null) {
+ logInfo("ShuffleServer now stopping...")
+ serverSocket.close()
+ }
+ }
+ // Shutdown the thread pool
+ threadPool.shutdown()
+ }
+
+ class ShuffleServerThread(val clientSocket: Socket)
+ extends Thread with Logging {
+ private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream]
+ os.flush()
+ private val bos = new BufferedOutputStream(os)
+ bos.flush()
+ private val oos = new ObjectOutputStream(os)
+ oos.flush()
+ private val ois = new ObjectInputStream(clientSocket.getInputStream)
+
+ logInfo("new ShuffleServerThread is running")
+
+ override def run: Unit = {
+ try {
+ // Receive basic path information
+ val (shuffleId, myIndex, outputId) =
+ ois.readObject.asInstanceOf[(Int, Int, Int)]
+
+ var requestPathBase = "%d/%d/%d".format(shuffleId, myIndex, outputId)
+
+ logInfo("requestPathBase: " + requestPathBase)
+
+ // Read BLOCKNUM file and send back the total number of blocks
+ val blockNumFilePath = "%s/%d/%d/BLOCKNUM-%d".format(shuffleDir,
+ shuffleId, myIndex, outputId)
+ val blockNumIn =
+ new ObjectInputStream(new FileInputStream(blockNumFilePath))
+ val BLOCKNUM = blockNumIn.readObject.asInstanceOf[Int]
+ blockNumIn.close()
+
+ oos.writeObject(BLOCKNUM)
+
+ val startTime = System.currentTimeMillis
+ var curTime = startTime
+ var keepSending = true
+ var numBlocksToSend = Shuffle.MaxChatBlocks
+
+ while (keepSending && numBlocksToSend > 0) {
+ // Receive specific block request
+ val blockId = ois.readObject.asInstanceOf[Int]
+
+ // Ready to send
+ val requestPath = requestPathBase + "-" + blockId
+
+ // Open the file
+ var requestedFile: File = null
+ var requestedFileLen = -1
+ try {
+ requestedFile = new File(shuffleDir + "/" + requestPath)
+ requestedFileLen = requestedFile.length.toInt
+ } catch {
+ case e: Exception => { }
+ }
+
+ // Send the length of the requestPath to let the receiver know that
+ // transfer is about to start
+ // In the case of receiver timeout and connection close, this will
+ // throw a java.net.SocketException: Broken pipe
+ oos.writeObject(requestedFileLen)
+ oos.flush()
+
+ logInfo("requestedFileLen = " + requestedFileLen)
+
+ // Read and send the requested file
+ if (requestedFileLen != -1) {
+ // Read
+ var byteArray = new Array[Byte](requestedFileLen)
+ val bis =
+ new BufferedInputStream(new FileInputStream(requestedFile))
+
+ var bytesRead = bis.read(byteArray, 0, byteArray.length)
+ var alreadyRead = bytesRead
+
+ while (alreadyRead < requestedFileLen) {
+ bytesRead = bis.read(byteArray, alreadyRead,
+ (byteArray.length - alreadyRead))
+ if(bytesRead > 0) {
+ alreadyRead = alreadyRead + bytesRead
+ }
+ }
+ bis.close()
+
+ // Send
+ bos.write(byteArray, 0, byteArray.length)
+ bos.flush()
+
+ // Update loop variables
+ numBlocksToSend = numBlocksToSend - 1
+
+ curTime = System.currentTimeMillis
+ // Revoke sending only if there is anyone waiting in the queue
+ if (curTime - startTime >= Shuffle.MaxChatTime &&
+ threadPool.getQueue.size > 0) {
+ keepSending = false
+ }
+ } else {
+ // Close the connection
+ }
+ }
+ } catch {
+ // If something went wrong, e.g., the worker at the other end died etc
+ // then close everything up
+ // Exception can happen if the receiver stops receiving
+ // EOFException is expected to happen because receiver can break
+ // connection as soon as it has all the blocks
+ case eofe: java.io.EOFException => { }
+ case e: Exception => {
+ logInfo("ShuffleServerThread had a " + e)
+ }
+ } finally {
+ logInfo("ShuffleServerThread is closing streams and sockets")
+ ois.close()
+ // TODO: Following can cause "java.net.SocketException: Socket closed"
+ oos.close()
+ bos.close()
+ clientSocket.close()
+ }
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/CustomParallelFakeShuffle.scala b/core/src/main/scala/spark/CustomParallelFakeShuffle.scala
new file mode 100644
index 0000000000..889c5111b6
--- /dev/null
+++ b/core/src/main/scala/spark/CustomParallelFakeShuffle.scala
@@ -0,0 +1,495 @@
+package spark
+
+import java.io._
+import java.net._
+import java.util.{BitSet, Random, Timer, TimerTask, UUID}
+import java.util.concurrent.atomic.AtomicLong
+import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory}
+
+import scala.collection.mutable.{ArrayBuffer, HashMap}
+
+/**
+ * TODO: THIS IS AN ABSOLUTELY EXPERIMENTAL IMPLEMENTATON (FOR NOW).
+ *
+ * An implementation of shuffle using fake data served through custom server
+ * where receivers create simultaneous connections to multiple servers by
+ * setting the 'spark.shuffle.maxRxConnections' config option.
+ *
+ * TODO: Add support for compression when spark.compress is set to true.
+ */
+@serializable
+class CustomParallelFakeShuffle[K, V, C]
+extends Shuffle[K, V, C] with Logging {
+ @transient var totalSplits = 0
+ @transient var hasSplits = 0
+ @transient var hasSplitsBitVector: BitSet = null
+ @transient var splitsInRequestBitVector: BitSet = null
+
+ @transient var combiners: HashMap[K,C] = null
+
+ override def compute(input: RDD[(K, V)],
+ numOutputSplits: Int,
+ createCombiner: V => C,
+ mergeValue: (C, V) => C,
+ mergeCombiners: (C, C) => C)
+ : RDD[(K, C)] =
+ {
+ val sc = input.sparkContext
+ val shuffleId = CustomParallelFakeShuffle.newShuffleId()
+ logInfo("Shuffle ID: " + shuffleId)
+
+ val splitRdd = new NumberedSplitRDD(input)
+ val numInputSplits = splitRdd.splits.size
+
+ // Run a parallel map and collect to write the intermediate data files,
+ // returning a list of inputSplitId -> serverUri pairs
+ val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => {
+ val myIndex = pair._1
+ val myIterator = pair._2
+ val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C])
+ for ((k, v) <- myIterator) {
+ var bucketId = k.hashCode % numOutputSplits
+ if (bucketId < 0) { // Fix bucket ID if hash code was negative
+ bucketId += numOutputSplits
+ }
+ val bucket = buckets(bucketId)
+ bucket(k) = bucket.get(k) match {
+ case Some(c) => mergeValue(c, v)
+ case None => createCombiner(v)
+ }
+ }
+
+ for (i <- 0 until numOutputSplits) {
+ val splitName =
+ CustomParallelFakeShuffle.getSplitName(shuffleId, myIndex, i)
+
+ val writeStartTime = System.currentTimeMillis
+ logInfo("BEGIN WRITE: " + splitName)
+
+ // Write buckets(i) to a byte array & put in splitsCache instead of file
+ val baos = new ByteArrayOutputStream
+ val oos = new ObjectOutputStream(baos)
+ buckets(i).foreach(pair => oos.writeObject(pair))
+ oos.close
+ baos.close
+
+ // Store the length only
+ CustomParallelFakeShuffle.splitsCache(splitName) = baos.toByteArray.length
+ val splitLen = baos.toByteArray.length
+
+ logInfo("END WRITE: " + splitName)
+ val writeTime = System.currentTimeMillis - writeStartTime
+ logInfo("Writing " + splitName + " of size " + splitLen + " bytes took " + writeTime + " millis.")
+ }
+
+ (myIndex, CustomParallelFakeShuffle.serverAddress,
+ CustomParallelFakeShuffle.serverPort)
+ }).collect()
+
+ val splitsByUri = new ArrayBuffer[(String, Int, Int)]
+ for ((inputId, serverAddress, serverPort) <- outputLocs) {
+ splitsByUri += ((serverAddress, serverPort, inputId))
+ }
+
+ // TODO: Could broadcast splitsByUri
+
+ // Return an RDD that does each of the merges for a given partition
+ val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits)
+ return indexes.flatMap((myId: Int) => {
+ totalSplits = splitsByUri.size
+ hasSplits = 0
+ hasSplitsBitVector = new BitSet(totalSplits)
+ splitsInRequestBitVector = new BitSet(totalSplits)
+
+ combiners = new HashMap[K, C]
+
+ var threadPool = Shuffle.newDaemonFixedThreadPool(
+ Shuffle.MaxRxConnections)
+
+ var totalThreadsCreated = 0
+
+ while (hasSplits < totalSplits) {
+ var numThreadsToCreate = Math.min(totalSplits,
+ Shuffle.MaxRxConnections) -
+ threadPool.getActiveCount
+
+ while (hasSplits < totalSplits && numThreadsToCreate > 0) {
+ // Select a random split to pull
+ val splitIndex = selectRandomSplit
+
+ if (splitIndex != -1) {
+ val (serverAddress, serverPort, inputId) = splitsByUri(splitIndex)
+ val requestSplit = "%d/%d/%d".format(shuffleId, inputId, myId)
+
+ threadPool.execute(new ShuffleClient(splitIndex, serverAddress,
+ serverPort, requestSplit))
+ totalThreadsCreated += 1
+ logInfo("totalThreadsCreated = %d / threadPool.getActiveCount = %d".format(totalThreadsCreated, threadPool.getActiveCount.toInt))
+
+ // splitIndex is in transit. Will be unset in the ShuffleClient
+ splitsInRequestBitVector.synchronized {
+ splitsInRequestBitVector.set(splitIndex)
+ }
+ }
+
+ numThreadsToCreate = numThreadsToCreate - 1
+ }
+
+ // Sleep for a while before creating new threads
+ Thread.sleep(Shuffle.MinKnockInterval)
+ }
+
+ threadPool.shutdown()
+
+ combiners
+ })
+ }
+
+ def selectRandomSplit: Int = {
+ var requiredSplits = new ArrayBuffer[Int]
+
+ synchronized {
+ for (i <- 0 until totalSplits) {
+ if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) {
+ requiredSplits += i
+ }
+ }
+ }
+
+ if (requiredSplits.size > 0) {
+ requiredSplits(CustomParallelFakeShuffle.ranGen.nextInt(
+ requiredSplits.size))
+ } else {
+ -1
+ }
+ }
+
+ class ShuffleClient(splitIndex: Int, hostAddress: String, listenPort: Int,
+ requestSplit: String)
+ extends Thread with Logging {
+ private var peerSocketToSource: Socket = null
+ private var oosSource: ObjectOutputStream = null
+ private var oisSource: ObjectInputStream = null
+
+ private var receptionSucceeded = false
+
+ override def run: Unit = {
+ // Setup the timeout mechanism
+ var timeOutTask = new TimerTask {
+ override def run: Unit = {
+ cleanUpConnections()
+ }
+ }
+
+ var timeOutTimer = new Timer
+ timeOutTimer.schedule(timeOutTask, Shuffle.MaxKnockInterval)
+
+ logInfo("ShuffleClient started... => %s:%d#%s".format(hostAddress, listenPort, requestSplit))
+
+ try {
+ // Connect to the source
+ peerSocketToSource = new Socket(hostAddress, listenPort)
+ oosSource =
+ new ObjectOutputStream(peerSocketToSource.getOutputStream)
+ oosSource.flush()
+ var isSource = peerSocketToSource.getInputStream
+ oisSource = new ObjectInputStream(isSource)
+
+ // Send the request
+ oosSource.writeObject(requestSplit)
+
+ // Receive the length of the requested file
+ var requestedFileLen = oisSource.readObject.asInstanceOf[Int]
+ logInfo("Received requestedFileLen = " + requestedFileLen)
+
+ // Turn the timer OFF, if the sender responds before timeout
+ timeOutTimer.cancel()
+
+ // Receive the file
+ if (requestedFileLen != -1) {
+ val readStartTime = System.currentTimeMillis
+ logInfo("BEGIN READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit))
+
+// var recvByteArray = new Array[Byte](requestedFileLen)
+// var alreadyRead = 0
+// var bytesRead = 0
+//
+// while (alreadyRead != requestedFileLen) {
+// bytesRead = isSource.read(recvByteArray, alreadyRead,
+// requestedFileLen - alreadyRead)
+// if (bytesRead > 0) {
+// alreadyRead = alreadyRead + bytesRead
+// }
+// }
+
+ // Receive data in an Array[Byte]
+ var recvByteArray = new Array[Byte](65536)
+ var alreadyRead = 0
+ var bytesRead = 0
+
+ while (alreadyRead != requestedFileLen) {
+ bytesRead = isSource.read(recvByteArray)
+ if (bytesRead > 0) {
+ alreadyRead = alreadyRead + bytesRead
+ }
+ }
+
+ // TODO: Updating stats before consumption is completed
+ hasSplitsBitVector.synchronized {
+ hasSplitsBitVector.set(splitIndex)
+ hasSplits += 1
+ }
+
+ // We have received splitIndex
+ splitsInRequestBitVector.synchronized {
+ splitsInRequestBitVector.set(splitIndex, false)
+ }
+
+ receptionSucceeded = true
+
+ logInfo("END READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit))
+ val readTime = System.currentTimeMillis - readStartTime
+ logInfo("Reading http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit) + " took " + readTime + " millis.")
+ } else {
+ throw new SparkException("ShuffleServer " + hostAddress + " does not have " + requestSplit)
+ }
+ } catch {
+ // EOFException is expected to happen because sender can break
+ // connection due to timeout
+ case eofe: java.io.EOFException => { }
+ case e: Exception => {
+ logInfo("ShuffleClient had a " + e)
+ }
+ } finally {
+ // If reception failed, unset for future retry
+ if (!receptionSucceeded) {
+ splitsInRequestBitVector.synchronized {
+ splitsInRequestBitVector.set(splitIndex, false)
+ }
+ }
+ cleanUpConnections()
+ }
+ }
+
+ private def cleanUpConnections(): Unit = {
+ if (oisSource != null) {
+ oisSource.close()
+ }
+ if (oosSource != null) {
+ oosSource.close()
+ }
+ if (peerSocketToSource != null) {
+ peerSocketToSource.close()
+ }
+ }
+ }
+}
+
+object CustomParallelFakeShuffle extends Logging {
+ // Cache for keeping the splits around
+ val splitsCache = new HashMap[String, Int]
+
+ private var initialized = false
+ private var nextShuffleId = new AtomicLong(0)
+
+ // Variables initialized by initializeIfNeeded()
+ private var shuffleDir: File = null
+
+ private var shuffleServer: ShuffleServer = null
+ private var serverAddress = InetAddress.getLocalHost.getHostAddress
+ private var serverPort: Int = -1
+
+ // Random number generator
+ var ranGen = new Random
+
+ private def initializeIfNeeded() = synchronized {
+ if (!initialized) {
+ // TODO: localDir should be created by some mechanism common to Spark
+ // so that it can be shared among shuffle, broadcast, etc
+ val localDirRoot = System.getProperty("spark.local.dir", "/tmp")
+ var tries = 0
+ var foundLocalDir = false
+ var localDir: File = null
+ var localDirUuid: UUID = null
+
+ while (!foundLocalDir && tries < 10) {
+ tries += 1
+ try {
+ localDirUuid = UUID.randomUUID
+ localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
+ if (!localDir.exists) {
+ localDir.mkdirs()
+ foundLocalDir = true
+ }
+ } catch {
+ case e: Exception =>
+ logWarning("Attempt " + tries + " to create local dir failed", e)
+ }
+ }
+ if (!foundLocalDir) {
+ logError("Failed 10 attempts to create local dir in " + localDirRoot)
+ System.exit(1)
+ }
+ shuffleDir = new File(localDir, "shuffle")
+ shuffleDir.mkdirs()
+ logInfo("Shuffle dir: " + shuffleDir)
+
+ // Create and start the shuffleServer
+ shuffleServer = new ShuffleServer
+ shuffleServer.setDaemon(true)
+ shuffleServer.start()
+ logInfo("ShuffleServer started...")
+
+ initialized = true
+ }
+ }
+
+ def getSplitName(shuffleId: Long, inputId: Int, outputId: Int): String = {
+ initializeIfNeeded()
+ // Adding shuffleDir is unnecessary. Added to keep the parsers working
+ return "%s/%d/%d/%d".format(shuffleDir, shuffleId, inputId, outputId)
+ }
+
+ def newShuffleId(): Long = {
+ nextShuffleId.getAndIncrement()
+ }
+
+ // Returns a standard ThreadFactory except all threads are daemons
+ private def newDaemonThreadFactory: ThreadFactory = {
+ new ThreadFactory {
+ def newThread(r: Runnable): Thread = {
+ var t = Executors.defaultThreadFactory.newThread(r)
+ t.setDaemon(true)
+ return t
+ }
+ }
+ }
+
+ // Wrapper over newFixedThreadPool
+ def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = {
+ var threadPool =
+ Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor]
+
+ threadPool.setThreadFactory(newDaemonThreadFactory)
+
+ return threadPool
+ }
+
+ class ShuffleServer
+ extends Thread with Logging {
+ var threadPool = newDaemonFixedThreadPool(Shuffle.MaxTxConnections)
+
+ var serverSocket: ServerSocket = null
+
+ override def run: Unit = {
+ serverSocket = new ServerSocket(0)
+ serverPort = serverSocket.getLocalPort
+
+ logInfo("ShuffleServer started with " + serverSocket)
+ logInfo("Local URI: http://" + serverAddress + ":" + serverPort)
+
+ try {
+ while (true) {
+ var clientSocket: Socket = null
+ try {
+ clientSocket = serverSocket.accept()
+ } catch {
+ case e: Exception => { }
+ }
+ if (clientSocket != null) {
+ logInfo("Serve: Accepted new client connection:" + clientSocket)
+ try {
+ threadPool.execute(new ShuffleServerThread(clientSocket))
+ } catch {
+ // In failure, close socket here; else, the thread will close it
+ case ioe: IOException => {
+ clientSocket.close()
+ }
+ }
+ }
+ }
+ } finally {
+ if (serverSocket != null) {
+ logInfo("ShuffleServer now stopping...")
+ serverSocket.close()
+ }
+ }
+ // Shutdown the thread pool
+ threadPool.shutdown()
+ }
+
+ class ShuffleServerThread(val clientSocket: Socket)
+ extends Thread with Logging {
+ private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream]
+ os.flush()
+ private val bos = new BufferedOutputStream(os)
+ bos.flush()
+ private val oos = new ObjectOutputStream(os)
+ oos.flush()
+ private val ois = new ObjectInputStream(clientSocket.getInputStream)
+
+ logInfo("new ShuffleServerThread is running")
+
+ override def run: Unit = {
+ try {
+ // Receive requestedSplit from the receiver
+ // Adding shuffleDir is unnecessary. Added to keep the parsers working
+ var requestedSplit =
+ shuffleDir + "/" + ois.readObject.asInstanceOf[String]
+ logInfo("requestedSplit: " + requestedSplit)
+
+ // Send the length of the requestedSplit to let the receiver know that
+ // transfer is about to start
+ // In the case of receiver timeout and connection close, this will
+ // throw a java.net.SocketException: Broken pipe
+ var requestedSplitLen = -1
+
+ try {
+ requestedSplitLen =
+ CustomParallelFakeShuffle.splitsCache(requestedSplit)
+ } catch {
+ case e: Exception => { }
+ }
+
+ oos.writeObject(requestedSplitLen)
+ oos.flush()
+
+ logInfo("requestedSplitLen = " + requestedSplitLen)
+
+ // Read and send the requested split
+ if (requestedSplitLen != -1) {
+ // Send fake data
+// var byteArray = new Array[Byte](requestedSplitLen)
+// bos.write(byteArray, 0, byteArray.length)
+// bos.flush()
+
+ val buf = new Array[Byte](65536)
+ var bytesSent = 0
+ while (bytesSent < requestedSplitLen) {
+ val bytesToSend = Math.min(requestedSplitLen - bytesSent, buf.length)
+ bos.write(buf, 0, bytesToSend)
+ bos.flush()
+ bytesSent += bytesToSend
+ }
+ } else {
+ // Close the connection
+ }
+ } catch {
+ // If something went wrong, e.g., the worker at the other end died etc
+ // then close everything up
+ // Exception can happen if the receiver stops receiving
+ case e: Exception => {
+ logInfo("ShuffleServerThread had a " + e)
+ }
+ } finally {
+ logInfo("ShuffleServerThread is closing streams and sockets")
+ ois.close()
+ // TODO: Following can cause "java.net.SocketException: Socket closed"
+ oos.close()
+ bos.close()
+ clientSocket.close()
+ }
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/CustomParallelInMemoryShuffle.scala b/core/src/main/scala/spark/CustomParallelInMemoryShuffle.scala
new file mode 100644
index 0000000000..48f3685a1a
--- /dev/null
+++ b/core/src/main/scala/spark/CustomParallelInMemoryShuffle.scala
@@ -0,0 +1,535 @@
+package spark
+
+import java.io._
+import java.net._
+import java.util.{BitSet, Random, Timer, TimerTask, UUID}
+import java.util.concurrent.atomic.AtomicLong
+import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory}
+
+import scala.collection.mutable.{ArrayBuffer, HashMap}
+
+/**
+ * TODO: THIS IS AN ABSOLUTELY EXPERIMENTAL IMPLEMENTATON (FOR NOW).
+ *
+ * An implementation of shuffle using local memory served through custom server
+ * where receivers create simultaneous connections to multiple servers by
+ * setting the 'spark.shuffle.maxRxConnections' config option.
+ *
+ * TODO: Add support for compression when spark.compress is set to true.
+ */
+@serializable
+class CustomParallelInMemoryShuffle[K, V, C]
+extends Shuffle[K, V, C] with Logging {
+ @transient var totalSplits = 0
+ @transient var hasSplits = 0
+ @transient var hasSplitsBitVector: BitSet = null
+ @transient var splitsInRequestBitVector: BitSet = null
+
+ @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null
+ @transient var combiners: HashMap[K,C] = null
+
+ override def compute(input: RDD[(K, V)],
+ numOutputSplits: Int,
+ createCombiner: V => C,
+ mergeValue: (C, V) => C,
+ mergeCombiners: (C, C) => C)
+ : RDD[(K, C)] =
+ {
+ val sc = input.sparkContext
+ val shuffleId = CustomParallelInMemoryShuffle.newShuffleId()
+ logInfo("Shuffle ID: " + shuffleId)
+
+ val splitRdd = new NumberedSplitRDD(input)
+ val numInputSplits = splitRdd.splits.size
+
+ // Run a parallel map and collect to write the intermediate data files,
+ // returning a list of inputSplitId -> serverUri pairs
+ val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => {
+ val myIndex = pair._1
+ val myIterator = pair._2
+ val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C])
+ for ((k, v) <- myIterator) {
+ var bucketId = k.hashCode % numOutputSplits
+ if (bucketId < 0) { // Fix bucket ID if hash code was negative
+ bucketId += numOutputSplits
+ }
+ val bucket = buckets(bucketId)
+ bucket(k) = bucket.get(k) match {
+ case Some(c) => mergeValue(c, v)
+ case None => createCombiner(v)
+ }
+ }
+
+ for (i <- 0 until numOutputSplits) {
+ val splitName =
+ CustomParallelInMemoryShuffle.getSplitName(shuffleId, myIndex, i)
+
+ val writeStartTime = System.currentTimeMillis
+ logInfo("BEGIN WRITE: " + splitName)
+
+ // Write buckets(i) to a byte array & put in splitsCache instead of file
+ val baos = new ByteArrayOutputStream
+ val oos = new ObjectOutputStream(baos)
+ buckets(i).foreach(pair => oos.writeObject(pair))
+ oos.close
+ baos.close
+
+ CustomParallelInMemoryShuffle.splitsCache(splitName) = baos.toByteArray
+ val splitLen =
+ CustomParallelInMemoryShuffle.splitsCache(splitName).length
+
+ logInfo("END WRITE: " + splitName)
+ val writeTime = System.currentTimeMillis - writeStartTime
+ logInfo("Writing " + splitName + " of size " + splitLen + " bytes took " + writeTime + " millis.")
+ }
+
+ (myIndex, CustomParallelInMemoryShuffle.serverAddress,
+ CustomParallelInMemoryShuffle.serverPort)
+ }).collect()
+
+ val splitsByUri = new ArrayBuffer[(String, Int, Int)]
+ for ((inputId, serverAddress, serverPort) <- outputLocs) {
+ splitsByUri += ((serverAddress, serverPort, inputId))
+ }
+
+ // TODO: Could broadcast splitsByUri
+
+ // Return an RDD that does each of the merges for a given partition
+ val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits)
+ return indexes.flatMap((myId: Int) => {
+ totalSplits = splitsByUri.size
+ hasSplits = 0
+ hasSplitsBitVector = new BitSet(totalSplits)
+ splitsInRequestBitVector = new BitSet(totalSplits)
+
+ receivedData = new LinkedBlockingQueue[(Int, Array[Byte])]
+ combiners = new HashMap[K, C]
+
+ var threadPool = Shuffle.newDaemonFixedThreadPool(
+ Shuffle.MaxRxConnections)
+
+ while (hasSplits < totalSplits) {
+ var numThreadsToCreate = Math.min(totalSplits,
+ Shuffle.MaxRxConnections) -
+ threadPool.getActiveCount
+
+ while (hasSplits < totalSplits && numThreadsToCreate > 0) {
+ // Select a random split to pull
+ val splitIndex = selectRandomSplit
+
+ if (splitIndex != -1) {
+ val (serverAddress, serverPort, inputId) = splitsByUri(splitIndex)
+ val requestSplit = "%d/%d/%d".format(shuffleId, inputId, myId)
+
+ threadPool.execute(new ShuffleClient(splitIndex, serverAddress,
+ serverPort, requestSplit))
+
+ // splitIndex is in transit. Will be unset in the ShuffleClient
+ splitsInRequestBitVector.synchronized {
+ splitsInRequestBitVector.set(splitIndex)
+ }
+ }
+
+ numThreadsToCreate = numThreadsToCreate - 1
+ }
+
+ // Sleep for a while before creating new threads
+ Thread.sleep(Shuffle.MinKnockInterval)
+ }
+
+ threadPool.shutdown()
+
+ // Start consumer
+ // TODO: Consumption is delayed until everything has been received.
+ // Otherwise it interferes with network performance
+ var shuffleConsumer = new ShuffleConsumer(mergeCombiners)
+ shuffleConsumer.setDaemon(true)
+ shuffleConsumer.start()
+ logInfo("ShuffleConsumer started...")
+
+ // Don't return until consumption is finished
+ // while (receivedData.size > 0) {
+ // Thread.sleep(Shuffle.MinKnockInterval)
+ // }
+
+ // Wait till shuffleConsumer is done
+ shuffleConsumer.join
+
+ combiners
+ })
+ }
+
+ def selectRandomSplit: Int = {
+ var requiredSplits = new ArrayBuffer[Int]
+
+ synchronized {
+ for (i <- 0 until totalSplits) {
+ if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) {
+ requiredSplits += i
+ }
+ }
+ }
+
+ if (requiredSplits.size > 0) {
+ requiredSplits(CustomParallelInMemoryShuffle.ranGen.nextInt(
+ requiredSplits.size))
+ } else {
+ -1
+ }
+ }
+
+ class ShuffleConsumer(mergeCombiners: (C, C) => C)
+ extends Thread with Logging {
+ override def run: Unit = {
+ // Run until all splits are here
+ while (receivedData.size > 0) {
+ var splitIndex = -1
+ var recvByteArray: Array[Byte] = null
+
+ try {
+ var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])]
+ splitIndex = tempPair._1
+ recvByteArray = tempPair._2
+ } catch {
+ case e: Exception => {
+ logInfo("Exception during taking data from receivedData")
+ }
+ }
+
+ val inputStream =
+ new ObjectInputStream(new ByteArrayInputStream(recvByteArray))
+
+ try{
+ while (true) {
+ val (k, c) = inputStream.readObject.asInstanceOf[(K, C)]
+ combiners(k) = combiners.get(k) match {
+ case Some(oldC) => mergeCombiners(oldC, c)
+ case None => c
+ }
+ }
+ } catch {
+ case e: EOFException => { }
+ }
+ inputStream.close()
+ }
+ }
+ }
+
+ class ShuffleClient(splitIndex: Int, hostAddress: String, listenPort: Int,
+ requestSplit: String)
+ extends Thread with Logging {
+ private var peerSocketToSource: Socket = null
+ private var oosSource: ObjectOutputStream = null
+ private var oisSource: ObjectInputStream = null
+
+ private var receptionSucceeded = false
+
+ override def run: Unit = {
+ // Setup the timeout mechanism
+ var timeOutTask = new TimerTask {
+ override def run: Unit = {
+ cleanUpConnections()
+ }
+ }
+
+ var timeOutTimer = new Timer
+ timeOutTimer.schedule(timeOutTask, Shuffle.MaxKnockInterval)
+
+ logInfo("ShuffleClient started... => %s:%d#%s".format(hostAddress, listenPort, requestSplit))
+
+ try {
+ // Connect to the source
+ peerSocketToSource = new Socket(hostAddress, listenPort)
+ oosSource =
+ new ObjectOutputStream(peerSocketToSource.getOutputStream)
+ oosSource.flush()
+ var isSource = peerSocketToSource.getInputStream
+ oisSource = new ObjectInputStream(isSource)
+
+ // Send the request
+ oosSource.writeObject(requestSplit)
+
+ // Receive the length of the requested file
+ var requestedFileLen = oisSource.readObject.asInstanceOf[Int]
+ logInfo("Received requestedFileLen = " + requestedFileLen)
+
+ // Turn the timer OFF, if the sender responds before timeout
+ timeOutTimer.cancel()
+
+ // Receive the file
+ if (requestedFileLen != -1) {
+ val readStartTime = System.currentTimeMillis
+ logInfo("BEGIN READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit))
+
+ // Receive data in an Array[Byte]
+ var recvByteArray = new Array[Byte](requestedFileLen)
+ var alreadyRead = 0
+ var bytesRead = 0
+
+ while (alreadyRead != requestedFileLen) {
+ bytesRead = isSource.read(recvByteArray, alreadyRead,
+ requestedFileLen - alreadyRead)
+ if (bytesRead > 0) {
+ alreadyRead = alreadyRead + bytesRead
+ }
+ }
+
+ // Make it available to the consumer
+ try {
+ receivedData.put((splitIndex, recvByteArray))
+ } catch {
+ case e: Exception => {
+ logInfo("Exception during putting data into receivedData")
+ }
+ }
+
+ // TODO: Updating stats before consumption is completed
+ hasSplitsBitVector.synchronized {
+ hasSplitsBitVector.set(splitIndex)
+ hasSplits += 1
+ }
+
+ // We have received splitIndex
+ splitsInRequestBitVector.synchronized {
+ splitsInRequestBitVector.set(splitIndex, false)
+ }
+
+ receptionSucceeded = true
+
+ logInfo("END READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit))
+ val readTime = System.currentTimeMillis - readStartTime
+ logInfo("Reading http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit) + " took " + readTime + " millis.")
+ } else {
+ throw new SparkException("ShuffleServer " + hostAddress + " does not have " + requestSplit)
+ }
+ } catch {
+ // EOFException is expected to happen because sender can break
+ // connection due to timeout
+ case eofe: java.io.EOFException => { }
+ case e: Exception => {
+ logInfo("ShuffleClient had a " + e)
+ }
+ } finally {
+ // If reception failed, unset for future retry
+ if (!receptionSucceeded) {
+ splitsInRequestBitVector.synchronized {
+ splitsInRequestBitVector.set(splitIndex, false)
+ }
+ }
+ cleanUpConnections()
+ }
+ }
+
+ private def cleanUpConnections(): Unit = {
+ if (oisSource != null) {
+ oisSource.close()
+ }
+ if (oosSource != null) {
+ oosSource.close()
+ }
+ if (peerSocketToSource != null) {
+ peerSocketToSource.close()
+ }
+ }
+ }
+}
+
+object CustomParallelInMemoryShuffle extends Logging {
+ // Cache for keeping the splits around
+ val splitsCache = new HashMap[String, Array[Byte]]
+
+ private var initialized = false
+ private var nextShuffleId = new AtomicLong(0)
+
+ // Variables initialized by initializeIfNeeded()
+ private var shuffleDir: File = null
+
+ private var shuffleServer: ShuffleServer = null
+ private var serverAddress = InetAddress.getLocalHost.getHostAddress
+ private var serverPort: Int = -1
+
+ // Random number generator
+ var ranGen = new Random
+
+ private def initializeIfNeeded() = synchronized {
+ if (!initialized) {
+ // TODO: localDir should be created by some mechanism common to Spark
+ // so that it can be shared among shuffle, broadcast, etc
+ val localDirRoot = System.getProperty("spark.local.dir", "/tmp")
+ var tries = 0
+ var foundLocalDir = false
+ var localDir: File = null
+ var localDirUuid: UUID = null
+
+ while (!foundLocalDir && tries < 10) {
+ tries += 1
+ try {
+ localDirUuid = UUID.randomUUID
+ localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
+ if (!localDir.exists) {
+ localDir.mkdirs()
+ foundLocalDir = true
+ }
+ } catch {
+ case e: Exception =>
+ logWarning("Attempt " + tries + " to create local dir failed", e)
+ }
+ }
+ if (!foundLocalDir) {
+ logError("Failed 10 attempts to create local dir in " + localDirRoot)
+ System.exit(1)
+ }
+ shuffleDir = new File(localDir, "shuffle")
+ shuffleDir.mkdirs()
+ logInfo("Shuffle dir: " + shuffleDir)
+
+ // Create and start the shuffleServer
+ shuffleServer = new ShuffleServer
+ shuffleServer.setDaemon(true)
+ shuffleServer.start()
+ logInfo("ShuffleServer started...")
+
+ initialized = true
+ }
+ }
+
+ def getSplitName(shuffleId: Long, inputId: Int, outputId: Int): String = {
+ initializeIfNeeded()
+ // Adding shuffleDir is unnecessary. Added to keep the parsers working
+ return "%s/%d/%d/%d".format(shuffleDir, shuffleId, inputId, outputId)
+ }
+
+ def newShuffleId(): Long = {
+ nextShuffleId.getAndIncrement()
+ }
+
+ // Returns a standard ThreadFactory except all threads are daemons
+ private def newDaemonThreadFactory: ThreadFactory = {
+ new ThreadFactory {
+ def newThread(r: Runnable): Thread = {
+ var t = Executors.defaultThreadFactory.newThread(r)
+ t.setDaemon(true)
+ return t
+ }
+ }
+ }
+
+ // Wrapper over newFixedThreadPool
+ def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = {
+ var threadPool =
+ Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor]
+
+ threadPool.setThreadFactory(newDaemonThreadFactory)
+
+ return threadPool
+ }
+
+ class ShuffleServer
+ extends Thread with Logging {
+ var threadPool = newDaemonFixedThreadPool(Shuffle.MaxTxConnections)
+
+ var serverSocket: ServerSocket = null
+
+ override def run: Unit = {
+ serverSocket = new ServerSocket(0)
+ serverPort = serverSocket.getLocalPort
+
+ logInfo("ShuffleServer started with " + serverSocket)
+ logInfo("Local URI: http://" + serverAddress + ":" + serverPort)
+
+ try {
+ while (true) {
+ var clientSocket: Socket = null
+ try {
+ clientSocket = serverSocket.accept()
+ } catch {
+ case e: Exception => { }
+ }
+ if (clientSocket != null) {
+ logInfo("Serve: Accepted new client connection:" + clientSocket)
+ try {
+ threadPool.execute(new ShuffleServerThread(clientSocket))
+ } catch {
+ // In failure, close socket here; else, the thread will close it
+ case ioe: IOException => {
+ clientSocket.close()
+ }
+ }
+ }
+ }
+ } finally {
+ if (serverSocket != null) {
+ logInfo("ShuffleServer now stopping...")
+ serverSocket.close()
+ }
+ }
+ // Shutdown the thread pool
+ threadPool.shutdown()
+ }
+
+ class ShuffleServerThread(val clientSocket: Socket)
+ extends Thread with Logging {
+ private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream]
+ os.flush()
+ private val bos = new BufferedOutputStream(os)
+ bos.flush()
+ private val oos = new ObjectOutputStream(os)
+ oos.flush()
+ private val ois = new ObjectInputStream(clientSocket.getInputStream)
+
+ logInfo("new ShuffleServerThread is running")
+
+ override def run: Unit = {
+ try {
+ // Receive requestedSplit from the receiver
+ // Adding shuffleDir is unnecessary. Added to keep the parsers working
+ var requestedSplit =
+ shuffleDir + "/" + ois.readObject.asInstanceOf[String]
+ logInfo("requestedSplit: " + requestedSplit)
+
+ // Send the length of the requestedSplit to let the receiver know that
+ // transfer is about to start
+ // In the case of receiver timeout and connection close, this will
+ // throw a java.net.SocketException: Broken pipe
+ var requestedSplitLen = -1
+
+ try {
+ requestedSplitLen =
+ CustomParallelInMemoryShuffle.splitsCache(requestedSplit).length
+ } catch {
+ case e: Exception => { }
+ }
+
+ oos.writeObject(requestedSplitLen)
+ oos.flush()
+
+ logInfo("requestedSplitLen = " + requestedSplitLen)
+
+ // Read and send the requested split
+ if (requestedSplitLen != -1) {
+ // Send
+ bos.write(CustomParallelInMemoryShuffle.splitsCache(requestedSplit),
+ 0, requestedSplitLen)
+ bos.flush()
+ } else {
+ // Close the connection
+ }
+ } catch {
+ // If something went wrong, e.g., the worker at the other end died etc
+ // then close everything up
+ // Exception can happen if the receiver stops receiving
+ case e: Exception => {
+ logInfo("ShuffleServerThread had a " + e)
+ }
+ } finally {
+ logInfo("ShuffleServerThread is closing streams and sockets")
+ ois.close()
+ // TODO: Following can cause "java.net.SocketException: Socket closed"
+ oos.close()
+ bos.close()
+ clientSocket.close()
+ }
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/CustomParallelLocalFileShuffle.scala b/core/src/main/scala/spark/CustomParallelLocalFileShuffle.scala
new file mode 100644
index 0000000000..87e824fb2e
--- /dev/null
+++ b/core/src/main/scala/spark/CustomParallelLocalFileShuffle.scala
@@ -0,0 +1,541 @@
+package spark
+
+import java.io._
+import java.net._
+import java.util.{BitSet, Random, Timer, TimerTask, UUID}
+import java.util.concurrent.atomic.AtomicLong
+import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory}
+
+import scala.collection.mutable.{ArrayBuffer, HashMap}
+
+/**
+ * An implementation of shuffle using local files served through custom server
+ * where receivers create simultaneous connections to multiple servers by
+ * setting the 'spark.shuffle.maxRxConnections' config option.
+ *
+ * 'spark.shuffle.maxTxConnections' enforces server-side cap. Ideally,
+ * maxTxConnections >= maxRxConnections * numReducersPerMachine
+ *
+ * TODO: Add support for compression when spark.compress is set to true.
+ */
+@serializable
+class CustomParallelLocalFileShuffle[K, V, C]
+extends Shuffle[K, V, C] with Logging {
+ @transient var totalSplits = 0
+ @transient var hasSplits = 0
+ @transient var hasSplitsBitVector: BitSet = null
+ @transient var splitsInRequestBitVector: BitSet = null
+
+ @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null
+ @transient var combiners: HashMap[K,C] = null
+
+ override def compute(input: RDD[(K, V)],
+ numOutputSplits: Int,
+ createCombiner: V => C,
+ mergeValue: (C, V) => C,
+ mergeCombiners: (C, C) => C)
+ : RDD[(K, C)] =
+ {
+ val sc = input.sparkContext
+ val shuffleId = CustomParallelLocalFileShuffle.newShuffleId()
+ logInfo("Shuffle ID: " + shuffleId)
+
+ val splitRdd = new NumberedSplitRDD(input)
+ val numInputSplits = splitRdd.splits.size
+
+ // Run a parallel map and collect to write the intermediate data files,
+ // returning a list of inputSplitId -> serverUri pairs
+ val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => {
+ val myIndex = pair._1
+ val myIterator = pair._2
+ val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C])
+ for ((k, v) <- myIterator) {
+ var bucketId = k.hashCode % numOutputSplits
+ if (bucketId < 0) { // Fix bucket ID if hash code was negative
+ bucketId += numOutputSplits
+ }
+ val bucket = buckets(bucketId)
+ bucket(k) = bucket.get(k) match {
+ case Some(c) => mergeValue(c, v)
+ case None => createCombiner(v)
+ }
+ }
+
+ for (i <- 0 until numOutputSplits) {
+ val file = CustomParallelLocalFileShuffle.getOutputFile(shuffleId,
+ myIndex, i)
+ val writeStartTime = System.currentTimeMillis
+ logInfo("BEGIN WRITE: " + file)
+ val out = new ObjectOutputStream(new FileOutputStream(file))
+ buckets(i).foreach(pair => out.writeObject(pair))
+ out.close()
+ logInfo("END WRITE: " + file)
+ val writeTime = System.currentTimeMillis - writeStartTime
+ logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.")
+ }
+
+ (myIndex, CustomParallelLocalFileShuffle.serverAddress,
+ CustomParallelLocalFileShuffle.serverPort)
+ }).collect()
+
+ val splitsByUri = new ArrayBuffer[(String, Int, Int)]
+ for ((inputId, serverAddress, serverPort) <- outputLocs) {
+ splitsByUri += ((serverAddress, serverPort, inputId))
+ }
+
+ // TODO: Could broadcast splitsByUri
+
+ // Return an RDD that does each of the merges for a given partition
+ val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits)
+ return indexes.flatMap((myId: Int) => {
+ totalSplits = splitsByUri.size
+ hasSplits = 0
+ hasSplitsBitVector = new BitSet(totalSplits)
+ splitsInRequestBitVector = new BitSet(totalSplits)
+
+ receivedData = new LinkedBlockingQueue[(Int, Array[Byte])]
+ combiners = new HashMap[K, C]
+
+ var threadPool = Shuffle.newDaemonFixedThreadPool(
+ Shuffle.MaxRxConnections)
+
+ while (hasSplits < totalSplits) {
+ var numThreadsToCreate = Math.min(totalSplits,
+ Shuffle.MaxRxConnections) -
+ threadPool.getActiveCount
+
+ while (hasSplits < totalSplits && numThreadsToCreate > 0) {
+ // Select a random split to pull
+ val splitIndex = selectRandomSplit
+
+ if (splitIndex != -1) {
+ val (serverAddress, serverPort, inputId) = splitsByUri(splitIndex)
+ val requestPath = "%d/%d/%d".format(shuffleId, inputId, myId)
+
+ threadPool.execute(new ShuffleClient(splitIndex, serverAddress,
+ serverPort, requestPath))
+
+ // splitIndex is in transit. Will be unset in the ShuffleClient
+ splitsInRequestBitVector.synchronized {
+ splitsInRequestBitVector.set(splitIndex)
+ }
+ }
+
+ numThreadsToCreate = numThreadsToCreate - 1
+ }
+
+ // Sleep for a while before creating new threads
+ Thread.sleep(Shuffle.MinKnockInterval)
+ }
+
+ threadPool.shutdown()
+
+ // Start consumer
+ // TODO: Consumption is delayed until everything has been received.
+ // Otherwise it interferes with network performance
+ var shuffleConsumer = new ShuffleConsumer(mergeCombiners)
+ shuffleConsumer.setDaemon(true)
+ shuffleConsumer.start()
+ logInfo("ShuffleConsumer started...")
+
+ // Don't return until consumption is finished
+ // while (receivedData.size > 0) {
+ // Thread.sleep(Shuffle.MinKnockInterval)
+ // }
+
+ // Wait till shuffleConsumer is done
+ shuffleConsumer.join
+
+ combiners
+ })
+ }
+
+ def selectRandomSplit: Int = {
+ var requiredSplits = new ArrayBuffer[Int]
+
+ synchronized {
+ for (i <- 0 until totalSplits) {
+ if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) {
+ requiredSplits += i
+ }
+ }
+ }
+
+ if (requiredSplits.size > 0) {
+ requiredSplits(CustomParallelLocalFileShuffle.ranGen.nextInt(
+ requiredSplits.size))
+ } else {
+ -1
+ }
+ }
+
+ class ShuffleConsumer(mergeCombiners: (C, C) => C)
+ extends Thread with Logging {
+ override def run: Unit = {
+ // Run until all splits are here
+ while (receivedData.size > 0) {
+ var splitIndex = -1
+ var recvByteArray: Array[Byte] = null
+
+ try {
+ var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])]
+ splitIndex = tempPair._1
+ recvByteArray = tempPair._2
+ } catch {
+ case e: Exception => {
+ logInfo("Exception during taking data from receivedData")
+ }
+ }
+
+ val inputStream =
+ new ObjectInputStream(new ByteArrayInputStream(recvByteArray))
+
+ try{
+ while (true) {
+ val (k, c) = inputStream.readObject.asInstanceOf[(K, C)]
+ combiners(k) = combiners.get(k) match {
+ case Some(oldC) => mergeCombiners(oldC, c)
+ case None => c
+ }
+ }
+ } catch {
+ case e: EOFException => { }
+ }
+ inputStream.close()
+ }
+ }
+ }
+
+ class ShuffleClient(splitIndex: Int, hostAddress: String, listenPort: Int,
+ requestPath: String)
+ extends Thread with Logging {
+ private var peerSocketToSource: Socket = null
+ private var oosSource: ObjectOutputStream = null
+ private var oisSource: ObjectInputStream = null
+
+ private var receptionSucceeded = false
+
+ override def run: Unit = {
+ // Setup the timeout mechanism
+ var timeOutTask = new TimerTask {
+ override def run: Unit = {
+ cleanUpConnections()
+ }
+ }
+
+ var timeOutTimer = new Timer
+ timeOutTimer.schedule(timeOutTask, Shuffle.MaxKnockInterval)
+
+ logInfo("ShuffleClient started... => %s:%d#%s".format(hostAddress, listenPort, requestPath))
+
+ try {
+ // Connect to the source
+ peerSocketToSource = new Socket(hostAddress, listenPort)
+ oosSource =
+ new ObjectOutputStream(peerSocketToSource.getOutputStream)
+ oosSource.flush()
+ var isSource = peerSocketToSource.getInputStream
+ oisSource = new ObjectInputStream(isSource)
+
+ // Send the request
+ oosSource.writeObject(requestPath)
+
+ // Receive the length of the requested file
+ var requestedFileLen = oisSource.readObject.asInstanceOf[Int]
+ logInfo("Received requestedFileLen = " + requestedFileLen)
+
+ // Turn the timer OFF, if the sender responds before timeout
+ timeOutTimer.cancel()
+
+ // Receive the file
+ if (requestedFileLen != -1) {
+ val readStartTime = System.currentTimeMillis
+ logInfo("BEGIN READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestPath))
+
+ // Receive data in an Array[Byte]
+ var recvByteArray = new Array[Byte](requestedFileLen)
+ var alreadyRead = 0
+ var bytesRead = 0
+
+ while (alreadyRead != requestedFileLen) {
+ bytesRead = isSource.read(recvByteArray, alreadyRead,
+ requestedFileLen - alreadyRead)
+ if (bytesRead > 0) {
+ alreadyRead = alreadyRead + bytesRead
+ }
+ }
+
+ // Make it available to the consumer
+ try {
+ receivedData.put((splitIndex, recvByteArray))
+ } catch {
+ case e: Exception => {
+ logInfo("Exception during putting data into receivedData")
+ }
+ }
+
+ // TODO: Updating stats before consumption is completed
+ hasSplitsBitVector.synchronized {
+ hasSplitsBitVector.set(splitIndex)
+ hasSplits += 1
+ }
+
+ // We have received splitIndex
+ splitsInRequestBitVector.synchronized {
+ splitsInRequestBitVector.set(splitIndex, false)
+ }
+
+ receptionSucceeded = true
+
+ logInfo("END READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestPath))
+ val readTime = System.currentTimeMillis - readStartTime
+ logInfo("Reading http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestPath) + " took " + readTime + " millis.")
+ } else {
+ throw new SparkException("ShuffleServer " + hostAddress + " does not have " + requestPath)
+ }
+ } catch {
+ // EOFException is expected to happen because sender can break
+ // connection due to timeout
+ case eofe: java.io.EOFException => { }
+ case e: Exception => {
+ logInfo("ShuffleClient had a " + e)
+ }
+ } finally {
+ // If reception failed, unset for future retry
+ if (!receptionSucceeded) {
+ splitsInRequestBitVector.synchronized {
+ splitsInRequestBitVector.set(splitIndex, false)
+ }
+ }
+ cleanUpConnections()
+ }
+ }
+
+ private def cleanUpConnections(): Unit = {
+ if (oisSource != null) {
+ oisSource.close()
+ }
+ if (oosSource != null) {
+ oosSource.close()
+ }
+ if (peerSocketToSource != null) {
+ peerSocketToSource.close()
+ }
+ }
+ }
+}
+
+object CustomParallelLocalFileShuffle extends Logging {
+ private var initialized = false
+ private var nextShuffleId = new AtomicLong(0)
+
+ // Variables initialized by initializeIfNeeded()
+ private var shuffleDir: File = null
+
+ private var shuffleServer: ShuffleServer = null
+ private var serverAddress = InetAddress.getLocalHost.getHostAddress
+ private var serverPort: Int = -1
+
+ // Random number generator
+ var ranGen = new Random
+
+ private def initializeIfNeeded() = synchronized {
+ if (!initialized) {
+ // TODO: localDir should be created by some mechanism common to Spark
+ // so that it can be shared among shuffle, broadcast, etc
+ val localDirRoot = System.getProperty("spark.local.dir", "/tmp")
+ var tries = 0
+ var foundLocalDir = false
+ var localDir: File = null
+ var localDirUuid: UUID = null
+
+ while (!foundLocalDir && tries < 10) {
+ tries += 1
+ try {
+ localDirUuid = UUID.randomUUID
+ localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
+ if (!localDir.exists) {
+ localDir.mkdirs()
+ foundLocalDir = true
+ }
+ } catch {
+ case e: Exception =>
+ logWarning("Attempt " + tries + " to create local dir failed", e)
+ }
+ }
+ if (!foundLocalDir) {
+ logError("Failed 10 attempts to create local dir in " + localDirRoot)
+ System.exit(1)
+ }
+ shuffleDir = new File(localDir, "shuffle")
+ shuffleDir.mkdirs()
+ logInfo("Shuffle dir: " + shuffleDir)
+
+ // Create and start the shuffleServer
+ shuffleServer = new ShuffleServer
+ shuffleServer.setDaemon(true)
+ shuffleServer.start()
+ logInfo("ShuffleServer started...")
+
+ initialized = true
+ }
+ }
+
+ def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int): File = {
+ initializeIfNeeded()
+ val dir = new File(shuffleDir, shuffleId + "/" + inputId)
+ dir.mkdirs()
+ val file = new File(dir, "" + outputId)
+ return file
+ }
+
+ def newShuffleId(): Long = {
+ nextShuffleId.getAndIncrement()
+ }
+
+ // Returns a standard ThreadFactory except all threads are daemons
+ private def newDaemonThreadFactory: ThreadFactory = {
+ new ThreadFactory {
+ def newThread(r: Runnable): Thread = {
+ var t = Executors.defaultThreadFactory.newThread(r)
+ t.setDaemon(true)
+ return t
+ }
+ }
+ }
+
+ // Wrapper over newFixedThreadPool
+ def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = {
+ var threadPool =
+ Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor]
+
+ threadPool.setThreadFactory(newDaemonThreadFactory)
+
+ return threadPool
+ }
+
+ class ShuffleServer
+ extends Thread with Logging {
+ var threadPool =
+ newDaemonFixedThreadPool(Shuffle.MaxTxConnections)
+
+ var serverSocket: ServerSocket = null
+
+ override def run: Unit = {
+ serverSocket = new ServerSocket(0)
+ serverPort = serverSocket.getLocalPort
+
+ logInfo("ShuffleServer started with " + serverSocket)
+ logInfo("Local URI: http://" + serverAddress + ":" + serverPort)
+
+ try {
+ while (true) {
+ var clientSocket: Socket = null
+ try {
+ clientSocket = serverSocket.accept()
+ } catch {
+ case e: Exception => { }
+ }
+ if (clientSocket != null) {
+ logInfo("Serve: Accepted new client connection:" + clientSocket)
+ try {
+ threadPool.execute(new ShuffleServerThread(clientSocket))
+ } catch {
+ // In failure, close socket here; else, the thread will close it
+ case ioe: IOException => {
+ clientSocket.close()
+ }
+ }
+ }
+ }
+ } finally {
+ if (serverSocket != null) {
+ logInfo("ShuffleServer now stopping...")
+ serverSocket.close()
+ }
+ }
+ // Shutdown the thread pool
+ threadPool.shutdown()
+ }
+
+ class ShuffleServerThread(val clientSocket: Socket)
+ extends Thread with Logging {
+ private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream]
+ os.flush()
+ private val bos = new BufferedOutputStream(os)
+ bos.flush()
+ private val oos = new ObjectOutputStream(os)
+ oos.flush()
+ private val ois = new ObjectInputStream(clientSocket.getInputStream)
+
+ logInfo("new ShuffleServerThread is running")
+
+ override def run: Unit = {
+ try {
+ // Receive requestPath from the receiver
+ var requestPath = ois.readObject.asInstanceOf[String]
+ logInfo("requestPath: " + shuffleDir + "/" + requestPath)
+
+ // Open the file
+ var requestedFile: File = null
+ var requestedFileLen = -1
+ try {
+ requestedFile = new File(shuffleDir + "/" + requestPath)
+ requestedFileLen = requestedFile.length.toInt
+ } catch {
+ case e: Exception => { }
+ }
+
+ // Send the length of the requestPath to let the receiver know that
+ // transfer is about to start
+ // In the case of receiver timeout and connection close, this will
+ // throw a java.net.SocketException: Broken pipe
+ oos.writeObject(requestedFileLen)
+ oos.flush()
+
+ logInfo("requestedFileLen = " + requestedFileLen)
+
+ // Read and send the requested file
+ if (requestedFileLen != -1) {
+ // Read
+ var byteArray = new Array[Byte](requestedFileLen)
+ val bis =
+ new BufferedInputStream(new FileInputStream(requestedFile))
+
+ var bytesRead = bis.read(byteArray, 0, byteArray.length)
+ var alreadyRead = bytesRead
+
+ while (alreadyRead < requestedFileLen) {
+ bytesRead = bis.read(byteArray, alreadyRead,
+ (byteArray.length - alreadyRead))
+ if(bytesRead > 0) {
+ alreadyRead = alreadyRead + bytesRead
+ }
+ }
+ bis.close()
+
+ // Send
+ bos.write(byteArray, 0, byteArray.length)
+ bos.flush()
+ } else {
+ // Close the connection
+ }
+ } catch {
+ // If something went wrong, e.g., the worker at the other end died etc
+ // then close everything up
+ // Exception can happen if the receiver stops receiving
+ case e: Exception => {
+ logInfo("ShuffleServerThread had a " + e)
+ }
+ } finally {
+ logInfo("ShuffleServerThread is closing streams and sockets")
+ ois.close()
+ // TODO: Following can cause "java.net.SocketException: Socket closed"
+ oos.close()
+ bos.close()
+ clientSocket.close()
+ }
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/DfsShuffle.scala b/core/src/main/scala/spark/DfsShuffle.scala
index 7a42bf2d06..bf91be7d2c 100644
--- a/core/src/main/scala/spark/DfsShuffle.scala
+++ b/core/src/main/scala/spark/DfsShuffle.scala
@@ -9,7 +9,6 @@ import scala.collection.mutable.HashMap
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem}
-
/**
* A simple implementation of shuffle using a distributed file system.
*
@@ -82,7 +81,6 @@ class DfsShuffle[K, V, C] extends Shuffle[K, V, C] with Logging {
}
}
-
/**
* Companion object of DfsShuffle; responsible for initializing a Hadoop
* FileSystem object based on the spark.dfs property and generating names
diff --git a/core/src/main/scala/spark/HttpBlockedLocalFileShuffle.scala b/core/src/main/scala/spark/HttpBlockedLocalFileShuffle.scala
new file mode 100644
index 0000000000..8e89cadfdd
--- /dev/null
+++ b/core/src/main/scala/spark/HttpBlockedLocalFileShuffle.scala
@@ -0,0 +1,471 @@
+package spark
+
+import java.io._
+import java.net._
+import java.util.{BitSet, Random, Timer, TimerTask, UUID}
+import java.util.concurrent.atomic.AtomicLong
+import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory}
+
+import scala.collection.mutable.{ArrayBuffer, HashMap}
+
+/**
+ * An implementation of shuffle using local files served through HTTP where
+ * receivers create simultaneous connections to multiple servers by setting the
+ * 'spark.shuffle.maxRxConnections' config option.
+ *
+ * By controlling the 'spark.shuffle.blockSize' config option one can also
+ * control the largest block size to retrieve by each reducers. An INDEX file
+ * keeps track of block boundaries instead of creating many smaller files.
+ *
+ * TODO: Add support for compression when spark.compress is set to true.
+ */
+@serializable
+class HttpBlockedLocalFileShuffle[K, V, C]
+extends Shuffle[K, V, C] with Logging {
+ @transient var totalSplits = 0
+ @transient var hasSplits = 0
+
+ @transient var blocksInSplit: Array[ArrayBuffer[Long]] = null
+ @transient var totalBlocksInSplit: Array[Int] = null
+ @transient var hasBlocksInSplit: Array[Int] = null
+
+ @transient var hasSplitsBitVector: BitSet = null
+ @transient var splitsInRequestBitVector: BitSet = null
+
+ @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null
+ @transient var combiners: HashMap[K,C] = null
+
+ override def compute(input: RDD[(K, V)],
+ numOutputSplits: Int,
+ createCombiner: V => C,
+ mergeValue: (C, V) => C,
+ mergeCombiners: (C, C) => C)
+ : RDD[(K, C)] =
+ {
+ val sc = input.sparkContext
+ val shuffleId = HttpBlockedLocalFileShuffle.newShuffleId()
+ logInfo("Shuffle ID: " + shuffleId)
+
+ val splitRdd = new NumberedSplitRDD(input)
+ val numInputSplits = splitRdd.splits.size
+
+ // Run a parallel map and collect to write the intermediate data files,
+ // returning a list of inputSplitId -> serverUri pairs
+ val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => {
+ val myIndex = pair._1
+ val myIterator = pair._2
+ val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C])
+ for ((k, v) <- myIterator) {
+ var bucketId = k.hashCode % numOutputSplits
+ if (bucketId < 0) { // Fix bucket ID if hash code was negative
+ bucketId += numOutputSplits
+ }
+ val bucket = buckets(bucketId)
+ bucket(k) = bucket.get(k) match {
+ case Some(c) => mergeValue(c, v)
+ case None => createCombiner(v)
+ }
+ }
+
+ for (i <- 0 until numOutputSplits) {
+ // Open the INDEX file
+ var indexFile: File =
+ HttpBlockedLocalFileShuffle.getBlockIndexOutputFile(shuffleId,
+ myIndex, i)
+ var indexOut = new ObjectOutputStream(new FileOutputStream(indexFile))
+ var indexDirty: Boolean = true
+ var alreadyWritten: Long = 0
+
+ // Open the actual file
+ var file: File =
+ HttpBlockedLocalFileShuffle.getOutputFile(shuffleId, myIndex, i)
+ val out = new ObjectOutputStream(new FileOutputStream(file))
+
+ val writeStartTime = System.currentTimeMillis
+ logInfo("BEGIN WRITE: " + file)
+
+ buckets(i).foreach(pair => {
+ out.writeObject(pair)
+ out.flush()
+ indexDirty = true
+
+ // Update the INDEX file if more than blockSize limit has been written
+ if (file.length - alreadyWritten > Shuffle.BlockSize) {
+ indexOut.writeObject(file.length)
+ indexDirty = false
+ alreadyWritten = file.length
+ }
+ })
+
+ // Write down the last range if it was not written
+ if (indexDirty) {
+ indexOut.writeObject(file.length)
+ }
+
+ out.close()
+ indexOut.close()
+
+ logInfo("END WRITE: " + file)
+ val writeTime = (System.currentTimeMillis - writeStartTime)
+ logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.")
+ }
+
+ (myIndex, HttpBlockedLocalFileShuffle.serverUri)
+ }).collect()
+
+ // TODO: Could broadcast outputLocs
+
+ // Return an RDD that does each of the merges for a given partition
+ val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits)
+ return indexes.flatMap((myId: Int) => {
+ totalSplits = outputLocs.size
+ hasSplits = 0
+
+ blocksInSplit = Array.tabulate(totalSplits)(_ => new ArrayBuffer[Long])
+ totalBlocksInSplit = Array.tabulate(totalSplits)(_ => -1)
+ hasBlocksInSplit = Array.tabulate(totalSplits)(_ => 0)
+
+ hasSplitsBitVector = new BitSet(totalSplits)
+ splitsInRequestBitVector = new BitSet(totalSplits)
+
+ receivedData = new LinkedBlockingQueue[(Int, Array[Byte])]
+ combiners = new HashMap[K, C]
+
+ var threadPool = Shuffle.newDaemonFixedThreadPool(
+ Shuffle.MaxRxConnections)
+
+ while (hasSplits < totalSplits) {
+ var numThreadsToCreate =
+ Math.min(totalSplits, Shuffle.MaxRxConnections) -
+ threadPool.getActiveCount
+
+ while (hasSplits < totalSplits && numThreadsToCreate > 0) {
+ // Select a random split to pull
+ val splitIndex = selectRandomSplit
+
+ if (splitIndex != -1) {
+ val (inputId, serverUri) = outputLocs(splitIndex)
+
+ threadPool.execute(new ShuffleClient(serverUri, shuffleId.toInt,
+ inputId, myId, splitIndex))
+
+ // splitIndex is in transit. Will be unset in the ShuffleClient
+ splitsInRequestBitVector.synchronized {
+ splitsInRequestBitVector.set(splitIndex)
+ }
+ }
+
+ numThreadsToCreate = numThreadsToCreate - 1
+ }
+
+ // Sleep for a while before creating new threads
+ Thread.sleep(Shuffle.MinKnockInterval)
+ }
+
+ threadPool.shutdown()
+
+ // Start consumer
+ // TODO: Consumption is delayed until everything has been received.
+ // Otherwise it interferes with network performance
+ var shuffleConsumer = new ShuffleConsumer(mergeCombiners)
+ shuffleConsumer.setDaemon(true)
+ shuffleConsumer.start()
+ logInfo("ShuffleConsumer started...")
+
+ // Don't return until consumption is finished
+ // while (receivedData.size > 0) {
+ // Thread.sleep(Shuffle.MinKnockInterval)
+ // }
+
+ // Wait till shuffleConsumer is done
+ shuffleConsumer.join
+
+ combiners
+ })
+ }
+
+ def selectRandomSplit: Int = {
+ var requiredSplits = new ArrayBuffer[Int]
+
+ synchronized {
+ for (i <- 0 until totalSplits) {
+ if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) {
+ requiredSplits += i
+ }
+ }
+ }
+
+ if (requiredSplits.size > 0) {
+ requiredSplits(HttpBlockedLocalFileShuffle.ranGen.nextInt(
+ requiredSplits.size))
+ } else {
+ -1
+ }
+ }
+
+ class ShuffleConsumer(mergeCombiners: (C, C) => C)
+ extends Thread with Logging {
+ override def run: Unit = {
+ // Run until all splits are here
+ while (receivedData.size > 0) {
+ var splitIndex = -1
+ var recvByteArray: Array[Byte] = null
+
+ try {
+ var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])]
+ splitIndex = tempPair._1
+ recvByteArray = tempPair._2
+ } catch {
+ case e: Exception => {
+ logInfo("Exception during taking data from receivedData")
+ }
+ }
+
+ val inputStream =
+ new ObjectInputStream(new ByteArrayInputStream(recvByteArray))
+
+ try{
+ while (true) {
+ val (k, c) = inputStream.readObject.asInstanceOf[(K, C)]
+ combiners(k) = combiners.get(k) match {
+ case Some(oldC) => mergeCombiners(oldC, c)
+ case None => c
+ }
+ }
+ } catch {
+ case e: EOFException => { }
+ }
+ inputStream.close()
+ }
+ }
+ }
+
+ class ShuffleClient(serverUri: String, shuffleId: Int,
+ inputId: Int, myId: Int, splitIndex: Int)
+ extends Thread with Logging {
+ private var receptionSucceeded = false
+
+ override def run: Unit = {
+ try {
+ // First get the INDEX file if totalBlocksInSplit(splitIndex) is unknown
+ if (totalBlocksInSplit(splitIndex) == -1) {
+ val url = "%s/shuffle/%d/%d/INDEX-%d".format(serverUri, shuffleId,
+ inputId, myId)
+ val inputStream = new ObjectInputStream(new URL(url).openStream())
+
+ try {
+ while (true) {
+ blocksInSplit(splitIndex) +=
+ inputStream.readObject().asInstanceOf[Long]
+ }
+ } catch {
+ case e: EOFException => {}
+ }
+
+ totalBlocksInSplit(splitIndex) = blocksInSplit(splitIndex).size
+ inputStream.close()
+ }
+
+ // Open connection
+ val urlString =
+ "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, inputId, myId)
+ val url = new URL(urlString)
+ val httpConnection =
+ url.openConnection().asInstanceOf[HttpURLConnection]
+
+ // Set the range to download
+ val blockStartsAt = hasBlocksInSplit(splitIndex) match {
+ case 0 => 0
+ case _ => blocksInSplit(splitIndex)(hasBlocksInSplit(splitIndex) - 1) + 1
+ }
+ val blockEndsAt = blocksInSplit(splitIndex)(hasBlocksInSplit(splitIndex))
+ httpConnection.setRequestProperty("Range",
+ "bytes=" + blockStartsAt + "-" + blockEndsAt)
+
+ // Connect to the server
+ httpConnection.connect()
+
+ val urStringWithRange =
+ urlString + "[%d:%d]".format(blockStartsAt, blockEndsAt)
+ val readStartTime = System.currentTimeMillis
+ logInfo("BEGIN READ: " + urStringWithRange)
+
+ // Receive data in an Array[Byte]
+ val requestedFileLen: Int = (blockEndsAt - blockStartsAt).toInt + 1
+ var recvByteArray = new Array[Byte](requestedFileLen)
+ var alreadyRead = 0
+ var bytesRead = 0
+
+ val isSource = httpConnection.getInputStream()
+ while (alreadyRead != requestedFileLen) {
+ bytesRead = isSource.read(recvByteArray, alreadyRead,
+ requestedFileLen - alreadyRead)
+ if (bytesRead > 0) {
+ alreadyRead = alreadyRead + bytesRead
+ }
+ }
+
+ // Disconnect
+ httpConnection.disconnect()
+
+ // Make it available to the consumer
+ try {
+ receivedData.put((splitIndex, recvByteArray))
+ } catch {
+ case e: Exception => {
+ logInfo("Exception during putting data into receivedData")
+ }
+ }
+
+ // TODO: Updating stats before consumption is completed
+ hasBlocksInSplit(splitIndex) = hasBlocksInSplit(splitIndex) + 1
+
+ // Split has been received only if all the blocks have been received
+ if (hasBlocksInSplit(splitIndex) == totalBlocksInSplit(splitIndex)) {
+ hasSplitsBitVector.synchronized {
+ hasSplitsBitVector.set(splitIndex)
+ hasSplits += 1
+ }
+ }
+
+ // We have received splitIndex
+ splitsInRequestBitVector.synchronized {
+ splitsInRequestBitVector.set(splitIndex, false)
+ }
+
+ receptionSucceeded = true
+
+ logInfo("END READ: " + urStringWithRange)
+ val readTime = System.currentTimeMillis - readStartTime
+ logInfo("Reading " + urStringWithRange + " took " + readTime + " millis.")
+ } catch {
+ // EOFException is expected to happen because sender can break
+ // connection due to timeout
+ case eofe: java.io.EOFException => { }
+ case e: Exception => {
+ logInfo("ShuffleClient had a " + e)
+ }
+ } finally {
+ // If reception failed, unset for future retry
+ if (!receptionSucceeded) {
+ splitsInRequestBitVector.synchronized {
+ splitsInRequestBitVector.set(splitIndex, false)
+ }
+ }
+ }
+ }
+ }
+}
+
+object HttpBlockedLocalFileShuffle extends Logging {
+ private var initialized = false
+ private var nextShuffleId = new AtomicLong(0)
+
+ // Variables initialized by initializeIfNeeded()
+ private var shuffleDir: File = null
+ private var server: HttpServer = null
+ private var serverUri: String = null
+
+ // Random number generator
+ var ranGen = new Random
+
+ private def initializeIfNeeded() = synchronized {
+ if (!initialized) {
+ // TODO: localDir should be created by some mechanism common to Spark
+ // so that it can be shared among shuffle, broadcast, etc
+ val localDirRoot = System.getProperty("spark.local.dir", "/tmp")
+ var tries = 0
+ var foundLocalDir = false
+ var localDir: File = null
+ var localDirUuid: UUID = null
+ while (!foundLocalDir && tries < 10) {
+ tries += 1
+ try {
+ localDirUuid = UUID.randomUUID
+ localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
+ if (!localDir.exists) {
+ localDir.mkdirs()
+ foundLocalDir = true
+ }
+ } catch {
+ case e: Exception =>
+ logWarning("Attempt " + tries + " to create local dir failed", e)
+ }
+ }
+ if (!foundLocalDir) {
+ logError("Failed 10 attempts to create local dir in " + localDirRoot)
+ System.exit(1)
+ }
+ shuffleDir = new File(localDir, "shuffle")
+ shuffleDir.mkdirs()
+ logInfo("Shuffle dir: " + shuffleDir)
+
+ val extServerPort = System.getProperty(
+ "spark.localFileShuffle.external.server.port", "-1").toInt
+ if (extServerPort != -1) {
+ // We're using an external HTTP server; set URI relative to its root
+ var extServerPath = System.getProperty(
+ "spark.localFileShuffle.external.server.path", "")
+ if (extServerPath != "" && !extServerPath.endsWith("/")) {
+ extServerPath += "/"
+ }
+ serverUri = "http://%s:%d/%s/spark-local-%s".format(
+ Utils.localIpAddress, extServerPort, extServerPath, localDirUuid)
+ } else {
+ // Create our own server
+ server = new HttpServer(localDir)
+ server.start()
+ serverUri = server.uri
+ }
+ initialized = true
+ logInfo("Local URI: " + serverUri)
+ }
+ }
+
+ def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int): File = {
+ initializeIfNeeded()
+ val dir = new File(shuffleDir, shuffleId + "/" + inputId)
+ dir.mkdirs()
+ val file = new File(dir, "" + outputId)
+ return file
+ }
+
+ def getBlockIndexOutputFile(shuffleId: Long, inputId: Int,
+ outputId: Int): File = {
+ initializeIfNeeded()
+ val dir = new File(shuffleDir, shuffleId + "/" + inputId)
+ dir.mkdirs()
+ val file = new File(dir, "INDEX-" + outputId)
+ return file
+ }
+
+ def getServerUri(): String = {
+ initializeIfNeeded()
+ serverUri
+ }
+
+ def newShuffleId(): Long = {
+ nextShuffleId.getAndIncrement()
+ }
+
+ // Returns a standard ThreadFactory except all threads are daemons
+ private def newDaemonThreadFactory: ThreadFactory = {
+ new ThreadFactory {
+ def newThread(r: Runnable): Thread = {
+ var t = Executors.defaultThreadFactory.newThread(r)
+ t.setDaemon(true)
+ return t
+ }
+ }
+ }
+
+ // Wrapper over newFixedThreadPool
+ def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = {
+ var threadPool =
+ Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor]
+
+ threadPool.setThreadFactory(newDaemonThreadFactory)
+
+ return threadPool
+ }
+}
diff --git a/core/src/main/scala/spark/HttpParallelLocalFileShuffle.scala b/core/src/main/scala/spark/HttpParallelLocalFileShuffle.scala
new file mode 100644
index 0000000000..8e7b897668
--- /dev/null
+++ b/core/src/main/scala/spark/HttpParallelLocalFileShuffle.scala
@@ -0,0 +1,391 @@
+package spark
+
+import java.io._
+import java.net._
+import java.util.{BitSet, Random, Timer, TimerTask, UUID}
+import java.util.concurrent.atomic.AtomicLong
+import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory}
+
+import scala.collection.mutable.{ArrayBuffer, HashMap}
+
+/**
+ * An implementation of shuffle using local files served through HTTP where
+ * receivers create simultaneous connections to multiple servers by setting the
+ * 'spark.shuffle.maxRxConnections' config option.
+ *
+ * TODO: Add support for compression when spark.compress is set to true.
+ */
+@serializable
+class HttpParallelLocalFileShuffle[K, V, C]
+extends Shuffle[K, V, C] with Logging {
+ @transient var totalSplits = 0
+ @transient var hasSplits = 0
+
+ @transient var hasSplitsBitVector: BitSet = null
+ @transient var splitsInRequestBitVector: BitSet = null
+
+ @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null
+ @transient var combiners: HashMap[K,C] = null
+
+ override def compute(input: RDD[(K, V)],
+ numOutputSplits: Int,
+ createCombiner: V => C,
+ mergeValue: (C, V) => C,
+ mergeCombiners: (C, C) => C)
+ : RDD[(K, C)] =
+ {
+ val sc = input.sparkContext
+ val shuffleId = HttpParallelLocalFileShuffle.newShuffleId()
+ logInfo("Shuffle ID: " + shuffleId)
+
+ val splitRdd = new NumberedSplitRDD(input)
+ val numInputSplits = splitRdd.splits.size
+
+ // Run a parallel map and collect to write the intermediate data files,
+ // returning a list of inputSplitId -> serverUri pairs
+ val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => {
+ val myIndex = pair._1
+ val myIterator = pair._2
+ val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C])
+ for ((k, v) <- myIterator) {
+ var bucketId = k.hashCode % numOutputSplits
+ if (bucketId < 0) { // Fix bucket ID if hash code was negative
+ bucketId += numOutputSplits
+ }
+ val bucket = buckets(bucketId)
+ bucket(k) = bucket.get(k) match {
+ case Some(c) => mergeValue(c, v)
+ case None => createCombiner(v)
+ }
+ }
+
+ for (i <- 0 until numOutputSplits) {
+ val file =
+ HttpParallelLocalFileShuffle.getOutputFile(shuffleId, myIndex, i)
+
+ val writeStartTime = System.currentTimeMillis
+ logInfo("BEGIN WRITE: " + file)
+
+ val out = new ObjectOutputStream(new FileOutputStream(file))
+ buckets(i).foreach(pair => out.writeObject(pair))
+
+ logInfo("END WRITE: " + file)
+ val writeTime = System.currentTimeMillis - writeStartTime
+ logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.")
+
+ out.close()
+ }
+
+ (myIndex, HttpParallelLocalFileShuffle.serverUri)
+ }).collect()
+
+ // TODO: Could broadcast outputLocs
+
+ // Return an RDD that does each of the merges for a given partition
+ val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits)
+ return indexes.flatMap((myId: Int) => {
+ totalSplits = outputLocs.size
+ hasSplits = 0
+
+ hasSplitsBitVector = new BitSet(totalSplits)
+ splitsInRequestBitVector = new BitSet(totalSplits)
+
+ receivedData = new LinkedBlockingQueue[(Int, Array[Byte])]
+ combiners = new HashMap[K, C]
+
+ var threadPool = Shuffle.newDaemonFixedThreadPool(
+ Shuffle.MaxRxConnections)
+
+ while (hasSplits < totalSplits) {
+ var numThreadsToCreate =
+ Math.min(totalSplits, Shuffle.MaxRxConnections) -
+ threadPool.getActiveCount
+
+ while (hasSplits < totalSplits && numThreadsToCreate > 0) {
+ // Select a random split to pull
+ val splitIndex = selectRandomSplit
+
+ if (splitIndex != -1) {
+ val (inputId, serverUri) = outputLocs(splitIndex)
+
+ threadPool.execute(new ShuffleClient(serverUri, shuffleId.toInt,
+ inputId, myId, splitIndex))
+
+ // splitIndex is in transit. Will be unset in the ShuffleClient
+ splitsInRequestBitVector.synchronized {
+ splitsInRequestBitVector.set(splitIndex)
+ }
+ }
+
+ numThreadsToCreate = numThreadsToCreate - 1
+ }
+
+ // Sleep for a while before creating new threads
+ Thread.sleep(Shuffle.MinKnockInterval)
+ }
+
+ threadPool.shutdown()
+
+ // Start consumer
+ // TODO: Consumption is delayed until everything has been received.
+ // Otherwise it interferes with network performance
+ var shuffleConsumer = new ShuffleConsumer(mergeCombiners)
+ shuffleConsumer.setDaemon(true)
+ shuffleConsumer.start()
+ logInfo("ShuffleConsumer started...")
+
+ // Don't return until consumption is finished
+ // while (receivedData.size > 0) {
+ // Thread.sleep(Shuffle.MinKnockInterval)
+ // }
+
+ // Wait till shuffleConsumer is done
+ shuffleConsumer.join
+
+ combiners
+ })
+ }
+
+ def selectRandomSplit: Int = {
+ var requiredSplits = new ArrayBuffer[Int]
+
+ synchronized {
+ for (i <- 0 until totalSplits) {
+ if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) {
+ requiredSplits += i
+ }
+ }
+ }
+
+ if (requiredSplits.size > 0) {
+ requiredSplits(HttpParallelLocalFileShuffle.ranGen.nextInt(
+ requiredSplits.size))
+ } else {
+ -1
+ }
+ }
+
+ class ShuffleConsumer(mergeCombiners: (C, C) => C)
+ extends Thread with Logging {
+ override def run: Unit = {
+ // Run until all splits are here
+ while (receivedData.size > 0) {
+ var splitIndex = -1
+ var recvByteArray: Array[Byte] = null
+
+ try {
+ var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])]
+ splitIndex = tempPair._1
+ recvByteArray = tempPair._2
+ } catch {
+ case e: Exception => {
+ logInfo("Exception during taking data from receivedData")
+ }
+ }
+
+ val inputStream =
+ new ObjectInputStream(new ByteArrayInputStream(recvByteArray))
+
+ try{
+ while (true) {
+ val (k, c) = inputStream.readObject.asInstanceOf[(K, C)]
+ combiners(k) = combiners.get(k) match {
+ case Some(oldC) => mergeCombiners(oldC, c)
+ case None => c
+ }
+ }
+ } catch {
+ case e: EOFException => { }
+ }
+ inputStream.close()
+ }
+ }
+ }
+
+ class ShuffleClient(serverUri: String, shuffleId: Int,
+ inputId: Int, myId: Int, splitIndex: Int)
+ extends Thread with Logging {
+ private var receptionSucceeded = false
+
+ override def run: Unit = {
+ try {
+ // Open connection
+ val urlString =
+ "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, inputId, myId)
+ val url = new URL(urlString)
+ val httpConnection =
+ url.openConnection().asInstanceOf[HttpURLConnection]
+
+ // Connect to the server
+ httpConnection.connect()
+
+ // Receive file length
+ var requestedFileLen = httpConnection.getContentLength
+
+ val readStartTime = System.currentTimeMillis
+ logInfo("BEGIN READ: " + urlString)
+
+ // Receive data in an Array[Byte]
+ var recvByteArray = new Array[Byte](requestedFileLen)
+ var alreadyRead = 0
+ var bytesRead = 0
+
+ val isSource = httpConnection.getInputStream()
+ while (alreadyRead != requestedFileLen) {
+ bytesRead = isSource.read(recvByteArray, alreadyRead,
+ requestedFileLen - alreadyRead)
+ if (bytesRead > 0) {
+ alreadyRead = alreadyRead + bytesRead
+ }
+ }
+
+ // Disconnect
+ httpConnection.disconnect()
+
+ // Make it available to the consumer
+ try {
+ receivedData.put((splitIndex, recvByteArray))
+ } catch {
+ case e: Exception => {
+ logInfo("Exception during putting data into receivedData")
+ }
+ }
+
+ // TODO: Updating stats before consumption is completed
+ hasSplitsBitVector.synchronized {
+ hasSplitsBitVector.set(splitIndex)
+ hasSplits += 1
+ }
+
+ // We have received splitIndex
+ splitsInRequestBitVector.synchronized {
+ splitsInRequestBitVector.set(splitIndex, false)
+ }
+
+ receptionSucceeded = true
+
+ logInfo("END READ: " + urlString)
+ val readTime = System.currentTimeMillis - readStartTime
+ logInfo("Reading " + urlString + " took " + readTime + " millis.")
+ } catch {
+ // EOFException is expected to happen because sender can break
+ // connection due to timeout
+ case eofe: java.io.EOFException => { }
+ case e: Exception => {
+ logInfo("ShuffleClient had a " + e)
+ }
+ } finally {
+ // If reception failed, unset for future retry
+ if (!receptionSucceeded) {
+ splitsInRequestBitVector.synchronized {
+ splitsInRequestBitVector.set(splitIndex, false)
+ }
+ }
+ }
+ }
+ }
+}
+
+object HttpParallelLocalFileShuffle extends Logging {
+ private var initialized = false
+ private var nextShuffleId = new AtomicLong(0)
+
+ // Variables initialized by initializeIfNeeded()
+ private var shuffleDir: File = null
+ private var server: HttpServer = null
+ private var serverUri: String = null
+
+ // Random number generator
+ var ranGen = new Random
+
+ private def initializeIfNeeded() = synchronized {
+ if (!initialized) {
+ // TODO: localDir should be created by some mechanism common to Spark
+ // so that it can be shared among shuffle, broadcast, etc
+ val localDirRoot = System.getProperty("spark.local.dir", "/tmp")
+ var tries = 0
+ var foundLocalDir = false
+ var localDir: File = null
+ var localDirUuid: UUID = null
+ while (!foundLocalDir && tries < 10) {
+ tries += 1
+ try {
+ localDirUuid = UUID.randomUUID
+ localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
+ if (!localDir.exists) {
+ localDir.mkdirs()
+ foundLocalDir = true
+ }
+ } catch {
+ case e: Exception =>
+ logWarning("Attempt " + tries + " to create local dir failed", e)
+ }
+ }
+ if (!foundLocalDir) {
+ logError("Failed 10 attempts to create local dir in " + localDirRoot)
+ System.exit(1)
+ }
+ shuffleDir = new File(localDir, "shuffle")
+ shuffleDir.mkdirs()
+ logInfo("Shuffle dir: " + shuffleDir)
+
+ val extServerPort = System.getProperty(
+ "spark.localFileShuffle.external.server.port", "-1").toInt
+ if (extServerPort != -1) {
+ // We're using an external HTTP server; set URI relative to its root
+ var extServerPath = System.getProperty(
+ "spark.localFileShuffle.external.server.path", "")
+ if (extServerPath != "" && !extServerPath.endsWith("/")) {
+ extServerPath += "/"
+ }
+ serverUri = "http://%s:%d/%s/spark-local-%s".format(
+ Utils.localIpAddress, extServerPort, extServerPath, localDirUuid)
+ } else {
+ // Create our own server
+ server = new HttpServer(localDir)
+ server.start()
+ serverUri = server.uri
+ }
+ initialized = true
+ logInfo("Local URI: " + serverUri)
+ }
+ }
+
+ def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int): File = {
+ initializeIfNeeded()
+ val dir = new File(shuffleDir, shuffleId + "/" + inputId)
+ dir.mkdirs()
+ val file = new File(dir, "" + outputId)
+ return file
+ }
+
+ def getServerUri(): String = {
+ initializeIfNeeded()
+ serverUri
+ }
+
+ def newShuffleId(): Long = {
+ nextShuffleId.getAndIncrement()
+ }
+
+ // Returns a standard ThreadFactory except all threads are daemons
+ private def newDaemonThreadFactory: ThreadFactory = {
+ new ThreadFactory {
+ def newThread(r: Runnable): Thread = {
+ var t = Executors.defaultThreadFactory.newThread(r)
+ t.setDaemon(true)
+ return t
+ }
+ }
+ }
+
+ // Wrapper over newFixedThreadPool
+ def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = {
+ var threadPool =
+ Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor]
+
+ threadPool.setThreadFactory(newDaemonThreadFactory)
+
+ return threadPool
+ }
+}
diff --git a/core/src/main/scala/spark/ManualBlockedLocalFileShuffle.scala b/core/src/main/scala/spark/ManualBlockedLocalFileShuffle.scala
new file mode 100644
index 0000000000..cfd84fdb83
--- /dev/null
+++ b/core/src/main/scala/spark/ManualBlockedLocalFileShuffle.scala
@@ -0,0 +1,465 @@
+package spark
+
+import java.io._
+import java.net._
+import java.util.{BitSet, Random, Timer, TimerTask, UUID}
+import java.util.concurrent.atomic.AtomicLong
+import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory}
+
+import scala.collection.mutable.{ArrayBuffer, HashMap}
+
+/**
+ * An implementation of shuffle using local files served through HTTP where
+ * receivers create simultaneous connections to multiple servers by setting the
+ * 'spark.shuffle.maxRxConnections' config option.
+ *
+ * By controlling the 'spark.shuffle.blockSize' config option one can also
+ * control the largest block size to divide each map output into. Essentially,
+ * instead of creating one large output file for each reducer, maps create
+ * multiple smaller files to enable finer level of engagement.
+ *
+ * TODO: Add support for compression when spark.compress is set to true.
+ */
+@serializable
+class ManualBlockedLocalFileShuffle[K, V, C]
+extends Shuffle[K, V, C] with Logging {
+ @transient var totalSplits = 0
+ @transient var hasSplits = 0
+
+ @transient var totalBlocksInSplit: Array[Int] = null
+ @transient var hasBlocksInSplit: Array[Int] = null
+
+ @transient var hasSplitsBitVector: BitSet = null
+ @transient var splitsInRequestBitVector: BitSet = null
+
+ @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null
+ @transient var combiners: HashMap[K,C] = null
+
+ override def compute(input: RDD[(K, V)],
+ numOutputSplits: Int,
+ createCombiner: V => C,
+ mergeValue: (C, V) => C,
+ mergeCombiners: (C, C) => C)
+ : RDD[(K, C)] =
+ {
+ val sc = input.sparkContext
+ val shuffleId = ManualBlockedLocalFileShuffle.newShuffleId()
+ logInfo("Shuffle ID: " + shuffleId)
+
+ val splitRdd = new NumberedSplitRDD(input)
+ val numInputSplits = splitRdd.splits.size
+
+ // Run a parallel map and collect to write the intermediate data files,
+ // returning a list of inputSplitId -> serverUri pairs
+ val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => {
+ val myIndex = pair._1
+ val myIterator = pair._2
+ val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C])
+ for ((k, v) <- myIterator) {
+ var bucketId = k.hashCode % numOutputSplits
+ if (bucketId < 0) { // Fix bucket ID if hash code was negative
+ bucketId += numOutputSplits
+ }
+ val bucket = buckets(bucketId)
+ bucket(k) = bucket.get(k) match {
+ case Some(c) => mergeValue(c, v)
+ case None => createCombiner(v)
+ }
+ }
+
+ for (i <- 0 until numOutputSplits) {
+ var blockNum = 0
+ var isDirty = false
+ var file: File = null
+ var out: ObjectOutputStream = null
+
+ var writeStartTime: Long = 0
+
+ buckets(i).foreach(pair => {
+ // Open a new file if necessary
+ if (!isDirty) {
+ file = ManualBlockedLocalFileShuffle.getOutputFile(shuffleId,
+ myIndex, i, blockNum)
+ writeStartTime = System.currentTimeMillis
+ logInfo("BEGIN WRITE: " + file)
+
+ out = new ObjectOutputStream(new FileOutputStream(file))
+ }
+
+ out.writeObject(pair)
+ out.flush()
+ isDirty = true
+
+ // Close the old file if has crossed the blockSize limit
+ if (file.length > Shuffle.BlockSize) {
+ out.close()
+ logInfo("END WRITE: " + file)
+ val writeTime = System.currentTimeMillis - writeStartTime
+ logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.")
+
+ blockNum = blockNum + 1
+ isDirty = false
+ }
+ })
+
+ if (isDirty) {
+ out.close()
+ logInfo("END WRITE: " + file)
+ val writeTime = System.currentTimeMillis - writeStartTime
+ logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.")
+
+ blockNum = blockNum + 1
+ }
+
+ // Write the BLOCKNUM file
+ file = ManualBlockedLocalFileShuffle.getBlockNumOutputFile(shuffleId,
+ myIndex, i)
+ out = new ObjectOutputStream(new FileOutputStream(file))
+ out.writeObject(blockNum)
+ out.close()
+ }
+
+ (myIndex, ManualBlockedLocalFileShuffle.serverUri)
+ }).collect()
+
+ // TODO: Could broadcast outputLocs
+
+ // Return an RDD that does each of the merges for a given partition
+ val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits)
+ return indexes.flatMap((myId: Int) => {
+ totalSplits = outputLocs.size
+ hasSplits = 0
+
+ totalBlocksInSplit = Array.tabulate(totalSplits)(_ => -1)
+ hasBlocksInSplit = Array.tabulate(totalSplits)(_ => 0)
+
+ hasSplitsBitVector = new BitSet(totalSplits)
+ splitsInRequestBitVector = new BitSet(totalSplits)
+
+ receivedData = new LinkedBlockingQueue[(Int, Array[Byte])]
+ combiners = new HashMap[K, C]
+
+ var threadPool = Shuffle.newDaemonFixedThreadPool(
+ Shuffle.MaxRxConnections)
+
+ while (hasSplits < totalSplits) {
+ var numThreadsToCreate =
+ Math.min(totalSplits, Shuffle.MaxRxConnections) -
+ threadPool.getActiveCount
+
+ while (hasSplits < totalSplits && numThreadsToCreate > 0) {
+ // Select a random split to pull
+ val splitIndex = selectRandomSplit
+
+ if (splitIndex != -1) {
+ val (inputId, serverUri) = outputLocs(splitIndex)
+
+ threadPool.execute(new ShuffleClient(serverUri, shuffleId.toInt,
+ inputId, myId, splitIndex, mergeCombiners))
+
+ // splitIndex is in transit. Will be unset in the ShuffleClient
+ splitsInRequestBitVector.synchronized {
+ splitsInRequestBitVector.set(splitIndex)
+ }
+ }
+
+ numThreadsToCreate = numThreadsToCreate - 1
+ }
+
+ // Sleep for a while before creating new threads
+ Thread.sleep(Shuffle.MinKnockInterval)
+ }
+
+ threadPool.shutdown()
+
+ // Start consumer
+ // TODO: Consumption is delayed until everything has been received.
+ // Otherwise it interferes with network performance
+ var shuffleConsumer = new ShuffleConsumer(mergeCombiners)
+ shuffleConsumer.setDaemon(true)
+ shuffleConsumer.start()
+ logInfo("ShuffleConsumer started...")
+
+ // Don't return until consumption is finished
+ // while (receivedData.size > 0) {
+ // Thread.sleep(Shuffle.MinKnockInterval)
+ // }
+
+ // Wait till shuffleConsumer is done
+ shuffleConsumer.join
+
+ combiners
+ })
+ }
+
+ def selectRandomSplit: Int = {
+ var requiredSplits = new ArrayBuffer[Int]
+
+ synchronized {
+ for (i <- 0 until totalSplits) {
+ if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) {
+ requiredSplits += i
+ }
+ }
+ }
+
+ if (requiredSplits.size > 0) {
+ requiredSplits(ManualBlockedLocalFileShuffle.ranGen.nextInt(
+ requiredSplits.size))
+ } else {
+ -1
+ }
+ }
+
+ class ShuffleConsumer(mergeCombiners: (C, C) => C)
+ extends Thread with Logging {
+ override def run: Unit = {
+ // Run until all splits are here
+ while (receivedData.size > 0) {
+ var splitIndex = -1
+ var recvByteArray: Array[Byte] = null
+
+ try {
+ var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])]
+ splitIndex = tempPair._1
+ recvByteArray = tempPair._2
+ } catch {
+ case e: Exception => {
+ logInfo("Exception during taking data from receivedData")
+ }
+ }
+
+ val inputStream =
+ new ObjectInputStream(new ByteArrayInputStream(recvByteArray))
+
+ try{
+ while (true) {
+ val (k, c) = inputStream.readObject.asInstanceOf[(K, C)]
+ combiners(k) = combiners.get(k) match {
+ case Some(oldC) => mergeCombiners(oldC, c)
+ case None => c
+ }
+ }
+ } catch {
+ case e: EOFException => { }
+ }
+ inputStream.close()
+ }
+ }
+ }
+
+ class ShuffleClient(serverUri: String, shuffleId: Int,
+ inputId: Int, myId: Int, splitIndex: Int,
+ mergeCombiners: (C, C) => C)
+ extends Thread with Logging {
+ private var receptionSucceeded = false
+
+ override def run: Unit = {
+ try {
+ // Everything will break if BLOCKNUM is not correctly received
+ // First get BLOCKNUM file if totalBlocksInSplit(splitIndex) is unknown
+ if (totalBlocksInSplit(splitIndex) == -1) {
+ val url = "%s/shuffle/%d/%d/BLOCKNUM-%d".format(serverUri, shuffleId,
+ inputId, myId)
+ val inputStream = new ObjectInputStream(new URL(url).openStream())
+ totalBlocksInSplit(splitIndex) =
+ inputStream.readObject().asInstanceOf[Int]
+ inputStream.close()
+ }
+
+ // Open connection
+ val urlString =
+ "%s/shuffle/%d/%d/%d-%d".format(serverUri, shuffleId, inputId,
+ myId, hasBlocksInSplit(splitIndex))
+ val url = new URL(urlString)
+ val httpConnection =
+ url.openConnection().asInstanceOf[HttpURLConnection]
+
+ // Connect to the server
+ httpConnection.connect()
+
+ // Receive file length
+ var requestedFileLen = httpConnection.getContentLength
+
+ val readStartTime = System.currentTimeMillis
+ logInfo("BEGIN READ: " + url)
+
+ // Receive data in an Array[Byte]
+ var recvByteArray = new Array[Byte](requestedFileLen)
+ var alreadyRead = 0
+ var bytesRead = 0
+
+ val isSource = httpConnection.getInputStream()
+ while (alreadyRead != requestedFileLen) {
+ bytesRead = isSource.read(recvByteArray, alreadyRead,
+ requestedFileLen - alreadyRead)
+ if (bytesRead > 0) {
+ alreadyRead = alreadyRead + bytesRead
+ }
+ }
+
+ // Disconnect
+ httpConnection.disconnect()
+
+ // Make it available to the consumer
+ try {
+ receivedData.put((splitIndex, recvByteArray))
+ } catch {
+ case e: Exception => {
+ logInfo("Exception during putting data into receivedData")
+ }
+ }
+
+ // TODO: Updating stats before consumption is completed
+ hasBlocksInSplit(splitIndex) = hasBlocksInSplit(splitIndex) + 1
+
+ // Split has been received only if all the blocks have been received
+ if (hasBlocksInSplit(splitIndex) == totalBlocksInSplit(splitIndex)) {
+ hasSplitsBitVector.synchronized {
+ hasSplitsBitVector.set(splitIndex)
+ hasSplits += 1
+ }
+ }
+
+ // We have received splitIndex
+ splitsInRequestBitVector.synchronized {
+ splitsInRequestBitVector.set(splitIndex, false)
+ }
+
+ receptionSucceeded = true
+
+ logInfo("END READ: " + url)
+ val readTime = System.currentTimeMillis - readStartTime
+ logInfo("Reading " + url + " took " + readTime + " millis.")
+ } catch {
+ // EOFException is expected to happen because sender can break
+ // connection due to timeout
+ case eofe: java.io.EOFException => { }
+ case e: Exception => {
+ logInfo("ShuffleClient had a " + e)
+ }
+ } finally {
+ // If reception failed, unset for future retry
+ if (!receptionSucceeded) {
+ splitsInRequestBitVector.synchronized {
+ splitsInRequestBitVector.set(splitIndex, false)
+ }
+ }
+ }
+ }
+ }
+}
+
+object ManualBlockedLocalFileShuffle extends Logging {
+ private var initialized = false
+ private var nextShuffleId = new AtomicLong(0)
+
+ // Variables initialized by initializeIfNeeded()
+ private var shuffleDir: File = null
+ private var server: HttpServer = null
+ private var serverUri: String = null
+
+ // Random number generator
+ var ranGen = new Random
+
+ private def initializeIfNeeded() = synchronized {
+ if (!initialized) {
+ // TODO: localDir should be created by some mechanism common to Spark
+ // so that it can be shared among shuffle, broadcast, etc
+ val localDirRoot = System.getProperty("spark.local.dir", "/tmp")
+ var tries = 0
+ var foundLocalDir = false
+ var localDir: File = null
+ var localDirUuid: UUID = null
+ while (!foundLocalDir && tries < 10) {
+ tries += 1
+ try {
+ localDirUuid = UUID.randomUUID
+ localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
+ if (!localDir.exists) {
+ localDir.mkdirs()
+ foundLocalDir = true
+ }
+ } catch {
+ case e: Exception =>
+ logWarning("Attempt " + tries + " to create local dir failed", e)
+ }
+ }
+ if (!foundLocalDir) {
+ logError("Failed 10 attempts to create local dir in " + localDirRoot)
+ System.exit(1)
+ }
+ shuffleDir = new File(localDir, "shuffle")
+ shuffleDir.mkdirs()
+ logInfo("Shuffle dir: " + shuffleDir)
+
+ val extServerPort = System.getProperty(
+ "spark.localFileShuffle.external.server.port", "-1").toInt
+ if (extServerPort != -1) {
+ // We're using an external HTTP server; set URI relative to its root
+ var extServerPath = System.getProperty(
+ "spark.localFileShuffle.external.server.path", "")
+ if (extServerPath != "" && !extServerPath.endsWith("/")) {
+ extServerPath += "/"
+ }
+ serverUri = "http://%s:%d/%s/spark-local-%s".format(
+ Utils.localIpAddress, extServerPort, extServerPath, localDirUuid)
+ } else {
+ // Create our own server
+ server = new HttpServer(localDir)
+ server.start()
+ serverUri = server.uri
+ }
+ initialized = true
+ logInfo("Local URI: " + serverUri)
+ }
+ }
+
+ def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int,
+ blockId: Int): File = {
+ initializeIfNeeded()
+ val dir = new File(shuffleDir, shuffleId + "/" + inputId)
+ dir.mkdirs()
+ val file = new File(dir, "%d-%d".format(outputId, blockId))
+ return file
+ }
+
+ def getBlockNumOutputFile(shuffleId: Long, inputId: Int,
+ outputId: Int): File = {
+ initializeIfNeeded()
+ val dir = new File(shuffleDir, shuffleId + "/" + inputId)
+ dir.mkdirs()
+ val file = new File(dir, "BLOCKNUM-" + outputId)
+ return file
+ }
+
+ def getServerUri(): String = {
+ initializeIfNeeded()
+ serverUri
+ }
+
+ def newShuffleId(): Long = {
+ nextShuffleId.getAndIncrement()
+ }
+
+ // Returns a standard ThreadFactory except all threads are daemons
+ private def newDaemonThreadFactory: ThreadFactory = {
+ new ThreadFactory {
+ def newThread(r: Runnable): Thread = {
+ var t = Executors.defaultThreadFactory.newThread(r)
+ t.setDaemon(true)
+ return t
+ }
+ }
+ }
+
+ // Wrapper over newFixedThreadPool
+ def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = {
+ var threadPool =
+ Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor]
+
+ threadPool.setThreadFactory(newDaemonThreadFactory)
+
+ return threadPool
+ }
+}
diff --git a/core/src/main/scala/spark/Shuffle.scala b/core/src/main/scala/spark/Shuffle.scala
index 4c5649b537..f2d790f727 100644
--- a/core/src/main/scala/spark/Shuffle.scala
+++ b/core/src/main/scala/spark/Shuffle.scala
@@ -1,5 +1,9 @@
package spark
+import java.net._
+import java.util.{BitSet}
+import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
+
/**
* A trait for shuffle system. Given an input RDD and combiner functions
* for PairRDDExtras.combineByKey(), returns an output RDD.
@@ -13,3 +17,112 @@ trait Shuffle[K, V, C] {
mergeCombiners: (C, C) => C)
: RDD[(K, C)]
}
+
+/**
+ * An object containing common shuffle config parameters
+ */
+private object Shuffle
+extends Logging {
+ // Tracker communication constants
+ val ReducerEntering = 0
+ val ReducerLeaving = 1
+
+ // ShuffleTracker info
+ private var MasterHostAddress_ = System.getProperty(
+ "spark.shuffle.masterHostAddress", InetAddress.getLocalHost.getHostAddress)
+ private var MasterTrackerPort_ = System.getProperty(
+ "spark.shuffle.masterTrackerPort", "22222").toInt
+
+ private var BlockSize_ = System.getProperty(
+ "spark.shuffle.blockSize", "1024").toInt * 1024
+
+ // Used thoughout the code for small and large waits/timeouts
+ private var MinKnockInterval_ = System.getProperty(
+ "spark.shuffle.minKnockInterval", "1000").toInt
+ private var MaxKnockInterval_ = System.getProperty(
+ "spark.shuffle.maxKnockInterval", "5000").toInt
+
+ // Maximum number of connections
+ private var MaxRxConnections_ = System.getProperty(
+ "spark.shuffle.maxRxConnections", "4").toInt
+ private var MaxTxConnections_ = System.getProperty(
+ "spark.shuffle.maxTxConnections", "8").toInt
+
+ // Upper limit on receiving in blocked implementations (whichever comes first)
+ private var MaxChatTime_ = System.getProperty(
+ "spark.shuffle.maxChatTime", "250").toInt
+ private var MaxChatBlocks_ = System.getProperty(
+ "spark.shuffle.maxChatBlocks", "1024").toInt
+
+ // A reducer is throttled if it is this much faster
+ private var ThrottleFraction_ = System.getProperty(
+ "spark.shuffle.throttleFraction", "2.0").toDouble
+
+ def MasterHostAddress = MasterHostAddress_
+ def MasterTrackerPort = MasterTrackerPort_
+
+ def BlockSize = BlockSize_
+
+ def MinKnockInterval = MinKnockInterval_
+ def MaxKnockInterval = MaxKnockInterval_
+
+ def MaxRxConnections = MaxRxConnections_
+ def MaxTxConnections = MaxTxConnections_
+
+ def MaxChatTime = MaxChatTime_
+ def MaxChatBlocks = MaxChatBlocks_
+
+ def ThrottleFraction = ThrottleFraction_
+
+ // Returns a standard ThreadFactory except all threads are daemons
+ private def newDaemonThreadFactory: ThreadFactory = {
+ new ThreadFactory {
+ def newThread(r: Runnable): Thread = {
+ var t = Executors.defaultThreadFactory.newThread(r)
+ t.setDaemon(true)
+ return t
+ }
+ }
+ }
+
+ // Wrapper over newFixedThreadPool
+ def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = {
+ var threadPool =
+ Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor]
+
+ threadPool.setThreadFactory(newDaemonThreadFactory)
+
+ return threadPool
+ }
+
+ // Wrapper over newCachedThreadPool
+ def newDaemonCachedThreadPool: ThreadPoolExecutor = {
+ var threadPool =
+ Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor]
+
+ threadPool.setThreadFactory(newDaemonThreadFactory)
+
+ return threadPool
+ }
+}
+
+@serializable
+case class SplitInfo(val hostAddress: String, val listenPort: Int,
+ val splitId: Int) {
+
+ var hasSplits = 0
+ var hasSplitsBitVector: BitSet = null
+
+ // Used by mappers of dim |numOutputSplits|
+ var totalBlocksPerOutputSplit: Array[Int] = null
+ // Used by reducers of dim |numInputSplits|
+ var hasBlocksPerInputSplit: Array[Int] = null
+}
+
+object SplitInfo {
+ // Constants for special values of listenPort
+ val MappersBusy = -1
+
+ // Other constants
+ val UnusedParam = 0
+}
diff --git a/core/src/main/scala/spark/ShuffleTrackerStrategy.scala b/core/src/main/scala/spark/ShuffleTrackerStrategy.scala
new file mode 100644
index 0000000000..fc2f4aa5f7
--- /dev/null
+++ b/core/src/main/scala/spark/ShuffleTrackerStrategy.scala
@@ -0,0 +1,470 @@
+package spark
+
+import java.util.{BitSet, Random}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.util.Sorting._
+
+/**
+ * A trait for implementing tracker strategies for the shuffle system.
+ */
+trait ShuffleTrackerStrategy {
+ // Initialize
+ def initialize(outputLocs_ : Array[SplitInfo]): Unit
+
+ // Select a set of splits and send back
+ def selectSplit(reducerSplitInfo: SplitInfo): ArrayBuffer[Int]
+
+ // Update internal stats if things could be sent back successfully
+ def AddReducerToSplit(reducerSplitInfo: SplitInfo, splitIndices: ArrayBuffer[Int]): Unit
+
+ // A reducer is done. Update internal stats
+ def deleteReducerFrom(reducerSplitInfo: SplitInfo,
+ receptionStat: ReceptionStats): Unit
+}
+
+/**
+ * Helper class to send back reception stats from the reducer
+ */
+case class ReceptionStats(val bytesReceived: Int, val timeSpent: Int,
+ serverSplitIndex: Int) { }
+
+/**
+ * A simple ShuffleTrackerStrategy that tries to balance the total number of
+ * connections created for each mapper.
+ */
+class BalanceConnectionsShuffleTrackerStrategy
+extends ShuffleTrackerStrategy with Logging {
+ private var numSources = -1
+ private var outputLocs: Array[SplitInfo] = null
+ private var curConnectionsPerLoc: Array[Int] = null
+ private var totalConnectionsPerLoc: Array[Int] = null
+
+ // The order of elements in the outputLocs (splitIndex) is used to pass
+ // information back and forth between the tracker, mappers, and reducers
+ def initialize(outputLocs_ : Array[SplitInfo]): Unit = {
+ outputLocs = outputLocs_
+ numSources = outputLocs.size
+
+ // Now initialize other data structures
+ curConnectionsPerLoc = Array.tabulate(numSources)(_ => 0)
+ totalConnectionsPerLoc = Array.tabulate(numSources)(_ => 0)
+ }
+
+ def selectSplit(reducerSplitInfo: SplitInfo): ArrayBuffer[Int] = synchronized {
+ var minConnections = Int.MaxValue
+ var minIndex = -1
+
+ var splitIndices = ArrayBuffer[Int]()
+
+ for (i <- 0 until numSources) {
+ // TODO: Use of MaxRxConnections instead of MaxTxConnections is
+ // intentional here. MaxTxConnections is per machine whereas
+ // MaxRxConnections is per mapper/reducer. Will have to find a better way.
+ if (curConnectionsPerLoc(i) < Shuffle.MaxRxConnections &&
+ totalConnectionsPerLoc(i) < minConnections &&
+ !reducerSplitInfo.hasSplitsBitVector.get(i)) {
+ minConnections = totalConnectionsPerLoc(i)
+ minIndex = i
+ }
+ }
+
+ if (minIndex != -1) {
+ splitIndices += minIndex
+ }
+
+ return splitIndices
+ }
+
+ def AddReducerToSplit(reducerSplitInfo: SplitInfo, splitIndices: ArrayBuffer[Int]): Unit = synchronized {
+ splitIndices.foreach { splitIndex =>
+ curConnectionsPerLoc(splitIndex) = curConnectionsPerLoc(splitIndex) + 1
+ totalConnectionsPerLoc(splitIndex) =
+ totalConnectionsPerLoc(splitIndex) + 1
+ }
+ }
+
+ def deleteReducerFrom(reducerSplitInfo: SplitInfo,
+ receptionStat: ReceptionStats): Unit = synchronized {
+ // Decrease number of active connections
+ curConnectionsPerLoc(receptionStat.serverSplitIndex) =
+ curConnectionsPerLoc(receptionStat.serverSplitIndex) - 1
+
+ // TODO: This assertion can legally fail when ShuffleClient times out while
+ // waiting for tracker response and decides to go to a random server
+ // assert(curConnectionsPerLoc(receptionStat.serverSplitIndex) >= 0)
+
+ // Just in case
+ if (curConnectionsPerLoc(receptionStat.serverSplitIndex) < 0) {
+ curConnectionsPerLoc(receptionStat.serverSplitIndex) = 0
+ }
+ }
+}
+
+/**
+ * A simple ShuffleTrackerStrategy that randomly selects mapper for each reducer
+ */
+class SelectRandomShuffleTrackerStrategy
+extends ShuffleTrackerStrategy with Logging {
+ private var numMappers = -1
+ private var outputLocs: Array[SplitInfo] = null
+
+ private var ranGen = new Random
+
+ // The order of elements in the outputLocs (splitIndex) is used to pass
+ // information back and forth between the tracker, mappers, and reducers
+ def initialize(outputLocs_ : Array[SplitInfo]): Unit = {
+ outputLocs = outputLocs_
+ numMappers = outputLocs.size
+ }
+
+ def selectSplit(reducerSplitInfo: SplitInfo): ArrayBuffer[Int] = synchronized {
+ var splitIndex = -1
+
+ do {
+ splitIndex = ranGen.nextInt(numMappers)
+ } while (reducerSplitInfo.hasSplitsBitVector.get(splitIndex))
+
+ return ArrayBuffer(splitIndex)
+ }
+
+ def AddReducerToSplit(reducerSplitInfo: SplitInfo, splitIndices: ArrayBuffer[Int]): Unit = synchronized {
+ }
+
+ def deleteReducerFrom(reducerSplitInfo: SplitInfo,
+ receptionStat: ReceptionStats): Unit = synchronized {
+ // TODO: This assertion can legally fail when ShuffleClient times out while
+ // waiting for tracker response and decides to go to a random server
+ // assert(curConnectionsPerLoc(receptionStat.serverSplitIndex) >= 0)
+ }
+}
+
+/**
+ * Shuffle tracker strategy that tries to balance the percentage of blocks
+ * remaining for each reducer
+ */
+class BalanceRemainingShuffleTrackerStrategy
+extends ShuffleTrackerStrategy with Logging {
+ // Number of mappers
+ private var numMappers = -1
+ // Number of reducers
+ private var numReducers = -1
+ private var outputLocs: Array[SplitInfo] = null
+
+ // Data structures from reducers' perspectives
+ private var totalBlocksPerInputSplit: Array[Array[Int]] = null
+ private var hasBlocksPerInputSplit: Array[Array[Int]] = null
+
+ // Stored in bytes per millisecond
+ private var speedPerInputSplit: Array[Array[Double]] = null
+
+ private var curConnectionsPerLoc: Array[Int] = null
+ private var totalConnectionsPerLoc: Array[Int] = null
+
+ // The order of elements in the outputLocs (splitIndex) is used to pass
+ // information back and forth between the tracker, mappers, and reducers
+ def initialize(outputLocs_ : Array[SplitInfo]): Unit = {
+ outputLocs = outputLocs_
+
+ numMappers = outputLocs.size
+
+ // All the outputLocs have totalBlocksPerOutputSplit of same size
+ numReducers = outputLocs(0).totalBlocksPerOutputSplit.size
+
+ // Now initialize the data structures
+ totalBlocksPerInputSplit = Array.tabulate(numReducers, numMappers)((i,j) =>
+ outputLocs(j).totalBlocksPerOutputSplit(i))
+ hasBlocksPerInputSplit = Array.tabulate(numReducers, numMappers)((_,_) => 0)
+
+ // Initialize to -1
+ speedPerInputSplit = Array.tabulate(numReducers, numMappers)((_,_) => -1.0)
+
+ curConnectionsPerLoc = Array.tabulate(numMappers)(_ => 0)
+ totalConnectionsPerLoc = Array.tabulate(numMappers)(_ => 0)
+ }
+
+ def selectSplit(reducerSplitInfo: SplitInfo): ArrayBuffer[Int] = synchronized {
+ var splitIndex = -1
+
+ // Estimate time remaining to finish receiving for all reducer/mapper pairs
+ // If speed is unknown or zero then make it 1 to give a large estimate
+ var individualEstimates = Array.tabulate(numReducers, numMappers)((_,_) => 0.0)
+ for (i <- 0 until numReducers; j <- 0 until numMappers) {
+ var blocksRemaining = totalBlocksPerInputSplit(i)(j) -
+ hasBlocksPerInputSplit(i)(j)
+ assert(blocksRemaining >= 0)
+
+ individualEstimates(i)(j) = 1.0 * blocksRemaining * Shuffle.BlockSize /
+ { if (speedPerInputSplit(i)(j) <= 0.0) 1.0 else speedPerInputSplit(i)(j) }
+ }
+
+ // Check if all speedPerInputSplit entries have non-zero values
+ var estimationComplete = true
+ for (i <- 0 until numReducers; j <- 0 until numMappers) {
+ if (speedPerInputSplit(i)(j) < 0.0) {
+ estimationComplete = false
+ }
+ }
+
+ // Mark mappers where this reducer is too fast
+ var throttleFromMapper = Array.tabulate(numMappers)(_ => false)
+
+ for (i <- 0 until numMappers) {
+ var estimatesFromAMapper =
+ Array.tabulate(numReducers)(j => individualEstimates(j)(i))
+
+ val estimateOfThisReducer = estimatesFromAMapper(reducerSplitInfo.splitId)
+
+ // Only care if this reducer yet has something to receive from this mapper
+ if (estimateOfThisReducer > 0) {
+ // Sort the estimated times
+ quickSort(estimatesFromAMapper)
+
+ // Find a Shuffle.ThrottleFraction amount of gap
+ var gapIndex = -1
+ for (i <- 0 until numReducers - 1) {
+ if (gapIndex == -1 && estimatesFromAMapper(i) > 0 &&
+ (Shuffle.ThrottleFraction * estimatesFromAMapper(i) <
+ estimatesFromAMapper(i + 1))) {
+ gapIndex = i
+ }
+
+ assert (estimatesFromAMapper(i) <= estimatesFromAMapper(i + 1))
+ }
+
+ // Keep track of how many have completed
+ var numComplete = estimatesFromAMapper.findIndexOf(i => (i > 0))
+ if (numComplete == -1) {
+ numComplete = numReducers
+ }
+
+ // TODO: Pick a configurable parameter
+ if (gapIndex != -1 && (1.0 * (gapIndex - numComplete + 1) < 0.1 * Shuffle.ThrottleFraction * (numReducers - numComplete)) &&
+ estimateOfThisReducer <= estimatesFromAMapper(gapIndex)) {
+ throttleFromMapper(i) = true
+ logInfo("Throttling R-%d at M-%d with %f and cut-off %f at %d".format(reducerSplitInfo.splitId, i, estimateOfThisReducer, estimatesFromAMapper(gapIndex + 1), gapIndex))
+// for (i <- 0 until numReducers) {
+// print(estimatesFromAMapper(i) + " ")
+// }
+// println("")
+ }
+ } else {
+ throttleFromMapper(i) = true
+ }
+ }
+
+ var minConnections = Int.MaxValue
+ for (i <- 0 until numMappers) {
+ // TODO: Use of MaxRxConnections instead of MaxTxConnections is
+ // intentional here. MaxTxConnections is per machine whereas
+ // MaxRxConnections is per mapper/reducer. Will have to find a better way.
+ if (curConnectionsPerLoc(i) < Shuffle.MaxRxConnections &&
+ totalConnectionsPerLoc(i) < minConnections &&
+ !reducerSplitInfo.hasSplitsBitVector.get(i) &&
+ !throttleFromMapper(i)) {
+ minConnections = totalConnectionsPerLoc(i)
+ splitIndex = i
+ }
+ }
+
+ return ArrayBuffer(splitIndex)
+ }
+
+ def AddReducerToSplit(reducerSplitInfo: SplitInfo, splitIndices: ArrayBuffer[Int]): Unit = synchronized {
+ splitIndices.foreach { splitIndex =>
+ curConnectionsPerLoc(splitIndex) += 1
+ totalConnectionsPerLoc(splitIndex) += 1
+ }
+ }
+
+ def deleteReducerFrom(reducerSplitInfo: SplitInfo,
+ receptionStat: ReceptionStats): Unit = synchronized {
+ // Update hasBlocksPerInputSplit for reducerSplitInfo
+ hasBlocksPerInputSplit(reducerSplitInfo.splitId) =
+ reducerSplitInfo.hasBlocksPerInputSplit
+
+ // Store the last known speed. Add 1 to avoid divide-by-zero. Ignore 0 bytes
+ // TODO: We are forgetting the old speed. Can use averaging at some point.
+ if (receptionStat.bytesReceived > 0) {
+ speedPerInputSplit(reducerSplitInfo.splitId)(receptionStat.serverSplitIndex) =
+ 1.0 * receptionStat.bytesReceived / (receptionStat.timeSpent + 1.0)
+ }
+
+ // logInfo("%d received %d bytes in %d millis".format(reducerSplitInfo.splitId, receptionStat.bytesReceived, receptionStat.timeSpent))
+
+ // Update current connections to the mapper
+ curConnectionsPerLoc(receptionStat.serverSplitIndex) -= 1
+
+ // TODO: This assertion can legally fail when ShuffleClient times out while
+ // waiting for tracker response and decides to go to a random server
+ // assert(curConnectionsPerLoc(receptionStat.serverSplitIndex) >= 0)
+
+ // Just in case
+ if (curConnectionsPerLoc(receptionStat.serverSplitIndex) < 0) {
+ curConnectionsPerLoc(receptionStat.serverSplitIndex) = 0
+ }
+ }
+}
+
+/**
+ * Shuffle tracker strategy that allows reducers to create receiving threads
+ * depending on their estimated time remaining
+ */
+class LimitConnectionsShuffleTrackerStrategy
+extends ShuffleTrackerStrategy with Logging {
+ // Number of mappers
+ private var numMappers = -1
+ // Number of reducers
+ private var numReducers = -1
+ private var outputLocs: Array[SplitInfo] = null
+
+ private var ranGen = new Random
+
+ // Data structures from reducers' perspectives
+ private var totalBlocksPerInputSplit: Array[Array[Int]] = null
+ private var hasBlocksPerInputSplit: Array[Array[Int]] = null
+
+ // Stored in bytes per millisecond
+ private var speedPerInputSplit: Array[Array[Double]] = null
+
+ private var curConnectionsPerReducer: Array[Int] = null
+ private var maxConnectionsPerReducer: Array[Int] = null
+
+ // The order of elements in the outputLocs (splitIndex) is used to pass
+ // information back and forth between the tracker, mappers, and reducers
+ def initialize(outputLocs_ : Array[SplitInfo]): Unit = {
+ outputLocs = outputLocs_
+
+ numMappers = outputLocs.size
+
+ // All the outputLocs have totalBlocksPerOutputSplit of same size
+ numReducers = outputLocs(0).totalBlocksPerOutputSplit.size
+
+ // Now initialize the data structures
+ totalBlocksPerInputSplit = Array.tabulate(numReducers, numMappers)((i,j) =>
+ outputLocs(j).totalBlocksPerOutputSplit(i))
+ hasBlocksPerInputSplit = Array.tabulate(numReducers, numMappers)((_,_) => 0)
+
+ // Initialize to -1
+ speedPerInputSplit = Array.tabulate(numReducers, numMappers)((_,_) => -1.0)
+
+ curConnectionsPerReducer = Array.tabulate(numReducers)(_ => 0)
+ maxConnectionsPerReducer = Array.tabulate(numReducers)(_ => Shuffle.MaxRxConnections)
+ }
+
+ def selectSplit(reducerSplitInfo: SplitInfo): ArrayBuffer[Int] = synchronized {
+ var splitIndices = ArrayBuffer[Int]()
+
+ // Estimate time remaining to finish receiving for all reducer/mapper pairs
+ // If speed is unknown or zero then make it 1 to give a large estimate
+ var individualEstimates = Array.tabulate(numReducers, numMappers)((_,_) => 0.0)
+ for (i <- 0 until numReducers; j <- 0 until numMappers) {
+ var blocksRemaining = totalBlocksPerInputSplit(i)(j) -
+ hasBlocksPerInputSplit(i)(j)
+ assert(blocksRemaining >= 0)
+
+ individualEstimates(i)(j) = 1.0 * blocksRemaining * Shuffle.BlockSize /
+ { if (speedPerInputSplit(i)(j) <= 0.0) 1.0 else speedPerInputSplit(i)(j) }
+ }
+
+ // Check if all speedPerInputSplit entries have non-zero values
+ var estimationComplete = true
+ for (i <- 0 until numReducers; j <- 0 until numMappers) {
+ if (speedPerInputSplit(i)(j) < 0.0) {
+ estimationComplete = false
+ }
+ }
+
+ // Estimate time remaining to finish receiving for each reducer
+ var completionEstimates = Array.tabulate(numReducers)(
+ individualEstimates(_).foldLeft(Double.MinValue)(Math.max(_,_)))
+
+ var numFinished = 0
+ for (i <- 0 until numReducers) {
+ if (completionEstimates(i).toInt == 0) {
+ numFinished += 1
+ }
+ }
+
+ // If a certain number of reducers have finished already, then don't bother
+ if (numFinished >= ((1.0 - 0.1 * Shuffle.ThrottleFraction) * numReducers)) {
+ for (i <- 0 until numReducers) {
+ maxConnectionsPerReducer(i) = Shuffle.MaxRxConnections
+ }
+ // Otherwise, if estimation is complete give reducers proportional threads
+ } else if (estimationComplete) {
+ val fastestEstimate =
+ completionEstimates.foldLeft(Double.MaxValue)(Math.min(_,_))
+ val slowestEstimate =
+ completionEstimates.foldLeft(Double.MinValue)(Math.max(_,_))
+
+ // Set maxConnectionsPerReducer for all reducers proportional to their
+ // estimated time remaining with slowestEstimate reducer having the max
+ for (i <- 0 until numReducers) {
+ maxConnectionsPerReducer(i) =
+ ((completionEstimates(i) / slowestEstimate) * Shuffle.MaxRxConnections).toInt
+ }
+ }
+
+ // Send back a splitIndex if this reducer is within its limit
+ if (curConnectionsPerReducer(reducerSplitInfo.splitId) <
+ maxConnectionsPerReducer(reducerSplitInfo.splitId)) {
+
+ var i = maxConnectionsPerReducer(reducerSplitInfo.splitId) -
+ curConnectionsPerReducer(reducerSplitInfo.splitId)
+
+ var temp = reducerSplitInfo.hasSplitsBitVector.clone.asInstanceOf[BitSet]
+ temp.flip(0, numMappers)
+
+ i = Math.min(i, temp.cardinality)
+
+ while (i > 0) {
+ var splitIndex = -1
+
+ do {
+ splitIndex = ranGen.nextInt(numMappers)
+ } while (reducerSplitInfo.hasSplitsBitVector.get(splitIndex))
+
+ reducerSplitInfo.hasSplitsBitVector.set(splitIndex)
+ splitIndices += splitIndex
+ i -= 1
+ }
+ }
+
+ return splitIndices
+ }
+
+ def AddReducerToSplit(reducerSplitInfo: SplitInfo, splitIndices: ArrayBuffer[Int]): Unit = synchronized {
+ splitIndices.foreach { splitIndex =>
+ curConnectionsPerReducer(reducerSplitInfo.splitId) += 1
+ }
+ }
+
+ def deleteReducerFrom(reducerSplitInfo: SplitInfo,
+ receptionStat: ReceptionStats): Unit = synchronized {
+ // Update hasBlocksPerInputSplit for reducerSplitInfo
+ hasBlocksPerInputSplit(reducerSplitInfo.splitId) =
+ reducerSplitInfo.hasBlocksPerInputSplit
+
+ // Store the last known speed. Add 1 to avoid divide-by-zero. Ignore 0 bytes
+ // TODO: We are forgetting the old speed. Can use averaging at some point.
+ if (receptionStat.bytesReceived > 0) {
+ speedPerInputSplit(reducerSplitInfo.splitId)(receptionStat.serverSplitIndex) =
+ 1.0 * receptionStat.bytesReceived / (receptionStat.timeSpent + 1.0)
+ }
+
+ // logInfo("%d received %d bytes in %d millis".format(reducerSplitInfo.splitId, receptionStat.bytesReceived, receptionStat.timeSpent))
+
+ // Update current threads by this reducer
+ curConnectionsPerReducer(reducerSplitInfo.splitId) -= 1
+
+ // TODO: This assertion can legally fail when ShuffleClient times out while
+ // waiting for tracker response and decides to go to a random server
+ // assert(curConnectionsPerLoc(receptionStat.serverSplitIndex) >= 0)
+
+ // Just in case
+ if (curConnectionsPerReducer(reducerSplitInfo.splitId) < 0) {
+ curConnectionsPerReducer(reducerSplitInfo.splitId) = 0
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/TrackedCustomBlockedInMemoryShuffle.scala b/core/src/main/scala/spark/TrackedCustomBlockedInMemoryShuffle.scala
new file mode 100644
index 0000000000..0d21df9338
--- /dev/null
+++ b/core/src/main/scala/spark/TrackedCustomBlockedInMemoryShuffle.scala
@@ -0,0 +1,926 @@
+package spark
+
+import java.io._
+import java.net._
+import java.util.{BitSet, Random, Timer, TimerTask, UUID}
+import java.util.concurrent.atomic.AtomicLong
+import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory}
+
+import scala.collection.mutable.{ArrayBuffer, HashMap}
+
+/**
+ * An implementation of shuffle using memory served through custom server
+ * where receivers create simultaneous connections to multiple servers by
+ * setting the 'spark.shuffle.maxRxConnections' config option.
+ *
+ * By controlling the 'spark.shuffle.blockSize' config option one can also
+ * control the largest block size to divide each map output into. Essentially,
+ * instead of creating one large output file for each reducer, maps create
+ * multiple smaller files to enable finer level of engagement.
+ *
+ * 'spark.shuffle.maxTxConnections' enforces server-side cap. Ideally,
+ * maxTxConnections >= maxRxConnections * numReducersPerMachine
+ *
+ * 'spark.shuffle.TrackerStrategy' decides which strategy to use in the tracker
+ *
+ * TODO: Add support for compression when spark.compress is set to true.
+ */
+@serializable
+class TrackedCustomBlockedInMemoryShuffle[K, V, C]
+extends Shuffle[K, V, C] with Logging {
+ @transient var totalSplits = 0
+ @transient var hasSplits = 0
+
+ @transient var totalBlocksInSplit: Array[Int] = null
+ @transient var hasBlocksInSplit: Array[Int] = null
+
+ @transient var hasSplitsBitVector: BitSet = null
+ @transient var splitsInRequestBitVector: BitSet = null
+
+ @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null
+ @transient var combiners: HashMap[K,C] = null
+
+ override def compute(input: RDD[(K, V)],
+ numOutputSplits: Int,
+ createCombiner: V => C,
+ mergeValue: (C, V) => C,
+ mergeCombiners: (C, C) => C)
+ : RDD[(K, C)] =
+ {
+ val sc = input.sparkContext
+ val shuffleId = TrackedCustomBlockedInMemoryShuffle.newShuffleId()
+ logInfo("Shuffle ID: " + shuffleId)
+
+ val splitRdd = new NumberedSplitRDD(input)
+ val numInputSplits = splitRdd.splits.size
+
+ // Run a parallel map and collect to write the intermediate data files,
+ // returning a list of inputSplitId -> serverUri pairs
+ val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => {
+ val myIndex = pair._1
+ val myIterator = pair._2
+ val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C])
+ for ((k, v) <- myIterator) {
+ var bucketId = k.hashCode % numOutputSplits
+ if (bucketId < 0) { // Fix bucket ID if hash code was negative
+ bucketId += numOutputSplits
+ }
+ val bucket = buckets(bucketId)
+ bucket(k) = bucket.get(k) match {
+ case Some(c) => mergeValue(c, v)
+ case None => createCombiner(v)
+ }
+ }
+
+ // Keep track of number of blocks for each output split
+ var numBlocksPerOutputSplit = Array.tabulate(numOutputSplits)(_ => 0)
+
+ for (i <- 0 until numOutputSplits) {
+ var blockNum = 0
+ var isDirty = false
+
+ var splitName = ""
+ var baos: ByteArrayOutputStream = null
+ var oos: ObjectOutputStream = null
+
+ var writeStartTime: Long = 0
+
+ buckets(i).foreach(pair => {
+ // Open a new stream if necessary
+ if (!isDirty) {
+ splitName = TrackedCustomBlockedInMemoryShuffle.getSplitName(shuffleId,
+ myIndex, i, blockNum)
+
+ baos = new ByteArrayOutputStream
+ oos = new ObjectOutputStream(baos)
+
+ writeStartTime = System.currentTimeMillis
+ logInfo("BEGIN WRITE: " + splitName)
+ }
+
+ oos.writeObject(pair)
+ isDirty = true
+
+ // Close the old stream if has crossed the blockSize limit
+ if (baos.size > Shuffle.BlockSize) {
+ TrackedCustomBlockedInMemoryShuffle.splitsCache(splitName) =
+ baos.toByteArray
+
+ logInfo("END WRITE: " + splitName)
+ val writeTime = System.currentTimeMillis - writeStartTime
+ logInfo("Writing " + splitName + " of size " + baos.size + " bytes took " + writeTime + " millis.")
+
+ blockNum = blockNum + 1
+ isDirty = false
+ oos.close()
+ }
+ })
+
+ if (isDirty) {
+ TrackedCustomBlockedInMemoryShuffle.splitsCache(splitName) = baos.toByteArray
+
+ logInfo("END WRITE: " + splitName)
+ val writeTime = System.currentTimeMillis - writeStartTime
+ logInfo("Writing " + splitName + " of size " + baos.size + " bytes took " + writeTime + " millis.")
+
+ blockNum = blockNum + 1
+ oos.close()
+ }
+
+ // Store BLOCKNUM info
+ splitName = TrackedCustomBlockedInMemoryShuffle.getBlockNumOutputName(
+ shuffleId, myIndex, i)
+ baos = new ByteArrayOutputStream
+ oos = new ObjectOutputStream(baos)
+ oos.writeObject(blockNum)
+ TrackedCustomBlockedInMemoryShuffle.splitsCache(splitName) = baos.toByteArray
+
+ // Close streams
+ oos.close()
+
+ // Store number of blocks for this outputSplit
+ numBlocksPerOutputSplit(i) = blockNum
+ }
+
+ var retVal = SplitInfo(TrackedCustomBlockedInMemoryShuffle.serverAddress,
+ TrackedCustomBlockedInMemoryShuffle.serverPort, myIndex)
+ retVal.totalBlocksPerOutputSplit = numBlocksPerOutputSplit
+
+ (retVal)
+ }).collect()
+
+ // Start tracker
+ var shuffleTracker = new ShuffleTracker(outputLocs)
+ shuffleTracker.setDaemon(true)
+ shuffleTracker.start()
+ logInfo("ShuffleTracker started...")
+
+ // Return an RDD that does each of the merges for a given partition
+ val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits)
+ return indexes.flatMap((myId: Int) => {
+ totalSplits = outputLocs.size
+ hasSplits = 0
+
+ totalBlocksInSplit = Array.tabulate(totalSplits)(_ => -1)
+ hasBlocksInSplit = Array.tabulate(totalSplits)(_ => 0)
+
+ hasSplitsBitVector = new BitSet(totalSplits)
+ splitsInRequestBitVector = new BitSet(totalSplits)
+
+ receivedData = new LinkedBlockingQueue[(Int, Array[Byte])]
+ combiners = new HashMap[K, C]
+
+ var threadPool = Shuffle.newDaemonFixedThreadPool(
+ Shuffle.MaxRxConnections)
+
+ while (hasSplits < totalSplits) {
+ // Local status of hasSplitsBitVector and splitsInRequestBitVector
+ val localSplitInfo = getLocalSplitInfo(myId)
+
+ // DO NOT talk to the tracker if all the required splits are already busy
+ val hasOrWillHaveSplits = localSplitInfo.hasSplitsBitVector.cardinality
+
+ var numThreadsToCreate =
+ Math.min(totalSplits - hasOrWillHaveSplits, Shuffle.MaxRxConnections) -
+ threadPool.getActiveCount
+
+ while (hasSplits < totalSplits && numThreadsToCreate > 0) {
+ // Receive which split to pull from the tracker
+ logInfo("Talking to tracker...")
+ val startTime = System.currentTimeMillis
+ val splitIndices = getTrackerSelectedSplit(myId)
+ logInfo("Got %s from tracker in %d millis".format(splitIndices, System.currentTimeMillis - startTime))
+
+ if (splitIndices.size > 0) {
+ splitIndices.foreach { splitIndex =>
+ val selectedSplitInfo = outputLocs(splitIndex)
+ val requestSplit =
+ "%d/%d/%d".format(shuffleId, selectedSplitInfo.splitId, myId)
+
+ threadPool.execute(new ShuffleClient(splitIndex, selectedSplitInfo,
+ requestSplit, myId))
+
+ // splitIndex is in transit. Will be unset in the ShuffleClient
+ splitsInRequestBitVector.synchronized {
+ splitsInRequestBitVector.set(splitIndex)
+ }
+ }
+ } else {
+ // Tracker replied back with a NO. Sleep for a while.
+ Thread.sleep(Shuffle.MinKnockInterval)
+ numThreadsToCreate = 0
+ }
+
+ numThreadsToCreate = numThreadsToCreate - splitIndices.size
+ }
+
+ // Sleep for a while before creating new threads
+ Thread.sleep(Shuffle.MinKnockInterval)
+ }
+
+ threadPool.shutdown()
+
+ // Start consumer
+ // TODO: Consumption is delayed until everything has been received.
+ // Otherwise it interferes with network performance
+ var shuffleConsumer = new ShuffleConsumer(mergeCombiners)
+ shuffleConsumer.setDaemon(true)
+ shuffleConsumer.start()
+ logInfo("ShuffleConsumer started...")
+
+ // Don't return until consumption is finished
+ // while (receivedData.size > 0) {
+ // Thread.sleep(Shuffle.MinKnockInterval)
+ // }
+
+ // Wait till shuffleConsumer is done
+ shuffleConsumer.join
+
+ combiners
+ })
+ }
+
+ private def getLocalSplitInfo(myId: Int): SplitInfo = {
+ var localSplitInfo = SplitInfo(InetAddress.getLocalHost.getHostAddress,
+ SplitInfo.UnusedParam, myId)
+
+ // Store hasSplits
+ localSplitInfo.hasSplits = hasSplits
+
+ // Store hasSplitsBitVector
+ hasSplitsBitVector.synchronized {
+ localSplitInfo.hasSplitsBitVector =
+ hasSplitsBitVector.clone.asInstanceOf[BitSet]
+ }
+
+ // Store hasBlocksInSplit to hasBlocksPerInputSplit
+ hasBlocksInSplit.synchronized {
+ localSplitInfo.hasBlocksPerInputSplit =
+ hasBlocksInSplit.clone.asInstanceOf[Array[Int]]
+ }
+
+ // Include the splitsInRequest as well
+ splitsInRequestBitVector.synchronized {
+ localSplitInfo.hasSplitsBitVector.or(splitsInRequestBitVector)
+ }
+
+ return localSplitInfo
+ }
+
+ def selectRandomSplit: Int = {
+ var requiredSplits = new ArrayBuffer[Int]
+
+ synchronized {
+ for (i <- 0 until totalSplits) {
+ if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) {
+ requiredSplits += i
+ }
+ }
+ }
+
+ if (requiredSplits.size > 0) {
+ requiredSplits(TrackedCustomBlockedInMemoryShuffle.ranGen.nextInt(
+ requiredSplits.size))
+ } else {
+ -1
+ }
+ }
+
+ // Talks to the tracker and receives instruction
+ private def getTrackerSelectedSplit(myId: Int): ArrayBuffer[Int] = {
+ // Local status of hasSplitsBitVector and splitsInRequestBitVector
+ val localSplitInfo = getLocalSplitInfo(myId)
+
+ // DO NOT talk to the tracker if all the required splits are already busy
+ if (localSplitInfo.hasSplitsBitVector.cardinality == totalSplits) {
+ return ArrayBuffer[Int]()
+ }
+
+ val clientSocketToTracker = new Socket(Shuffle.MasterHostAddress,
+ Shuffle.MasterTrackerPort)
+ val oosTracker =
+ new ObjectOutputStream(clientSocketToTracker.getOutputStream)
+ oosTracker.flush()
+ val oisTracker =
+ new ObjectInputStream(clientSocketToTracker.getInputStream)
+
+ var selectedSplitIndices = ArrayBuffer[Int]()
+
+ // Setup the timeout mechanism
+ var timeOutTask = new TimerTask {
+ override def run: Unit = {
+ logInfo("Waited enough for tracker response... Take random response...")
+
+ // sockets will be closed in finally
+ // TODO: Sometimes timer wont go off
+
+ // TODO: Selecting randomly here. Tracker won't know about it and get an
+ // asssertion failure when this thread leaves
+
+ selectedSplitIndices = ArrayBuffer(selectRandomSplit)
+ }
+ }
+
+ var timeOutTimer = new Timer
+ // TODO: Which timeout to use?
+ // timeOutTimer.schedule(timeOutTask, Shuffle.MinKnockInterval)
+
+ try {
+ // Send intention
+ oosTracker.writeObject(Shuffle.ReducerEntering)
+ oosTracker.flush()
+
+ // Send what this reducer has
+ oosTracker.writeObject(localSplitInfo)
+ oosTracker.flush()
+
+ // Receive reply from the tracker
+ selectedSplitIndices = oisTracker.readObject.asInstanceOf[ArrayBuffer[Int]]
+
+ // Turn the timer OFF
+ timeOutTimer.cancel()
+ } catch {
+ case e: Exception => {
+ logInfo("getTrackerSelectedSplit had a " + e)
+ }
+ } finally {
+ oisTracker.close()
+ oosTracker.close()
+ clientSocketToTracker.close()
+ }
+
+ return selectedSplitIndices
+ }
+
+ class ShuffleTracker(outputLocs: Array[SplitInfo])
+ extends Thread with Logging {
+ var threadPool = Shuffle.newDaemonCachedThreadPool
+ var serverSocket: ServerSocket = null
+
+ // Create trackerStrategy object
+ val trackerStrategyClass = System.getProperty(
+ "spark.shuffle.trackerStrategy",
+ "spark.BalanceConnectionsShuffleTrackerStrategy")
+
+ val trackerStrategy =
+ Class.forName(trackerStrategyClass).newInstance().asInstanceOf[ShuffleTrackerStrategy]
+
+ // Must initialize here by supplying the outputLocs param
+ // TODO: This could be avoided by directly passing it to the constructor
+ trackerStrategy.initialize(outputLocs)
+
+ override def run: Unit = {
+ serverSocket = new ServerSocket(Shuffle.MasterTrackerPort)
+ logInfo("ShuffleTracker" + serverSocket)
+
+ try {
+ while (true) {
+ var clientSocket: Socket = null
+ try {
+ clientSocket = serverSocket.accept()
+ } catch {
+ case e: Exception => {
+ logInfo("ShuffleTracker had a " + e)
+ }
+ }
+
+ if (clientSocket != null) {
+ try {
+ threadPool.execute(new Thread {
+ override def run: Unit = {
+ val oos = new ObjectOutputStream(clientSocket.getOutputStream)
+ oos.flush()
+ val ois = new ObjectInputStream(clientSocket.getInputStream)
+
+ try {
+ // Receive intention
+ val reducerIntention = ois.readObject.asInstanceOf[Int]
+
+ if (reducerIntention == Shuffle.ReducerEntering) {
+ // Receive what the reducer has
+ val reducerSplitInfo =
+ ois.readObject.asInstanceOf[SplitInfo]
+
+ // Select splits and update stats if necessary
+ var selectedSplitIndices = ArrayBuffer[Int]()
+ trackerStrategy.synchronized {
+ selectedSplitIndices = trackerStrategy.selectSplit(
+ reducerSplitInfo)
+ }
+
+ // Send reply back
+ oos.writeObject(selectedSplitIndices)
+ oos.flush()
+
+ // Update internal stats, only if receiver got the reply
+ trackerStrategy.synchronized {
+ trackerStrategy.AddReducerToSplit(reducerSplitInfo,
+ selectedSplitIndices)
+ }
+ }
+ else if (reducerIntention == Shuffle.ReducerLeaving) {
+ val reducerSplitInfo =
+ ois.readObject.asInstanceOf[SplitInfo]
+
+ // Receive reception stats: how many blocks the reducer
+ // read in how much time and from where
+ val receptionStat =
+ ois.readObject.asInstanceOf[ReceptionStats]
+
+ // Update stats
+ trackerStrategy.synchronized {
+ trackerStrategy.deleteReducerFrom(reducerSplitInfo,
+ receptionStat)
+ }
+
+ // Send ACK
+ oos.writeObject(receptionStat.serverSplitIndex)
+ oos.flush()
+ }
+ else {
+ throw new SparkException("Undefined reducerIntention")
+ }
+ } catch {
+ // EOFException is expected to happen because receiver can
+ // break connection due to timeout and pick random instead
+ case eofe: java.io.EOFException => { }
+ case e: Exception => {
+ logInfo("ShuffleTracker had a " + e)
+ }
+ } finally {
+ ois.close()
+ oos.close()
+ clientSocket.close()
+ }
+ }
+ })
+ } catch {
+ // In failure, close socket here; else, client thread will close
+ case ioe: IOException => {
+ clientSocket.close()
+ }
+ }
+ }
+ }
+ } finally {
+ serverSocket.close()
+ }
+ // Shutdown the thread pool
+ threadPool.shutdown()
+ }
+ }
+
+ class ShuffleConsumer(mergeCombiners: (C, C) => C)
+ extends Thread with Logging {
+ override def run: Unit = {
+ // Run until all splits are here
+ while (receivedData.size > 0) {
+ var splitIndex = -1
+ var recvByteArray: Array[Byte] = null
+
+ try {
+ var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])]
+ splitIndex = tempPair._1
+ recvByteArray = tempPair._2
+ } catch {
+ case e: Exception => {
+ logInfo("Exception during taking data from receivedData")
+ }
+ }
+
+ val inputStream =
+ new ObjectInputStream(new ByteArrayInputStream(recvByteArray))
+
+ try{
+ while (true) {
+ val (k, c) = inputStream.readObject.asInstanceOf[(K, C)]
+ combiners(k) = combiners.get(k) match {
+ case Some(oldC) => mergeCombiners(oldC, c)
+ case None => c
+ }
+ }
+ } catch {
+ case e: EOFException => { }
+ }
+ inputStream.close()
+ }
+ }
+ }
+
+ class ShuffleClient(splitIndex: Int, serversplitInfo: SplitInfo,
+ requestSplit: String, myId: Int)
+ extends Thread with Logging {
+ private var peerSocketToSource: Socket = null
+ private var oosSource: ObjectOutputStream = null
+ private var oisSource: ObjectInputStream = null
+
+ private var receptionSucceeded = false
+
+ // Make sure that multiple messages don't go to the tracker
+ private var alreadySentLeavingNotification = false
+
+ // Keep track of bytes received and time spent
+ private var numBytesReceived = 0
+ private var totalTimeSpent = 0
+
+ override def run: Unit = {
+ // Setup the timeout mechanism
+ var timeOutTask = new TimerTask {
+ override def run: Unit = {
+ cleanUp()
+ }
+ }
+
+ var timeOutTimer = new Timer
+ timeOutTimer.schedule(timeOutTask, Shuffle.MaxKnockInterval)
+
+ try {
+ // Everything will break if BLOCKNUM is not correctly received
+ // First get BLOCKNUM file if totalBlocksInSplit(splitIndex) is unknown
+ peerSocketToSource = new Socket(
+ serversplitInfo.hostAddress, serversplitInfo.listenPort)
+ oosSource =
+ new ObjectOutputStream(peerSocketToSource.getOutputStream)
+ oosSource.flush()
+ var isSource = peerSocketToSource.getInputStream
+ oisSource = new ObjectInputStream(isSource)
+
+ // Send path information
+ oosSource.writeObject(requestSplit)
+
+ // TODO: Can be optimized. No need to do it everytime.
+ // Receive BLOCKNUM
+ totalBlocksInSplit(splitIndex) = oisSource.readObject.asInstanceOf[Int]
+
+ // Turn the timer OFF, if the sender responds before timeout
+ timeOutTimer.cancel()
+
+ while (hasBlocksInSplit(splitIndex) < totalBlocksInSplit(splitIndex)) {
+ // Set receptionSucceeded to false before trying for each block
+ receptionSucceeded = false
+
+ // Request specific block
+ oosSource.writeObject(hasBlocksInSplit(splitIndex))
+
+ // Good to go. First, receive the length of the requested file
+ var requestedFileLen = oisSource.readObject.asInstanceOf[Int]
+ logInfo("Received requestedFileLen = " + requestedFileLen)
+
+ // Create a temp variable to be used in different places
+ val requestPath = "http://%s:%d/shuffle/%s-%d".format(
+ serversplitInfo.hostAddress, serversplitInfo.listenPort, requestSplit,
+ hasBlocksInSplit(splitIndex))
+
+ // Receive the file
+ if (requestedFileLen != -1) {
+ val readStartTime = System.currentTimeMillis
+ logInfo("BEGIN READ: " + requestPath)
+
+ // Receive data in an Array[Byte]
+ var recvByteArray = new Array[Byte](requestedFileLen)
+ var alreadyRead = 0
+ var bytesRead = 0
+
+ while (alreadyRead != requestedFileLen) {
+ bytesRead = isSource.read(recvByteArray, alreadyRead,
+ requestedFileLen - alreadyRead)
+ if (bytesRead > 0) {
+ alreadyRead = alreadyRead + bytesRead
+ }
+ }
+
+ // Make it available to the consumer
+ try {
+ receivedData.put((splitIndex, recvByteArray))
+ } catch {
+ case e: Exception => {
+ logInfo("Exception during putting data into receivedData")
+ }
+ }
+
+ // TODO: Updating stats before consumption is completed
+ hasBlocksInSplit(splitIndex) = hasBlocksInSplit(splitIndex) + 1
+
+ // Split has been received only if all the blocks have been received
+ if (hasBlocksInSplit(splitIndex) == totalBlocksInSplit(splitIndex)) {
+ hasSplitsBitVector.synchronized {
+ hasSplitsBitVector.set(splitIndex)
+ hasSplits += 1
+ }
+ }
+
+ receptionSucceeded = true
+
+ logInfo("END READ: " + requestPath)
+ val readTime = System.currentTimeMillis - readStartTime
+ logInfo("Reading " + requestPath + " took " + readTime + " millis.")
+
+ // Update stats
+ numBytesReceived = numBytesReceived + requestedFileLen
+ totalTimeSpent = totalTimeSpent + readTime.toInt
+ } else {
+ throw new SparkException("ShuffleServer " + serversplitInfo.hostAddress + " does not have " + requestSplit)
+ }
+ }
+ } catch {
+ // EOFException is expected to happen because sender can break
+ // connection due to timeout
+ case eofe: java.io.EOFException => { }
+ case e: Exception => {
+ logInfo("ShuffleClient had a " + e)
+ }
+ } finally {
+ splitsInRequestBitVector.synchronized {
+ splitsInRequestBitVector.set(splitIndex, false)
+ }
+ cleanUp()
+ }
+ }
+
+ // Connect to the tracker and update its stats
+ private def sendLeavingNotification(): Unit = synchronized {
+ if (!alreadySentLeavingNotification) {
+ val clientSocketToTracker = new Socket(Shuffle.MasterHostAddress,
+ Shuffle.MasterTrackerPort)
+ val oosTracker =
+ new ObjectOutputStream(clientSocketToTracker.getOutputStream)
+ oosTracker.flush()
+ val oisTracker =
+ new ObjectInputStream(clientSocketToTracker.getInputStream)
+
+ try {
+ // Send intention
+ oosTracker.writeObject(Shuffle.ReducerLeaving)
+ oosTracker.flush()
+
+ // Send reducerSplitInfo
+ oosTracker.writeObject(getLocalSplitInfo(myId))
+ oosTracker.flush()
+
+ // Send reception stats
+ oosTracker.writeObject(ReceptionStats(
+ numBytesReceived, totalTimeSpent, splitIndex))
+ oosTracker.flush()
+
+ // Receive ACK. No need to do anything with that
+ oisTracker.readObject.asInstanceOf[Int]
+
+ // Now update sentLeavingNotifacation
+ alreadySentLeavingNotification = true
+ } catch {
+ case e: Exception => {
+ logInfo("sendLeavingNotification had a " + e)
+ }
+ } finally {
+ oisTracker.close()
+ oosTracker.close()
+ clientSocketToTracker.close()
+ }
+ }
+ }
+
+ private def cleanUp(): Unit = {
+ // Update tracker stats first
+ sendLeavingNotification()
+
+ // Clean up the connections to the mapper
+ if (oisSource != null) {
+ oisSource.close()
+ }
+ if (oosSource != null) {
+ oosSource.close()
+ }
+ if (peerSocketToSource != null) {
+ peerSocketToSource.close()
+ }
+
+ logInfo("Leaving client")
+ }
+ }
+}
+
+object TrackedCustomBlockedInMemoryShuffle extends Logging {
+ // Cache for keeping the splits around
+ val splitsCache = new HashMap[String, Array[Byte]]
+
+ private var initialized = false
+ private var nextShuffleId = new AtomicLong(0)
+
+ // Variables initialized by initializeIfNeeded()
+ private var shuffleDir: File = null
+
+ private var shuffleServer: ShuffleServer = null
+ private var serverAddress = InetAddress.getLocalHost.getHostAddress
+ private var serverPort: Int = -1
+
+ // Random number generator
+ var ranGen = new Random
+
+ private def initializeIfNeeded() = synchronized {
+ if (!initialized) {
+ // TODO: localDir should be created by some mechanism common to Spark
+ // so that it can be shared among shuffle, broadcast, etc
+ val localDirRoot = System.getProperty("spark.local.dir", "/tmp")
+ var tries = 0
+ var foundLocalDir = false
+ var localDir: File = null
+ var localDirUuid: UUID = null
+ while (!foundLocalDir && tries < 10) {
+ tries += 1
+ try {
+ localDirUuid = UUID.randomUUID
+ localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
+ if (!localDir.exists) {
+ localDir.mkdirs()
+ foundLocalDir = true
+ }
+ } catch {
+ case e: Exception =>
+ logWarning("Attempt " + tries + " to create local dir failed", e)
+ }
+ }
+ if (!foundLocalDir) {
+ logError("Failed 10 attempts to create local dir in " + localDirRoot)
+ System.exit(1)
+ }
+ shuffleDir = new File(localDir, "shuffle")
+ shuffleDir.mkdirs()
+ logInfo("Shuffle dir: " + shuffleDir)
+
+ // Create and start the shuffleServer
+ shuffleServer = new ShuffleServer
+ shuffleServer.setDaemon(true)
+ shuffleServer.start()
+ logInfo("ShuffleServer started...")
+
+ initialized = true
+ }
+ }
+
+ def getSplitName(shuffleId: Long, inputId: Int, outputId: Int,
+ blockId: Int): String = {
+ initializeIfNeeded()
+ // Adding shuffleDir is unnecessary. Added to keep the parsers working
+ return "%s/%d/%d/%d-%d".format(shuffleDir, shuffleId, inputId, outputId,
+ blockId)
+ }
+
+ def getBlockNumOutputName(shuffleId: Long, inputId: Int,
+ outputId: Int): String = {
+ initializeIfNeeded()
+ return "%s/%d/%d/%d-BLOCKNUM".format(shuffleDir, shuffleId, inputId,
+ outputId)
+ }
+
+ def newShuffleId(): Long = {
+ nextShuffleId.getAndIncrement()
+ }
+
+ class ShuffleServer
+ extends Thread with Logging {
+ var threadPool = Shuffle.newDaemonFixedThreadPool(Shuffle.MaxTxConnections)
+
+ var serverSocket: ServerSocket = null
+
+ override def run: Unit = {
+ serverSocket = new ServerSocket(0)
+ serverPort = serverSocket.getLocalPort
+
+ logInfo("ShuffleServer started with " + serverSocket)
+ logInfo("Local URI: http://" + serverAddress + ":" + serverPort)
+
+ try {
+ while (true) {
+ var clientSocket: Socket = null
+ try {
+ clientSocket = serverSocket.accept()
+ } catch {
+ case e: Exception => { }
+ }
+ if (clientSocket != null) {
+ logInfo("Serve: Accepted new client connection:" + clientSocket)
+ try {
+ threadPool.execute(new ShuffleServerThread(clientSocket))
+ } catch {
+ // In failure, close socket here; else, the thread will close it
+ case ioe: IOException => {
+ clientSocket.close()
+ }
+ }
+ }
+ }
+ } finally {
+ if (serverSocket != null) {
+ logInfo("ShuffleServer now stopping...")
+ serverSocket.close()
+ }
+ }
+ // Shutdown the thread pool
+ threadPool.shutdown()
+ }
+
+ class ShuffleServerThread(val clientSocket: Socket)
+ extends Thread with Logging {
+ private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream]
+ os.flush()
+ private val bos = new BufferedOutputStream(os)
+ bos.flush()
+ private val oos = new ObjectOutputStream(os)
+ oos.flush()
+ private val ois = new ObjectInputStream(clientSocket.getInputStream)
+
+ logInfo("new ShuffleServerThread is running")
+
+ override def run: Unit = {
+ try {
+ // Receive basic path information
+ var requestedSplitBase = ois.readObject.asInstanceOf[String]
+
+ logInfo("requestedSplitBase: " + requestedSplitBase)
+
+ // Read BLOCKNUM and send back the total number of blocks
+ val blockNumName = "%s/%s-BLOCKNUM".format(shuffleDir,
+ requestedSplitBase)
+
+ val blockNumIn = new ObjectInputStream(new ByteArrayInputStream(
+ TrackedCustomBlockedInMemoryShuffle.splitsCache(blockNumName)))
+ val BLOCKNUM = blockNumIn.readObject.asInstanceOf[Int]
+ blockNumIn.close()
+
+ oos.writeObject(BLOCKNUM)
+ oos.flush()
+
+ val startTime = System.currentTimeMillis
+ var curTime = startTime
+ var keepSending = true
+ var numBlocksToSend = Shuffle.MaxChatBlocks
+
+ while (keepSending && numBlocksToSend > 0) {
+ // Receive specific block request
+ val blockId = ois.readObject.asInstanceOf[Int]
+
+ // Ready to send
+ var requestedSplit = shuffleDir + "/" + requestedSplitBase + "-" + blockId
+
+ // Send the length of the requestedSplit to let the receiver know that
+ // transfer is about to start
+ // In the case of receiver timeout and connection close, this will
+ // throw a java.net.SocketException: Broken pipe
+ var requestedSplitLen = -1
+
+ try {
+ requestedSplitLen =
+ TrackedCustomBlockedInMemoryShuffle.splitsCache(requestedSplit).length
+ } catch {
+ case e: Exception => { }
+ }
+
+ oos.writeObject(requestedSplitLen)
+ oos.flush()
+
+ logInfo("requestedSplitLen = " + requestedSplitLen)
+
+ // Read and send the requested file
+ if (requestedSplitLen != -1) {
+ // Send
+ bos.write(TrackedCustomBlockedInMemoryShuffle.splitsCache(requestedSplit),
+ 0, requestedSplitLen)
+ bos.flush()
+
+ // Update loop variables
+ numBlocksToSend = numBlocksToSend - 1
+
+ curTime = System.currentTimeMillis
+ // Revoke sending only if there is anyone waiting in the queue
+ // TODO: Turning OFF the optimization so that reducers go back to
+ // tracker get advice
+ if (curTime - startTime >= Shuffle.MaxChatTime /* &&
+ threadPool.getQueue.size > 0 */) {
+ keepSending = false
+ }
+ } else {
+ // Close the connection
+ }
+ }
+ } catch {
+ // If something went wrong, e.g., the worker at the other end died etc
+ // then close everything up
+ // Exception can happen if the receiver stops receiving
+ // EOFException is expected to happen because receiver can break
+ // connection as soon as it has all the blocks
+ case eofe: java.io.EOFException => { }
+ case e: Exception => {
+ logInfo("ShuffleServerThread had a " + e)
+ }
+ } finally {
+ logInfo("ShuffleServerThread is closing streams and sockets")
+ ois.close()
+ // TODO: Following can cause "java.net.SocketException: Socket closed"
+ oos.close()
+ bos.close()
+ clientSocket.close()
+ }
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/TrackedCustomBlockedLocalFileShuffle.scala b/core/src/main/scala/spark/TrackedCustomBlockedLocalFileShuffle.scala
new file mode 100644
index 0000000000..a27fa628c6
--- /dev/null
+++ b/core/src/main/scala/spark/TrackedCustomBlockedLocalFileShuffle.scala
@@ -0,0 +1,930 @@
+package spark
+
+import java.io._
+import java.net._
+import java.util.{BitSet, Random, Timer, TimerTask, UUID}
+import java.util.concurrent.atomic.AtomicLong
+import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory}
+
+import scala.collection.mutable.{ArrayBuffer, HashMap}
+
+/**
+ * An implementation of shuffle using local files served through custom server
+ * where receivers create simultaneous connections to multiple servers by
+ * setting the 'spark.shuffle.maxRxConnections' config option.
+ *
+ * By controlling the 'spark.shuffle.blockSize' config option one can also
+ * control the largest block size to divide each map output into. Essentially,
+ * instead of creating one large output file for each reducer, maps create
+ * multiple smaller files to enable finer level of engagement.
+ *
+ * 'spark.shuffle.maxTxConnections' enforces server-side cap. Ideally,
+ * maxTxConnections >= maxRxConnections * numReducersPerMachine
+ *
+ * 'spark.shuffle.TrackerStrategy' decides which strategy to use in the tracker
+ *
+ * TODO: Add support for compression when spark.compress is set to true.
+ */
+@serializable
+class TrackedCustomBlockedLocalFileShuffle[K, V, C]
+extends Shuffle[K, V, C] with Logging {
+ @transient var totalSplits = 0
+ @transient var hasSplits = 0
+
+ @transient var totalBlocksInSplit: Array[Int] = null
+ @transient var hasBlocksInSplit: Array[Int] = null
+
+ @transient var hasSplitsBitVector: BitSet = null
+ @transient var splitsInRequestBitVector: BitSet = null
+
+ @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null
+ @transient var combiners: HashMap[K,C] = null
+
+ override def compute(input: RDD[(K, V)],
+ numOutputSplits: Int,
+ createCombiner: V => C,
+ mergeValue: (C, V) => C,
+ mergeCombiners: (C, C) => C)
+ : RDD[(K, C)] =
+ {
+ val sc = input.sparkContext
+ val shuffleId = TrackedCustomBlockedLocalFileShuffle.newShuffleId()
+ logInfo("Shuffle ID: " + shuffleId)
+
+ val splitRdd = new NumberedSplitRDD(input)
+ val numInputSplits = splitRdd.splits.size
+
+ // Run a parallel map and collect to write the intermediate data files,
+ // returning a list of inputSplitId -> serverUri pairs
+ val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => {
+ val myIndex = pair._1
+ val myIterator = pair._2
+ val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C])
+ for ((k, v) <- myIterator) {
+ var bucketId = k.hashCode % numOutputSplits
+ if (bucketId < 0) { // Fix bucket ID if hash code was negative
+ bucketId += numOutputSplits
+ }
+ val bucket = buckets(bucketId)
+ bucket(k) = bucket.get(k) match {
+ case Some(c) => mergeValue(c, v)
+ case None => createCombiner(v)
+ }
+ }
+
+ // Keep track of number of blocks for each output split
+ var numBlocksPerOutputSplit = Array.tabulate(numOutputSplits)(_ => 0)
+
+ for (i <- 0 until numOutputSplits) {
+ var blockNum = 0
+ var isDirty = false
+ var file: File = null
+ var out: ObjectOutputStream = null
+
+ var writeStartTime: Long = 0
+
+ buckets(i).foreach(pair => {
+ // Open a new file if necessary
+ if (!isDirty) {
+ file = TrackedCustomBlockedLocalFileShuffle.getOutputFile(shuffleId,
+ myIndex, i, blockNum)
+ writeStartTime = System.currentTimeMillis
+ logInfo("BEGIN WRITE: " + file)
+
+ out = new ObjectOutputStream(new FileOutputStream(file))
+ }
+
+ out.writeObject(pair)
+ out.flush()
+ isDirty = true
+
+ // Close the old file if has crossed the blockSize limit
+ if (file.length > Shuffle.BlockSize) {
+ out.close()
+ logInfo("END WRITE: " + file)
+ val writeTime = System.currentTimeMillis - writeStartTime
+ logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.")
+
+ blockNum = blockNum + 1
+ isDirty = false
+ }
+ })
+
+ if (isDirty) {
+ out.close()
+ logInfo("END WRITE: " + file)
+ val writeTime = System.currentTimeMillis - writeStartTime
+ logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.")
+
+ blockNum = blockNum + 1
+ }
+
+ // Write the BLOCKNUM file
+ file = TrackedCustomBlockedLocalFileShuffle.getBlockNumOutputFile(shuffleId,
+ myIndex, i)
+ out = new ObjectOutputStream(new FileOutputStream(file))
+ out.writeObject(blockNum)
+ out.close()
+
+ // Store number of blocks for this outputSplit
+ numBlocksPerOutputSplit(i) = blockNum
+ }
+
+ var retVal = SplitInfo(TrackedCustomBlockedLocalFileShuffle.serverAddress,
+ TrackedCustomBlockedLocalFileShuffle.serverPort, myIndex)
+ retVal.totalBlocksPerOutputSplit = numBlocksPerOutputSplit
+
+ (retVal)
+ }).collect()
+
+ // Start tracker
+ var shuffleTracker = new ShuffleTracker(outputLocs)
+ shuffleTracker.setDaemon(true)
+ shuffleTracker.start()
+ logInfo("ShuffleTracker started...")
+
+ // Return an RDD that does each of the merges for a given partition
+ val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits)
+ return indexes.flatMap((myId: Int) => {
+ totalSplits = outputLocs.size
+ hasSplits = 0
+
+ totalBlocksInSplit = Array.tabulate(totalSplits)(_ => -1)
+ hasBlocksInSplit = Array.tabulate(totalSplits)(_ => 0)
+
+ hasSplitsBitVector = new BitSet(totalSplits)
+ splitsInRequestBitVector = new BitSet(totalSplits)
+
+ receivedData = new LinkedBlockingQueue[(Int, Array[Byte])]
+ combiners = new HashMap[K, C]
+
+ var threadPool = Shuffle.newDaemonFixedThreadPool(
+ Shuffle.MaxRxConnections)
+
+ while (hasSplits < totalSplits) {
+ // Local status of hasSplitsBitVector and splitsInRequestBitVector
+ val localSplitInfo = getLocalSplitInfo(myId)
+
+ // DO NOT talk to the tracker if all the required splits are already busy
+ val hasOrWillHaveSplits = localSplitInfo.hasSplitsBitVector.cardinality
+
+ var numThreadsToCreate =
+ Math.min(totalSplits - hasOrWillHaveSplits, Shuffle.MaxRxConnections) -
+ threadPool.getActiveCount
+
+ while (hasSplits < totalSplits && numThreadsToCreate > 0) {
+ // Receive which split to pull from the tracker
+ logInfo("Talking to tracker...")
+ val startTime = System.currentTimeMillis
+ val splitIndices = getTrackerSelectedSplit(myId)
+ logInfo("Got %s from tracker in %d millis".format(splitIndices, System.currentTimeMillis - startTime))
+
+ if (splitIndices.size > 0) {
+ splitIndices.foreach { splitIndex =>
+ val selectedSplitInfo = outputLocs(splitIndex)
+ val requestSplit =
+ "%d/%d/%d".format(shuffleId, selectedSplitInfo.splitId, myId)
+
+ threadPool.execute(new ShuffleClient(splitIndex, selectedSplitInfo,
+ requestSplit, myId))
+
+ // splitIndex is in transit. Will be unset in the ShuffleClient
+ splitsInRequestBitVector.synchronized {
+ splitsInRequestBitVector.set(splitIndex)
+ }
+ }
+ } else {
+ // Tracker replied back with a NO. Sleep for a while.
+ Thread.sleep(Shuffle.MinKnockInterval)
+ numThreadsToCreate = 0
+ }
+
+ numThreadsToCreate = numThreadsToCreate - splitIndices.size
+ }
+
+ // Sleep for a while before creating new threads
+ Thread.sleep(Shuffle.MinKnockInterval)
+ }
+
+ threadPool.shutdown()
+
+ // Start consumer
+ // TODO: Consumption is delayed until everything has been received.
+ // Otherwise it interferes with network performance
+ var shuffleConsumer = new ShuffleConsumer(mergeCombiners)
+ shuffleConsumer.setDaemon(true)
+ shuffleConsumer.start()
+ logInfo("ShuffleConsumer started...")
+
+ // Don't return until consumption is finished
+ // while (receivedData.size > 0) {
+ // Thread.sleep(Shuffle.MinKnockInterval)
+ // }
+
+ // Wait till shuffleConsumer is done
+ shuffleConsumer.join
+
+ combiners
+ })
+ }
+
+ private def getLocalSplitInfo(myId: Int): SplitInfo = {
+ var localSplitInfo = SplitInfo(InetAddress.getLocalHost.getHostAddress,
+ SplitInfo.UnusedParam, myId)
+
+ // Store hasSplits
+ localSplitInfo.hasSplits = hasSplits
+
+ // Store hasSplitsBitVector
+ hasSplitsBitVector.synchronized {
+ localSplitInfo.hasSplitsBitVector =
+ hasSplitsBitVector.clone.asInstanceOf[BitSet]
+ }
+
+ // Store hasBlocksInSplit to hasBlocksPerInputSplit
+ hasBlocksInSplit.synchronized {
+ localSplitInfo.hasBlocksPerInputSplit =
+ hasBlocksInSplit.clone.asInstanceOf[Array[Int]]
+ }
+
+ // Include the splitsInRequest as well
+ splitsInRequestBitVector.synchronized {
+ localSplitInfo.hasSplitsBitVector.or(splitsInRequestBitVector)
+ }
+
+ return localSplitInfo
+ }
+
+ def selectRandomSplit: Int = {
+ var requiredSplits = new ArrayBuffer[Int]
+
+ synchronized {
+ for (i <- 0 until totalSplits) {
+ if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) {
+ requiredSplits += i
+ }
+ }
+ }
+
+ if (requiredSplits.size > 0) {
+ requiredSplits(TrackedCustomBlockedLocalFileShuffle.ranGen.nextInt(
+ requiredSplits.size))
+ } else {
+ -1
+ }
+ }
+
+ // Talks to the tracker and receives instruction
+ private def getTrackerSelectedSplit(myId: Int): ArrayBuffer[Int] = {
+ // Local status of hasSplitsBitVector and splitsInRequestBitVector
+ val localSplitInfo = getLocalSplitInfo(myId)
+
+ // DO NOT talk to the tracker if all the required splits are already busy
+ if (localSplitInfo.hasSplitsBitVector.cardinality == totalSplits) {
+ return ArrayBuffer[Int]()
+ }
+
+ val clientSocketToTracker = new Socket(Shuffle.MasterHostAddress,
+ Shuffle.MasterTrackerPort)
+ val oosTracker =
+ new ObjectOutputStream(clientSocketToTracker.getOutputStream)
+ oosTracker.flush()
+ val oisTracker =
+ new ObjectInputStream(clientSocketToTracker.getInputStream)
+
+ var selectedSplitIndices = ArrayBuffer[Int]()
+
+ // Setup the timeout mechanism
+ var timeOutTask = new TimerTask {
+ override def run: Unit = {
+ logInfo("Waited enough for tracker response... Take random response...")
+
+ // sockets will be closed in finally
+ // TODO: Sometimes timer wont go off
+
+ // TODO: Selecting randomly here. Tracker won't know about it and get an
+ // asssertion failure when this thread leaves
+
+ selectedSplitIndices = ArrayBuffer(selectRandomSplit)
+ }
+ }
+
+ var timeOutTimer = new Timer
+ // TODO: Which timeout to use?
+ // timeOutTimer.schedule(timeOutTask, Shuffle.MinKnockInterval)
+
+ try {
+ // Send intention
+ oosTracker.writeObject(Shuffle.ReducerEntering)
+ oosTracker.flush()
+
+ // Send what this reducer has
+ oosTracker.writeObject(localSplitInfo)
+ oosTracker.flush()
+
+ // Receive reply from the tracker
+ selectedSplitIndices = oisTracker.readObject.asInstanceOf[ArrayBuffer[Int]]
+
+ // Turn the timer OFF
+ timeOutTimer.cancel()
+ } catch {
+ case e: Exception => {
+ logInfo("getTrackerSelectedSplit had a " + e)
+ }
+ } finally {
+ oisTracker.close()
+ oosTracker.close()
+ clientSocketToTracker.close()
+ }
+
+ return selectedSplitIndices
+ }
+
+ class ShuffleTracker(outputLocs: Array[SplitInfo])
+ extends Thread with Logging {
+ var threadPool = Shuffle.newDaemonCachedThreadPool
+ var serverSocket: ServerSocket = null
+
+ // Create trackerStrategy object
+ val trackerStrategyClass = System.getProperty(
+ "spark.shuffle.trackerStrategy",
+ "spark.BalanceConnectionsShuffleTrackerStrategy")
+
+ val trackerStrategy =
+ Class.forName(trackerStrategyClass).newInstance().asInstanceOf[ShuffleTrackerStrategy]
+
+ // Must initialize here by supplying the outputLocs param
+ // TODO: This could be avoided by directly passing it to the constructor
+ trackerStrategy.initialize(outputLocs)
+
+ override def run: Unit = {
+ serverSocket = new ServerSocket(Shuffle.MasterTrackerPort)
+ logInfo("ShuffleTracker" + serverSocket)
+
+ try {
+ while (true) {
+ var clientSocket: Socket = null
+ try {
+ clientSocket = serverSocket.accept()
+ } catch {
+ case e: Exception => {
+ logInfo("ShuffleTracker had a " + e)
+ }
+ }
+
+ if (clientSocket != null) {
+ try {
+ threadPool.execute(new Thread {
+ override def run: Unit = {
+ val oos = new ObjectOutputStream(clientSocket.getOutputStream)
+ oos.flush()
+ val ois = new ObjectInputStream(clientSocket.getInputStream)
+
+ try {
+ // Receive intention
+ val reducerIntention = ois.readObject.asInstanceOf[Int]
+
+ if (reducerIntention == Shuffle.ReducerEntering) {
+ // Receive what the reducer has
+ val reducerSplitInfo =
+ ois.readObject.asInstanceOf[SplitInfo]
+
+ // Select splits and update stats if necessary
+ var selectedSplitIndices = ArrayBuffer[Int]()
+ trackerStrategy.synchronized {
+ selectedSplitIndices = trackerStrategy.selectSplit(
+ reducerSplitInfo)
+ }
+
+ // Send reply back
+ oos.writeObject(selectedSplitIndices)
+ oos.flush()
+
+ // Update internal stats, only if receiver got the reply
+ trackerStrategy.synchronized {
+ trackerStrategy.AddReducerToSplit(reducerSplitInfo,
+ selectedSplitIndices)
+ }
+ }
+ else if (reducerIntention == Shuffle.ReducerLeaving) {
+ val reducerSplitInfo =
+ ois.readObject.asInstanceOf[SplitInfo]
+
+ // Receive reception stats: how many blocks the reducer
+ // read in how much time and from where
+ val receptionStat =
+ ois.readObject.asInstanceOf[ReceptionStats]
+
+ // Update stats
+ trackerStrategy.synchronized {
+ trackerStrategy.deleteReducerFrom(reducerSplitInfo,
+ receptionStat)
+ }
+
+ // Send ACK
+ oos.writeObject(receptionStat.serverSplitIndex)
+ oos.flush()
+ }
+ else {
+ throw new SparkException("Undefined reducerIntention")
+ }
+ } catch {
+ // EOFException is expected to happen because receiver can
+ // break connection due to timeout and pick random instead
+ case eofe: java.io.EOFException => { }
+ case e: Exception => {
+ logInfo("ShuffleTracker had a " + e)
+ }
+ } finally {
+ ois.close()
+ oos.close()
+ clientSocket.close()
+ }
+ }
+ })
+ } catch {
+ // In failure, close socket here; else, client thread will close
+ case ioe: IOException => {
+ clientSocket.close()
+ }
+ }
+ }
+ }
+ } finally {
+ serverSocket.close()
+ }
+ // Shutdown the thread pool
+ threadPool.shutdown()
+ }
+ }
+
+ class ShuffleConsumer(mergeCombiners: (C, C) => C)
+ extends Thread with Logging {
+ override def run: Unit = {
+ // Run until all splits are here
+ while (receivedData.size > 0) {
+ var splitIndex = -1
+ var recvByteArray: Array[Byte] = null
+
+ try {
+ var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])]
+ splitIndex = tempPair._1
+ recvByteArray = tempPair._2
+ } catch {
+ case e: Exception => {
+ logInfo("Exception during taking data from receivedData")
+ }
+ }
+
+ val inputStream =
+ new ObjectInputStream(new ByteArrayInputStream(recvByteArray))
+
+ try{
+ while (true) {
+ val (k, c) = inputStream.readObject.asInstanceOf[(K, C)]
+ combiners(k) = combiners.get(k) match {
+ case Some(oldC) => mergeCombiners(oldC, c)
+ case None => c
+ }
+ }
+ } catch {
+ case e: EOFException => { }
+ }
+ inputStream.close()
+ }
+ }
+ }
+
+ class ShuffleClient(splitIndex: Int, serversplitInfo: SplitInfo,
+ requestSplit: String, myId: Int)
+ extends Thread with Logging {
+ private var peerSocketToSource: Socket = null
+ private var oosSource: ObjectOutputStream = null
+ private var oisSource: ObjectInputStream = null
+
+ private var receptionSucceeded = false
+
+ // Make sure that multiple messages don't go to the tracker
+ private var alreadySentLeavingNotification = false
+
+ // Keep track of bytes received and time spent
+ private var numBytesReceived = 0
+ private var totalTimeSpent = 0
+
+ override def run: Unit = {
+ // Setup the timeout mechanism
+ var timeOutTask = new TimerTask {
+ override def run: Unit = {
+ cleanUp()
+ }
+ }
+
+ var timeOutTimer = new Timer
+ timeOutTimer.schedule(timeOutTask, Shuffle.MaxKnockInterval)
+
+ try {
+ // Everything will break if BLOCKNUM is not correctly received
+ // First get BLOCKNUM file if totalBlocksInSplit(splitIndex) is unknown
+ peerSocketToSource = new Socket(
+ serversplitInfo.hostAddress, serversplitInfo.listenPort)
+ oosSource =
+ new ObjectOutputStream(peerSocketToSource.getOutputStream)
+ oosSource.flush()
+ var isSource = peerSocketToSource.getInputStream
+ oisSource = new ObjectInputStream(isSource)
+
+ // Send path information
+ oosSource.writeObject(requestSplit)
+
+ // TODO: Can be optimized. No need to do it everytime.
+ // Receive BLOCKNUM
+ totalBlocksInSplit(splitIndex) = oisSource.readObject.asInstanceOf[Int]
+
+ // Turn the timer OFF, if the sender responds before timeout
+ timeOutTimer.cancel()
+
+ while (hasBlocksInSplit(splitIndex) < totalBlocksInSplit(splitIndex)) {
+ // Set receptionSucceeded to false before trying for each block
+ receptionSucceeded = false
+
+ // Request specific block
+ oosSource.writeObject(hasBlocksInSplit(splitIndex))
+
+ // Good to go. First, receive the length of the requested file
+ var requestedFileLen = oisSource.readObject.asInstanceOf[Int]
+ logInfo("Received requestedFileLen = " + requestedFileLen)
+
+ // Create a temp variable to be used in different places
+ val requestPath = "http://%s:%d/shuffle/%s-%d".format(
+ serversplitInfo.hostAddress, serversplitInfo.listenPort, requestSplit,
+ hasBlocksInSplit(splitIndex))
+
+ // Receive the file
+ if (requestedFileLen != -1) {
+ val readStartTime = System.currentTimeMillis
+ logInfo("BEGIN READ: " + requestPath)
+
+ // Receive data in an Array[Byte]
+ var recvByteArray = new Array[Byte](requestedFileLen)
+ var alreadyRead = 0
+ var bytesRead = 0
+
+ while (alreadyRead != requestedFileLen) {
+ bytesRead = isSource.read(recvByteArray, alreadyRead,
+ requestedFileLen - alreadyRead)
+ if (bytesRead > 0) {
+ alreadyRead = alreadyRead + bytesRead
+ }
+ }
+
+ // Make it available to the consumer
+ try {
+ receivedData.put((splitIndex, recvByteArray))
+ } catch {
+ case e: Exception => {
+ logInfo("Exception during putting data into receivedData")
+ }
+ }
+
+ // TODO: Updating stats before consumption is completed
+ hasBlocksInSplit(splitIndex) = hasBlocksInSplit(splitIndex) + 1
+
+ // Split has been received only if all the blocks have been received
+ if (hasBlocksInSplit(splitIndex) == totalBlocksInSplit(splitIndex)) {
+ hasSplitsBitVector.synchronized {
+ hasSplitsBitVector.set(splitIndex)
+ hasSplits += 1
+ }
+ }
+
+ receptionSucceeded = true
+
+ logInfo("END READ: " + requestPath)
+ val readTime = System.currentTimeMillis - readStartTime
+ logInfo("Reading " + requestPath + " took " + readTime + " millis.")
+
+ // Update stats
+ numBytesReceived = numBytesReceived + requestedFileLen
+ totalTimeSpent = totalTimeSpent + readTime.toInt
+ } else {
+ throw new SparkException("ShuffleServer " + serversplitInfo.hostAddress + " does not have " + requestSplit)
+ }
+ }
+ } catch {
+ // EOFException is expected to happen because sender can break
+ // connection due to timeout
+ case eofe: java.io.EOFException => { }
+ case e: Exception => {
+ logInfo("ShuffleClient had a " + e)
+ }
+ } finally {
+ splitsInRequestBitVector.synchronized {
+ splitsInRequestBitVector.set(splitIndex, false)
+ }
+ cleanUp()
+ }
+ }
+
+ // Connect to the tracker and update its stats
+ private def sendLeavingNotification(): Unit = synchronized {
+ if (!alreadySentLeavingNotification) {
+ val clientSocketToTracker = new Socket(Shuffle.MasterHostAddress,
+ Shuffle.MasterTrackerPort)
+ val oosTracker =
+ new ObjectOutputStream(clientSocketToTracker.getOutputStream)
+ oosTracker.flush()
+ val oisTracker =
+ new ObjectInputStream(clientSocketToTracker.getInputStream)
+
+ try {
+ // Send intention
+ oosTracker.writeObject(Shuffle.ReducerLeaving)
+ oosTracker.flush()
+
+ // Send reducerSplitInfo
+ oosTracker.writeObject(getLocalSplitInfo(myId))
+ oosTracker.flush()
+
+ // Send reception stats
+ oosTracker.writeObject(ReceptionStats(
+ numBytesReceived, totalTimeSpent, splitIndex))
+ oosTracker.flush()
+
+ // Receive ACK. No need to do anything with that
+ oisTracker.readObject.asInstanceOf[Int]
+
+ // Now update sentLeavingNotifacation
+ alreadySentLeavingNotification = true
+ } catch {
+ case e: Exception => {
+ logInfo("sendLeavingNotification had a " + e)
+ }
+ } finally {
+ oisTracker.close()
+ oosTracker.close()
+ clientSocketToTracker.close()
+ }
+ }
+ }
+
+ private def cleanUp(): Unit = {
+ // Update tracker stats first
+ sendLeavingNotification()
+
+ // Clean up the connections to the mapper
+ if (oisSource != null) {
+ oisSource.close()
+ }
+ if (oosSource != null) {
+ oosSource.close()
+ }
+ if (peerSocketToSource != null) {
+ peerSocketToSource.close()
+ }
+
+ logInfo("Leaving client")
+ }
+ }
+}
+
+object TrackedCustomBlockedLocalFileShuffle extends Logging {
+ private var initialized = false
+ private var nextShuffleId = new AtomicLong(0)
+
+ // Variables initialized by initializeIfNeeded()
+ private var shuffleDir: File = null
+
+ private var shuffleServer: ShuffleServer = null
+ private var serverAddress = InetAddress.getLocalHost.getHostAddress
+ private var serverPort: Int = -1
+
+ // Random number generator
+ var ranGen = new Random
+
+ private def initializeIfNeeded() = synchronized {
+ if (!initialized) {
+ // TODO: localDir should be created by some mechanism common to Spark
+ // so that it can be shared among shuffle, broadcast, etc
+ val localDirRoot = System.getProperty("spark.local.dir", "/tmp")
+ var tries = 0
+ var foundLocalDir = false
+ var localDir: File = null
+ var localDirUuid: UUID = null
+ while (!foundLocalDir && tries < 10) {
+ tries += 1
+ try {
+ localDirUuid = UUID.randomUUID
+ localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
+ if (!localDir.exists) {
+ localDir.mkdirs()
+ foundLocalDir = true
+ }
+ } catch {
+ case e: Exception =>
+ logWarning("Attempt " + tries + " to create local dir failed", e)
+ }
+ }
+ if (!foundLocalDir) {
+ logError("Failed 10 attempts to create local dir in " + localDirRoot)
+ System.exit(1)
+ }
+ shuffleDir = new File(localDir, "shuffle")
+ shuffleDir.mkdirs()
+ logInfo("Shuffle dir: " + shuffleDir)
+
+ // Create and start the shuffleServer
+ shuffleServer = new ShuffleServer
+ shuffleServer.setDaemon(true)
+ shuffleServer.start()
+ logInfo("ShuffleServer started...")
+
+ initialized = true
+ }
+ }
+
+ def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int,
+ blockId: Int): File = {
+ initializeIfNeeded()
+ val dir = new File(shuffleDir, shuffleId + "/" + inputId)
+ dir.mkdirs()
+ val file = new File(dir, "%d-%d".format(outputId, blockId))
+ return file
+ }
+
+ def getBlockNumOutputFile(shuffleId: Long, inputId: Int,
+ outputId: Int): File = {
+ initializeIfNeeded()
+ val dir = new File(shuffleDir, shuffleId + "/" + inputId)
+ dir.mkdirs()
+ val file = new File(dir, outputId + "-BLOCKNUM")
+ return file
+ }
+
+ def newShuffleId(): Long = {
+ nextShuffleId.getAndIncrement()
+ }
+
+ class ShuffleServer
+ extends Thread with Logging {
+ var threadPool = Shuffle.newDaemonFixedThreadPool(Shuffle.MaxTxConnections)
+
+ var serverSocket: ServerSocket = null
+
+ override def run: Unit = {
+ serverSocket = new ServerSocket(0)
+ serverPort = serverSocket.getLocalPort
+
+ logInfo("ShuffleServer started with " + serverSocket)
+ logInfo("Local URI: http://" + serverAddress + ":" + serverPort)
+
+ try {
+ while (true) {
+ var clientSocket: Socket = null
+ try {
+ clientSocket = serverSocket.accept()
+ } catch {
+ case e: Exception => { }
+ }
+ if (clientSocket != null) {
+ logInfo("Serve: Accepted new client connection:" + clientSocket)
+ try {
+ threadPool.execute(new ShuffleServerThread(clientSocket))
+ } catch {
+ // In failure, close socket here; else, the thread will close it
+ case ioe: IOException => {
+ clientSocket.close()
+ }
+ }
+ }
+ }
+ } finally {
+ if (serverSocket != null) {
+ logInfo("ShuffleServer now stopping...")
+ serverSocket.close()
+ }
+ }
+ // Shutdown the thread pool
+ threadPool.shutdown()
+ }
+
+ class ShuffleServerThread(val clientSocket: Socket)
+ extends Thread with Logging {
+ private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream]
+ os.flush()
+ private val bos = new BufferedOutputStream(os)
+ bos.flush()
+ private val oos = new ObjectOutputStream(os)
+ oos.flush()
+ private val ois = new ObjectInputStream(clientSocket.getInputStream)
+
+ logInfo("new ShuffleServerThread is running")
+
+ override def run: Unit = {
+ try {
+ // Receive basic path information
+ var requestPathBase = ois.readObject.asInstanceOf[String]
+
+ logInfo("requestPathBase: " + requestPathBase)
+
+ // Read BLOCKNUM file and send back the total number of blocks
+ val blockNumFilePath = "%s/%s-BLOCKNUM".format(shuffleDir,
+ requestPathBase)
+ val blockNumIn =
+ new ObjectInputStream(new FileInputStream(blockNumFilePath))
+ val BLOCKNUM = blockNumIn.readObject.asInstanceOf[Int]
+ blockNumIn.close()
+
+ oos.writeObject(BLOCKNUM)
+ oos.flush()
+
+ val startTime = System.currentTimeMillis
+ var curTime = startTime
+ var keepSending = true
+ var numBlocksToSend = Shuffle.MaxChatBlocks
+
+ while (keepSending && numBlocksToSend > 0) {
+ // Receive specific block request
+ val blockId = ois.readObject.asInstanceOf[Int]
+
+ // Ready to send
+ var requestPath = requestPathBase + "-" + blockId
+
+ // Open the file
+ var requestedFile: File = null
+ var requestedFileLen = -1
+ try {
+ requestedFile = new File(shuffleDir + "/" + requestPath)
+ requestedFileLen = requestedFile.length.toInt
+ } catch {
+ case e: Exception => { }
+ }
+
+ // Send the length of the requestPath to let the receiver know that
+ // transfer is about to start
+ // In the case of receiver timeout and connection close, this will
+ // throw a java.net.SocketException: Broken pipe
+ oos.writeObject(requestedFileLen)
+ oos.flush()
+
+ logInfo("requestedFileLen = " + requestedFileLen)
+
+ // Read and send the requested file
+ if (requestedFileLen != -1) {
+ // Read
+ var byteArray = new Array[Byte](requestedFileLen)
+ val bis =
+ new BufferedInputStream(new FileInputStream(requestedFile))
+
+ var bytesRead = bis.read(byteArray, 0, byteArray.length)
+ var alreadyRead = bytesRead
+
+ while (alreadyRead < requestedFileLen) {
+ bytesRead = bis.read(byteArray, alreadyRead,
+ (byteArray.length - alreadyRead))
+ if(bytesRead > 0) {
+ alreadyRead = alreadyRead + bytesRead
+ }
+ }
+ bis.close()
+
+ // Send
+ bos.write(byteArray, 0, byteArray.length)
+ bos.flush()
+
+ // Update loop variables
+ numBlocksToSend = numBlocksToSend - 1
+
+ curTime = System.currentTimeMillis
+ // Revoke sending only if there is anyone waiting in the queue
+ // TODO: Turning OFF the optimization so that reducers go back to
+ // tracker get advice
+ if (curTime - startTime >= Shuffle.MaxChatTime /* &&
+ threadPool.getQueue.size > 0 */) {
+ keepSending = false
+ }
+ } else {
+ // Close the connection
+ }
+ }
+ } catch {
+ // If something went wrong, e.g., the worker at the other end died etc
+ // then close everything up
+ // Exception can happen if the receiver stops receiving
+ // EOFException is expected to happen because receiver can break
+ // connection as soon as it has all the blocks
+ case eofe: java.io.EOFException => { }
+ case e: Exception => {
+ logInfo("ShuffleServerThread had a " + e)
+ }
+ } finally {
+ logInfo("ShuffleServerThread is closing streams and sockets")
+ ois.close()
+ // TODO: Following can cause "java.net.SocketException: Socket closed"
+ oos.close()
+ bos.close()
+ clientSocket.close()
+ }
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/TrackedCustomParallelLocalFileShuffle.scala b/core/src/main/scala/spark/TrackedCustomParallelLocalFileShuffle.scala
new file mode 100644
index 0000000000..4a38a8d7ff
--- /dev/null
+++ b/core/src/main/scala/spark/TrackedCustomParallelLocalFileShuffle.scala
@@ -0,0 +1,802 @@
+package spark
+
+import java.io._
+import java.net._
+import java.util.{BitSet, Random, Timer, TimerTask, UUID}
+import java.util.concurrent.atomic.AtomicLong
+import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory}
+
+import scala.collection.mutable.{ArrayBuffer, HashMap}
+
+/**
+ * An implementation of shuffle using local files served through custom server
+ * where receivers create simultaneous connections to multiple servers by
+ * setting the 'spark.shuffle.maxRxConnections' config option.
+ *
+ * 'spark.shuffle.maxTxConnections' enforces server-side cap. Ideally,
+ * maxTxConnections >= maxRxConnections * numReducersPerMachine
+ *
+ * 'spark.shuffle.TrackerStrategy' decides which strategy to use in the tracker
+ *
+ * TODO: Add support for compression when spark.compress is set to true.
+ */
+@serializable
+class TrackedCustomParallelLocalFileShuffle[K, V, C]
+extends Shuffle[K, V, C] with Logging {
+ @transient var totalSplits = 0
+ @transient var hasSplits = 0
+ @transient var hasSplitsBitVector: BitSet = null
+ @transient var splitsInRequestBitVector: BitSet = null
+
+ @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null
+ @transient var combiners: HashMap[K,C] = null
+
+ override def compute(input: RDD[(K, V)],
+ numOutputSplits: Int,
+ createCombiner: V => C,
+ mergeValue: (C, V) => C,
+ mergeCombiners: (C, C) => C)
+ : RDD[(K, C)] =
+ {
+ val sc = input.sparkContext
+ val shuffleId = TrackedCustomParallelLocalFileShuffle.newShuffleId()
+ logInfo("Shuffle ID: " + shuffleId)
+
+ val splitRdd = new NumberedSplitRDD(input)
+ val numInputSplits = splitRdd.splits.size
+
+ // Run a parallel map and collect to write the intermediate data files
+ val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => {
+ val myIndex = pair._1
+ val myIterator = pair._2
+ val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C])
+ for ((k, v) <- myIterator) {
+ var bucketId = k.hashCode % numOutputSplits
+ if (bucketId < 0) { // Fix bucket ID if hash code was negative
+ bucketId += numOutputSplits
+ }
+ val bucket = buckets(bucketId)
+ bucket(k) = bucket.get(k) match {
+ case Some(c) => mergeValue(c, v)
+ case None => createCombiner(v)
+ }
+ }
+
+ for (i <- 0 until numOutputSplits) {
+ val file = TrackedCustomParallelLocalFileShuffle.getOutputFile(shuffleId,
+ myIndex, i)
+ val writeStartTime = System.currentTimeMillis
+ logInfo("BEGIN WRITE: " + file)
+ val out = new ObjectOutputStream(new FileOutputStream(file))
+ buckets(i).foreach(pair => out.writeObject(pair))
+ out.close()
+ logInfo("END WRITE: " + file)
+ val writeTime = System.currentTimeMillis - writeStartTime
+ logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.")
+ }
+
+ (SplitInfo (TrackedCustomParallelLocalFileShuffle.serverAddress,
+ TrackedCustomParallelLocalFileShuffle.serverPort, myIndex))
+ }).collect()
+
+ // Start tracker
+ var shuffleTracker = new ShuffleTracker(outputLocs)
+ shuffleTracker.setDaemon(true)
+ shuffleTracker.start()
+ logInfo("ShuffleTracker started...")
+
+ // Return an RDD that does each of the merges for a given partition
+ val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits)
+ return indexes.flatMap((myId: Int) => {
+ totalSplits = outputLocs.size
+ hasSplits = 0
+ hasSplitsBitVector = new BitSet(totalSplits)
+ splitsInRequestBitVector = new BitSet(totalSplits)
+
+ receivedData = new LinkedBlockingQueue[(Int, Array[Byte])]
+ combiners = new HashMap[K, C]
+
+ var threadPool = Shuffle.newDaemonFixedThreadPool(
+ Shuffle.MaxRxConnections)
+
+ while (hasSplits < totalSplits) {
+ // Local status of hasSplitsBitVector and splitsInRequestBitVector
+ val localSplitInfo = getLocalSplitInfo(myId)
+
+ // DO NOT talk to the tracker if all the required splits are already busy
+ val hasOrWillHaveSplits = localSplitInfo.hasSplitsBitVector.cardinality
+
+ var numThreadsToCreate =
+ Math.min(totalSplits - hasOrWillHaveSplits, Shuffle.MaxRxConnections) -
+ threadPool.getActiveCount
+
+ while (hasSplits < totalSplits && numThreadsToCreate > 0) {
+ // Receive which split to pull from the tracker
+ logInfo("Talking to tracker...")
+ val startTime = System.currentTimeMillis
+ val splitIndices = getTrackerSelectedSplit(myId)
+ logInfo("Got %s from tracker in %d millis".format(splitIndices, System.currentTimeMillis - startTime))
+
+ if (splitIndices.size > 0) {
+ splitIndices.foreach { splitIndex =>
+ val selectedSplitInfo = outputLocs(splitIndex)
+ val requestSplit =
+ "%d/%d/%d".format(shuffleId, selectedSplitInfo.splitId, myId)
+
+ threadPool.execute(new ShuffleClient(splitIndex, selectedSplitInfo,
+ requestSplit, myId))
+
+ // splitIndex is in transit. Will be unset in the ShuffleClient
+ splitsInRequestBitVector.synchronized {
+ splitsInRequestBitVector.set(splitIndex)
+ }
+ }
+ } else {
+ // Tracker replied back with a NO. Sleep for a while.
+ Thread.sleep(Shuffle.MinKnockInterval)
+ numThreadsToCreate = 0
+ }
+
+ numThreadsToCreate = numThreadsToCreate - splitIndices.size
+ }
+
+ // Sleep for a while before creating new threads
+ Thread.sleep(Shuffle.MinKnockInterval)
+ }
+
+ threadPool.shutdown()
+
+ // Start consumer
+ // TODO: Consumption is delayed until everything has been received.
+ // Otherwise it interferes with network performance
+ var shuffleConsumer = new ShuffleConsumer(mergeCombiners)
+ shuffleConsumer.setDaemon(true)
+ shuffleConsumer.start()
+ logInfo("ShuffleConsumer started...")
+
+ // Don't return until consumption is finished
+ // while (receivedData.size > 0) {
+ // Thread.sleep(Shuffle.MinKnockInterval)
+ // }
+
+ // Wait till shuffleConsumer is done
+ shuffleConsumer.join
+
+ combiners
+ })
+ }
+
+ private def getLocalSplitInfo(myId: Int): SplitInfo = {
+ var localSplitInfo = SplitInfo(InetAddress.getLocalHost.getHostAddress,
+ SplitInfo.UnusedParam, myId)
+
+ localSplitInfo.hasSplits = hasSplits
+
+ hasSplitsBitVector.synchronized {
+ localSplitInfo.hasSplitsBitVector =
+ hasSplitsBitVector.clone.asInstanceOf[BitSet]
+ }
+
+ // Include the splitsInRequest as well
+ splitsInRequestBitVector.synchronized {
+ localSplitInfo.hasSplitsBitVector.or(splitsInRequestBitVector)
+ }
+
+ return localSplitInfo
+ }
+
+ // Selects a random split using local information
+ private def selectRandomSplit: Int = {
+ var requiredSplits = new ArrayBuffer[Int]
+
+ synchronized {
+ for (i <- 0 until totalSplits) {
+ if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) {
+ requiredSplits += i
+ }
+ }
+ }
+
+ if (requiredSplits.size > 0) {
+ requiredSplits(TrackedCustomParallelLocalFileShuffle.ranGen.nextInt(
+ requiredSplits.size))
+ } else {
+ -1
+ }
+ }
+
+ // Talks to the tracker and receives instruction
+ private def getTrackerSelectedSplit(myId: Int): ArrayBuffer[Int] = {
+ // Local status of hasSplitsBitVector and splitsInRequestBitVector
+ val localSplitInfo = getLocalSplitInfo(myId)
+
+ // DO NOT talk to the tracker if all the required splits are already busy
+ if (localSplitInfo.hasSplitsBitVector.cardinality == totalSplits) {
+ return ArrayBuffer[Int]()
+ }
+
+ val clientSocketToTracker = new Socket(Shuffle.MasterHostAddress,
+ Shuffle.MasterTrackerPort)
+ val oosTracker =
+ new ObjectOutputStream(clientSocketToTracker.getOutputStream)
+ oosTracker.flush()
+ val oisTracker =
+ new ObjectInputStream(clientSocketToTracker.getInputStream)
+
+ var selectedSplitIndices = ArrayBuffer[Int]()
+
+ // Setup the timeout mechanism
+ var timeOutTask = new TimerTask {
+ override def run: Unit = {
+ logInfo("Waited enough for tracker response... Take random response...")
+
+ // sockets will be closed in finally
+ // TODO: Sometimes timer wont go off
+
+ // TODO: Selecting randomly here. Tracker won't know about it and get an
+ // asssertion failure when this thread leaves
+
+ selectedSplitIndices = ArrayBuffer(selectRandomSplit)
+ }
+ }
+
+ var timeOutTimer = new Timer
+ // TODO: Which timeout to use?
+ // timeOutTimer.schedule(timeOutTask, Shuffle.MinKnockInterval)
+
+ try {
+ // Send intention
+ oosTracker.writeObject(Shuffle.ReducerEntering)
+ oosTracker.flush()
+
+ // Send what this reducer has
+ oosTracker.writeObject(localSplitInfo)
+ oosTracker.flush()
+
+ // Receive reply from the tracker
+ selectedSplitIndices = oisTracker.readObject.asInstanceOf[ArrayBuffer[Int]]
+
+ // Turn the timer OFF
+ timeOutTimer.cancel()
+ } catch {
+ case e: Exception => {
+ logInfo("getTrackerSelectedSplit had a " + e)
+ }
+ } finally {
+ oisTracker.close()
+ oosTracker.close()
+ clientSocketToTracker.close()
+ }
+
+ return selectedSplitIndices
+ }
+
+ class ShuffleTracker(outputLocs: Array[SplitInfo])
+ extends Thread with Logging {
+ var threadPool = Shuffle.newDaemonCachedThreadPool
+ var serverSocket: ServerSocket = null
+
+ // Create trackerStrategy object
+ val trackerStrategyClass = System.getProperty(
+ "spark.shuffle.trackerStrategy",
+ "spark.BalanceConnectionsShuffleTrackerStrategy")
+
+ val trackerStrategy =
+ Class.forName(trackerStrategyClass).newInstance().asInstanceOf[ShuffleTrackerStrategy]
+
+ // Must initialize here by supplying the outputLocs param
+ // TODO: This could be avoided by directly passing it to the constructor
+ trackerStrategy.initialize(outputLocs)
+
+ override def run: Unit = {
+ serverSocket = new ServerSocket(Shuffle.MasterTrackerPort)
+ logInfo("ShuffleTracker" + serverSocket)
+
+ try {
+ while (true) {
+ var clientSocket: Socket = null
+ try {
+ clientSocket = serverSocket.accept()
+ } catch {
+ case e: Exception => {
+ logInfo("ShuffleTracker had a " + e)
+ }
+ }
+
+ if (clientSocket != null) {
+ try {
+ threadPool.execute(new Thread {
+ override def run: Unit = {
+ val oos = new ObjectOutputStream(clientSocket.getOutputStream)
+ oos.flush()
+ val ois = new ObjectInputStream(clientSocket.getInputStream)
+
+ try {
+ // Receive intention
+ val reducerIntention = ois.readObject.asInstanceOf[Int]
+
+ if (reducerIntention == Shuffle.ReducerEntering) {
+ // Receive what the reducer has
+ val reducerSplitInfo =
+ ois.readObject.asInstanceOf[SplitInfo]
+
+ // Select split and update stats if necessary
+ var selectedSplitIndices = ArrayBuffer[Int]()
+ trackerStrategy.synchronized {
+ selectedSplitIndices = trackerStrategy.selectSplit(
+ reducerSplitInfo)
+ }
+
+ // Send reply back
+ oos.writeObject(selectedSplitIndices)
+ oos.flush()
+
+ // Update internal stats, only if receiver got the reply
+ trackerStrategy.synchronized {
+ trackerStrategy.AddReducerToSplit(reducerSplitInfo,
+ selectedSplitIndices)
+ }
+ }
+ else if (reducerIntention == Shuffle.ReducerLeaving) {
+ val reducerSplitInfo =
+ ois.readObject.asInstanceOf[SplitInfo]
+
+ // Receive reception stats: how many blocks the reducer
+ // read in how much time and from where
+ val receptionStat =
+ ois.readObject.asInstanceOf[ReceptionStats]
+
+ // Update stats
+ trackerStrategy.synchronized {
+ trackerStrategy.deleteReducerFrom(reducerSplitInfo,
+ receptionStat)
+ }
+
+ // Send ACK
+ oos.writeObject(receptionStat.serverSplitIndex)
+ oos.flush()
+ }
+ else {
+ throw new SparkException("Undefined reducerIntention")
+ }
+ } catch {
+ // EOFException is expected to happen because receiver can
+ // break connection due to timeout and pick random instead
+ case eofe: java.io.EOFException => { }
+ case e: Exception => {
+ logInfo("ShuffleTracker had a " + e)
+ }
+ } finally {
+ ois.close()
+ oos.close()
+ clientSocket.close()
+ }
+ }
+ })
+ } catch {
+ // In failure, close socket here; else, client thread will close
+ case ioe: IOException => {
+ clientSocket.close()
+ }
+ }
+ }
+ }
+ } finally {
+ serverSocket.close()
+ }
+ // Shutdown the thread pool
+ threadPool.shutdown()
+ }
+ }
+
+ class ShuffleConsumer(mergeCombiners: (C, C) => C)
+ extends Thread with Logging {
+ override def run: Unit = {
+ // Run until all splits are here
+ while (receivedData.size > 0) {
+ var splitIndex = -1
+ var recvByteArray: Array[Byte] = null
+
+ try {
+ var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])]
+ splitIndex = tempPair._1
+ recvByteArray = tempPair._2
+ } catch {
+ case e: Exception => {
+ logInfo("Exception during taking data from receivedData")
+ }
+ }
+
+ val inputStream =
+ new ObjectInputStream(new ByteArrayInputStream(recvByteArray))
+
+ try{
+ while (true) {
+ val (k, c) = inputStream.readObject.asInstanceOf[(K, C)]
+ combiners(k) = combiners.get(k) match {
+ case Some(oldC) => mergeCombiners(oldC, c)
+ case None => c
+ }
+ }
+ } catch {
+ case e: EOFException => { }
+ }
+ inputStream.close()
+ }
+ }
+ }
+
+ class ShuffleClient(splitIndex: Int, serversplitInfo: SplitInfo,
+ requestSplit: String, myId: Int)
+ extends Thread with Logging {
+ private var peerSocketToSource: Socket = null
+ private var oosSource: ObjectOutputStream = null
+ private var oisSource: ObjectInputStream = null
+
+ private var receptionSucceeded = false
+
+ // Make sure that multiple messages don't go to the tracker
+ private var alreadySentLeavingNotification = false
+
+ // Keep track of bytes received and time spent
+ private var numBytesReceived = 0
+ private var totalTimeSpent = 0
+
+ override def run: Unit = {
+ // Setup the timeout mechanism
+ var timeOutTask = new TimerTask {
+ override def run: Unit = {
+ cleanUp()
+ }
+ }
+
+ var timeOutTimer = new Timer
+ timeOutTimer.schedule(timeOutTask, Shuffle.MaxKnockInterval)
+
+ // Create a temp variable to be used in different places
+ val requestPath = "http://%s:%d/shuffle/%s".format(
+ serversplitInfo.hostAddress, serversplitInfo.listenPort, requestSplit)
+
+ logInfo("ShuffleClient started... => " + requestPath)
+
+ try {
+ // Connect to the source
+ peerSocketToSource = new Socket(
+ serversplitInfo.hostAddress, serversplitInfo.listenPort)
+ oosSource =
+ new ObjectOutputStream(peerSocketToSource.getOutputStream)
+ oosSource.flush()
+ var isSource = peerSocketToSource.getInputStream
+ oisSource = new ObjectInputStream(isSource)
+
+ // Send the request
+ oosSource.writeObject(requestSplit)
+ oosSource.flush()
+
+ // Receive the length of the requested file
+ var requestedFileLen = oisSource.readObject.asInstanceOf[Int]
+ logInfo("Received requestedFileLen = " + requestedFileLen)
+
+ // Turn the timer OFF, if the sender responds before timeout
+ timeOutTimer.cancel()
+
+ // Receive the file
+ if (requestedFileLen != -1) {
+ val readStartTime = System.currentTimeMillis
+ logInfo("BEGIN READ: " + requestPath)
+
+ // Receive data in an Array[Byte]
+ var recvByteArray = new Array[Byte](requestedFileLen)
+ var alreadyRead = 0
+ var bytesRead = 0
+
+ while (alreadyRead < requestedFileLen) {
+ bytesRead = isSource.read(recvByteArray, alreadyRead,
+ requestedFileLen - alreadyRead)
+ if (bytesRead > 0) {
+ alreadyRead = alreadyRead + bytesRead
+ }
+ }
+
+ // Make it available to the consumer
+ try {
+ receivedData.put((splitIndex, recvByteArray))
+ } catch {
+ case e: Exception => {
+ logInfo("Exception during putting data into receivedData")
+ }
+ }
+
+ // TODO: Updating stats before consumption is completed
+ hasSplitsBitVector.synchronized {
+ hasSplitsBitVector.set(splitIndex)
+ hasSplits += 1
+ }
+
+ // We have received splitIndex
+ splitsInRequestBitVector.synchronized {
+ splitsInRequestBitVector.set(splitIndex, false)
+ }
+
+ receptionSucceeded = true
+
+ logInfo("END READ: " + requestPath)
+ val readTime = System.currentTimeMillis - readStartTime
+ logInfo("Reading " + requestPath + " took " + readTime + " millis.")
+
+ // Update stats
+ numBytesReceived = numBytesReceived + requestedFileLen
+ totalTimeSpent = totalTimeSpent + readTime.toInt
+ } else {
+ throw new SparkException("ShuffleServer " + serversplitInfo.hostAddress + " does not have " + requestSplit)
+ }
+ } catch {
+ // EOFException is expected to happen because sender can break
+ // connection due to timeout
+ case eofe: java.io.EOFException => { }
+ case e: Exception => {
+ logInfo("ShuffleClient had a " + e)
+ }
+ } finally {
+ // If reception failed, unset for future retry
+ if (!receptionSucceeded) {
+ splitsInRequestBitVector.synchronized {
+ splitsInRequestBitVector.set(splitIndex, false)
+ }
+ }
+ cleanUp()
+ }
+ }
+
+ // Connect to the tracker and update its stats
+ private def sendLeavingNotification(): Unit = synchronized {
+ if (!alreadySentLeavingNotification) {
+ val clientSocketToTracker = new Socket(Shuffle.MasterHostAddress,
+ Shuffle.MasterTrackerPort)
+ val oosTracker =
+ new ObjectOutputStream(clientSocketToTracker.getOutputStream)
+ oosTracker.flush()
+ val oisTracker =
+ new ObjectInputStream(clientSocketToTracker.getInputStream)
+
+ try {
+ // Send intention
+ oosTracker.writeObject(Shuffle.ReducerLeaving)
+ oosTracker.flush()
+
+ // Send reducerSplitInfo
+ oosTracker.writeObject(getLocalSplitInfo(myId))
+ oosTracker.flush()
+
+ // Send reception stats
+ oosTracker.writeObject(ReceptionStats(
+ numBytesReceived, totalTimeSpent, splitIndex))
+ oosTracker.flush()
+
+ // Receive ACK. No need to do anything with that
+ oisTracker.readObject.asInstanceOf[Int]
+
+ // Now update sentLeavingNotifacation
+ alreadySentLeavingNotification = true
+ } catch {
+ case e: Exception => {
+ logInfo("sendLeavingNotification had a " + e)
+ }
+ } finally {
+ oisTracker.close()
+ oosTracker.close()
+ clientSocketToTracker.close()
+ }
+ }
+ }
+
+ private def cleanUp(): Unit = {
+ // Update tracker stats first
+ sendLeavingNotification()
+
+ // Clean up the connections to the mapper
+ if (oisSource != null) {
+ oisSource.close()
+ }
+ if (oosSource != null) {
+ oosSource.close()
+ }
+ if (peerSocketToSource != null) {
+ peerSocketToSource.close()
+ }
+ }
+ }
+}
+
+object TrackedCustomParallelLocalFileShuffle extends Logging {
+ private var initialized = false
+ private var nextShuffleId = new AtomicLong(0)
+
+ // Variables initialized by initializeIfNeeded()
+ private var shuffleDir: File = null
+
+ private var shuffleServer: ShuffleServer = null
+ private var serverAddress = InetAddress.getLocalHost.getHostAddress
+ private var serverPort: Int = -1
+
+ // Random number generator
+ var ranGen = new Random
+
+ private def initializeIfNeeded() = synchronized {
+ if (!initialized) {
+ // TODO: localDir should be created by some mechanism common to Spark
+ // so that it can be shared among shuffle, broadcast, etc
+ val localDirRoot = System.getProperty("spark.local.dir", "/tmp")
+ var tries = 0
+ var foundLocalDir = false
+ var localDir: File = null
+ var localDirUuid: UUID = null
+
+ while (!foundLocalDir && tries < 10) {
+ tries += 1
+ try {
+ localDirUuid = UUID.randomUUID
+ localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
+ if (!localDir.exists) {
+ localDir.mkdirs()
+ foundLocalDir = true
+ }
+ } catch {
+ case e: Exception =>
+ logWarning("Attempt " + tries + " to create local dir failed", e)
+ }
+ }
+ if (!foundLocalDir) {
+ logError("Failed 10 attempts to create local dir in " + localDirRoot)
+ System.exit(1)
+ }
+ shuffleDir = new File(localDir, "shuffle")
+ shuffleDir.mkdirs()
+ logInfo("Shuffle dir: " + shuffleDir)
+
+ // Create and start the shuffleServer
+ shuffleServer = new ShuffleServer
+ shuffleServer.setDaemon(true)
+ shuffleServer.start()
+ logInfo("ShuffleServer started...")
+
+ initialized = true
+ }
+ }
+
+ def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int): File = {
+ initializeIfNeeded()
+ val dir = new File(shuffleDir, shuffleId + "/" + inputId)
+ dir.mkdirs()
+ val file = new File(dir, "" + outputId)
+ return file
+ }
+
+ def newShuffleId(): Long = {
+ nextShuffleId.getAndIncrement()
+ }
+
+ class ShuffleServer
+ extends Thread with Logging {
+ var threadPool = Shuffle.newDaemonFixedThreadPool(Shuffle.MaxTxConnections)
+
+ var serverSocket: ServerSocket = null
+
+ override def run: Unit = {
+ serverSocket = new ServerSocket(0)
+ serverPort = serverSocket.getLocalPort
+
+ logInfo("ShuffleServer started with " + serverSocket)
+ logInfo("Local URI: http://" + serverAddress + ":" + serverPort)
+
+ try {
+ while (true) {
+ var clientSocket: Socket = null
+ try {
+ clientSocket = serverSocket.accept()
+ } catch {
+ case e: Exception => { }
+ }
+ if (clientSocket != null) {
+ logInfo("Serve: Accepted new client connection:" + clientSocket)
+ try {
+ threadPool.execute(new ShuffleServerThread(clientSocket))
+ } catch {
+ // In failure, close socket here; else, the thread will close it
+ case ioe: IOException => {
+ clientSocket.close()
+ }
+ }
+ }
+ }
+ } finally {
+ if (serverSocket != null) {
+ logInfo("ShuffleServer now stopping...")
+ serverSocket.close()
+ }
+ }
+ // Shutdown the thread pool
+ threadPool.shutdown()
+ }
+
+ class ShuffleServerThread(val clientSocket: Socket)
+ extends Thread with Logging {
+ private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream]
+ os.flush()
+ private val bos = new BufferedOutputStream(os)
+ bos.flush()
+ private val oos = new ObjectOutputStream(os)
+ oos.flush()
+ private val ois = new ObjectInputStream(clientSocket.getInputStream)
+
+ logInfo("new ShuffleServerThread is running")
+
+ override def run: Unit = {
+ try {
+ // Receive requestPath from the receiver
+ var requestPath = ois.readObject.asInstanceOf[String]
+ logInfo("requestPath: " + shuffleDir + "/" + requestPath)
+
+ // Open the file
+ var requestedFile: File = null
+ var requestedFileLen = -1
+ try {
+ requestedFile = new File(shuffleDir + "/" + requestPath)
+ requestedFileLen = requestedFile.length.toInt
+ } catch {
+ case e: Exception => { }
+ }
+
+ // Send the length of the requestPath to let the receiver know that
+ // transfer is about to start
+ // In the case of receiver timeout and connection close, this will
+ // throw a java.net.SocketException: Broken pipe
+ oos.writeObject(requestedFileLen)
+ oos.flush()
+
+ logInfo("requestedFileLen = " + requestedFileLen)
+
+ // Read and send the requested file
+ if (requestedFileLen != -1) {
+ // Read
+ var byteArray = new Array[Byte](requestedFileLen)
+ val bis =
+ new BufferedInputStream(new FileInputStream(requestedFile))
+
+ var bytesRead = bis.read(byteArray, 0, byteArray.length)
+ var alreadyRead = bytesRead
+
+ while (alreadyRead < requestedFileLen) {
+ bytesRead = bis.read(byteArray, alreadyRead,
+ (byteArray.length - alreadyRead))
+ if(bytesRead > 0) {
+ alreadyRead = alreadyRead + bytesRead
+ }
+ }
+ bis.close()
+
+ // Send
+ bos.write(byteArray, 0, byteArray.length)
+ bos.flush()
+ } else {
+ // Close the connection
+ }
+ } catch {
+ // If something went wrong, e.g., the worker at the other end died etc
+ // then close everything up
+ // Exception can happen if the receiver stops receiving
+ case e: Exception => {
+ logInfo("ShuffleServerThread had a " + e)
+ }
+ } finally {
+ logInfo("ShuffleServerThread is closing streams and sockets")
+ ois.close()
+ // TODO: Following can cause "java.net.SocketException: Socket closed"
+ oos.close()
+ bos.close()
+ clientSocket.close()
+ }
+ }
+ }
+ }
+}