diff options
author | Mosharaf Chowdhury <mosharaf@cs.berkeley.edu> | 2011-04-27 20:47:07 -0700 |
---|---|---|
committer | Mosharaf Chowdhury <mosharaf@cs.berkeley.edu> | 2011-04-27 20:47:07 -0700 |
commit | 9d78779257b156bec335af4ab2a66bb3cac30ca6 (patch) | |
tree | 738420e4f3653510e2803b6cb6764c261feeb8c8 | |
parent | 4e4c41026c33d9f8fe75137f32266de75a0aa30e (diff) | |
parent | ac7e066383a6878beb0618597c2be6fa9eb1982e (diff) | |
download | spark-9d78779257b156bec335af4ab2a66bb3cac30ca6.tar.gz spark-9d78779257b156bec335af4ab2a66bb3cac30ca6.tar.bz2 spark-9d78779257b156bec335af4ab2a66bb3cac30ca6.zip |
Merge branch 'mos-shuffle-tracked' into mos-bt
Conflicts:
core/src/main/scala/spark/Broadcast.scala
21 files changed, 7608 insertions, 12 deletions
diff --git a/conf/java-opts b/conf/java-opts new file mode 100644 index 0000000000..1a598061f0 --- /dev/null +++ b/conf/java-opts @@ -0,0 +1,14 @@ +-Dspark.shuffle.class=spark.CustomBlockedInMemoryShuffle +-Dspark.shuffle.masterHostAddress=127.0.0.1 +-Dspark.shuffle.masterTrackerPort=22222 +-Dspark.shuffle.trackerStrategy=spark.BalanceRemainingShuffleTrackerStrategy +-Dspark.shuffle.maxRxConnections=40 +-Dspark.shuffle.maxTxConnections=120 +-Dspark.shuffle.blockSize=4096 +-Dspark.shuffle.minKnockInterval=100 +-Dspark.shuffle.maxKnockInterval=5000 +-Dspark.shuffle.maxChatTime=500 +-Dspark.shuffle.throttleFraction=2.0 +-verbose:gc +-XX:+PrintGCTimeStamps +-XX:+PrintGCDetails diff --git a/conf/log4j.properties b/conf/log4j.properties new file mode 100644 index 0000000000..33774b463d --- /dev/null +++ b/conf/log4j.properties @@ -0,0 +1,8 @@ +# Set everything to be logged to the console +log4j.rootCategory=INFO, console +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.eclipse.jetty=WARN diff --git a/conf/spark-env.sh b/conf/spark-env.sh new file mode 100755 index 0000000000..5f6c8269e8 --- /dev/null +++ b/conf/spark-env.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash + +# Set Spark environment variables for your site in this file. Some useful +# variables to set are: +# - MESOS_HOME, to point to your Mesos installation +# - SCALA_HOME, to point to your Scala installation +# - SPARK_CLASSPATH, to add elements to Spark's classpath +# - SPARK_JAVA_OPTS, to add JVM options +# - SPARK_MEM, to change the amount of memory used per node (this should +# be in the same format as the JVM's -Xmx option, e.g. 300m or 1g). +# - SPARK_LIBRARY_PATH, to add extra search paths for native libraries. + +MESOS_HOME=/Users/mosharaf/Work/mesos diff --git a/core/src/main/scala/spark/LocalFileShuffle.scala b/core/src/main/scala/spark/BasicLocalFileShuffle.scala index 367599cfb4..3c3f132083 100644 --- a/core/src/main/scala/spark/LocalFileShuffle.scala +++ b/core/src/main/scala/spark/BasicLocalFileShuffle.scala @@ -7,14 +7,13 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable.{ArrayBuffer, HashMap} - /** - * A simple implementation of shuffle using local files served through HTTP. + * A basic implementation of shuffle using local files served through HTTP. * * TODO: Add support for compression when spark.compress is set to true. */ @serializable -class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { +class BasicLocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { override def compute(input: RDD[(K, V)], numOutputSplits: Int, createCombiner: V => C, @@ -23,7 +22,7 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { : RDD[(K, C)] = { val sc = input.sparkContext - val shuffleId = LocalFileShuffle.newShuffleId() + val shuffleId = BasicLocalFileShuffle.newShuffleId() logInfo("Shuffle ID: " + shuffleId) val splitRdd = new NumberedSplitRDD(input) @@ -46,13 +45,20 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { case None => createCombiner(v) } } + for (i <- 0 until numOutputSplits) { - val file = LocalFileShuffle.getOutputFile(shuffleId, myIndex, i) + val file = BasicLocalFileShuffle.getOutputFile(shuffleId, myIndex, i) + val writeStartTime = System.currentTimeMillis + logInfo("BEGIN WRITE: " + file) val out = new ObjectOutputStream(new FileOutputStream(file)) buckets(i).foreach(pair => out.writeObject(pair)) out.close() + logInfo("END WRITE: " + file) + val writeTime = (System.currentTimeMillis - writeStartTime) + logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.") } - (myIndex, LocalFileShuffle.serverUri) + + (myIndex, BasicLocalFileShuffle.serverUri) }).collect() // Build a hashmap from server URI to list of splits (to facillitate @@ -71,6 +77,8 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { for ((serverUri, inputIds) <- Utils.shuffle(splitsByUri)) { for (i <- inputIds) { val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, myId) + val readStartTime = System.currentTimeMillis + logInfo("BEGIN READ: " + url) val inputStream = new ObjectInputStream(new URL(url).openStream()) try { while (true) { @@ -84,6 +92,9 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { case e: EOFException => {} } inputStream.close() + logInfo("END READ: " + url) + val readTime = System.currentTimeMillis - readStartTime + logInfo("Reading " + url + " took " + readTime + " millis.") } } combiners @@ -91,8 +102,7 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { } } - -object LocalFileShuffle extends Logging { +object BasicLocalFileShuffle extends Logging { private var initialized = false private var nextShuffleId = new AtomicLong(0) @@ -113,9 +123,9 @@ object LocalFileShuffle extends Logging { while (!foundLocalDir && tries < 10) { tries += 1 try { - localDirUuid = UUID.randomUUID() + localDirUuid = UUID.randomUUID localDir = new File(localDirRoot, "spark-local-" + localDirUuid) - if (!localDir.exists()) { + if (!localDir.exists) { localDir.mkdirs() foundLocalDir = true } @@ -131,6 +141,7 @@ object LocalFileShuffle extends Logging { shuffleDir = new File(localDir, "shuffle") shuffleDir.mkdirs() logInfo("Shuffle dir: " + shuffleDir) + val extServerPort = System.getProperty( "spark.localFileShuffle.external.server.port", "-1").toInt if (extServerPort != -1) { @@ -149,6 +160,7 @@ object LocalFileShuffle extends Logging { serverUri = server.uri } initialized = true + logInfo("Local URI: " + serverUri) } } diff --git a/core/src/main/scala/spark/CustomBlockedInMemoryShuffle.scala b/core/src/main/scala/spark/CustomBlockedInMemoryShuffle.scala new file mode 100644 index 0000000000..5aef43f302 --- /dev/null +++ b/core/src/main/scala/spark/CustomBlockedInMemoryShuffle.scala @@ -0,0 +1,629 @@ +package spark + +import java.io._ +import java.net._ +import java.util.{BitSet, Random, Timer, TimerTask, UUID} +import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory} + +import scala.collection.mutable.{ArrayBuffer, HashMap} + +/** + * TODO: THIS IS AN ABSOLUTELY EXPERIMENTAL IMPLEMENTATON (FOR NOW). + * + * An implementation of shuffle using local memory served through custom server + * where receivers create simultaneous connections to multiple servers by + * setting the 'spark.shuffle.maxRxConnections' config option. + * + * By controlling the 'spark.shuffle.blockSize' config option one can also + * control the largest block size to divide each map output into. Essentially, + * instead of creating one large output file for each reducer, maps create + * multiple smaller files to enable finer level of engagement. + * + * 'spark.shuffle.maxTxConnections' enforces server-side cap. Ideally, + * maxTxConnections >= maxRxConnections * numReducersPerMachine + * + * TODO: Add support for compression when spark.compress is set to true. + */ +@serializable +class CustomBlockedInMemoryShuffle[K, V, C] +extends Shuffle[K, V, C] with Logging { + @transient var totalSplits = 0 + @transient var hasSplits = 0 + + @transient var totalBlocksInSplit: Array[Int] = null + @transient var hasBlocksInSplit: Array[Int] = null + + @transient var hasSplitsBitVector: BitSet = null + @transient var splitsInRequestBitVector: BitSet = null + + @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null + @transient var combiners: HashMap[K,C] = null + + override def compute(input: RDD[(K, V)], + numOutputSplits: Int, + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C) + : RDD[(K, C)] = + { + val sc = input.sparkContext + val shuffleId = CustomBlockedInMemoryShuffle.newShuffleId() + logInfo("Shuffle ID: " + shuffleId) + + val splitRdd = new NumberedSplitRDD(input) + val numInputSplits = splitRdd.splits.size + + // Run a parallel map and collect to write the intermediate data files, + // returning a list of inputSplitId -> serverUri pairs + val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => { + val myIndex = pair._1 + val myIterator = pair._2 + val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C]) + for ((k, v) <- myIterator) { + var bucketId = k.hashCode % numOutputSplits + if (bucketId < 0) { // Fix bucket ID if hash code was negative + bucketId += numOutputSplits + } + val bucket = buckets(bucketId) + bucket(k) = bucket.get(k) match { + case Some(c) => mergeValue(c, v) + case None => createCombiner(v) + } + } + + for (i <- 0 until numOutputSplits) { + var blockNum = 0 + var isDirty = false + + var splitName = "" + var baos: ByteArrayOutputStream = null + var oos: ObjectOutputStream = null + + var writeStartTime: Long = 0 + + buckets(i).foreach(pair => { + // Open a new stream if necessary + if (!isDirty) { + splitName = CustomBlockedInMemoryShuffle.getSplitName(shuffleId, + myIndex, i, blockNum) + + baos = new ByteArrayOutputStream + oos = new ObjectOutputStream(baos) + + writeStartTime = System.currentTimeMillis + logInfo("BEGIN WRITE: " + splitName) + } + + oos.writeObject(pair) + isDirty = true + + // Close the old stream if has crossed the blockSize limit + if (baos.size > Shuffle.BlockSize) { + CustomBlockedInMemoryShuffle.splitsCache(splitName) = + baos.toByteArray + + logInfo("END WRITE: " + splitName) + val writeTime = System.currentTimeMillis - writeStartTime + logInfo("Writing " + splitName + " of size " + baos.size + " bytes took " + writeTime + " millis.") + + blockNum = blockNum + 1 + isDirty = false + oos.close() + } + }) + + if (isDirty) { + CustomBlockedInMemoryShuffle.splitsCache(splitName) = baos.toByteArray + + logInfo("END WRITE: " + splitName) + val writeTime = System.currentTimeMillis - writeStartTime + logInfo("Writing " + splitName + " of size " + baos.size + " bytes took " + writeTime + " millis.") + + blockNum = blockNum + 1 + oos.close() + } + + // Store BLOCKNUM info + splitName = CustomBlockedInMemoryShuffle.getBlockNumOutputName( + shuffleId, myIndex, i) + baos = new ByteArrayOutputStream + oos = new ObjectOutputStream(baos) + oos.writeObject(blockNum) + CustomBlockedInMemoryShuffle.splitsCache(splitName) = baos.toByteArray + + // Close streams + oos.close() + } + + (myIndex, CustomBlockedInMemoryShuffle.serverAddress, + CustomBlockedInMemoryShuffle.serverPort) + }).collect() + + val splitsByUri = new ArrayBuffer[(String, Int, Int)] + for ((inputId, serverAddress, serverPort) <- outputLocs) { + splitsByUri += ((serverAddress, serverPort, inputId)) + } + + // TODO: Could broadcast outputLocs + + // Return an RDD that does each of the merges for a given partition + val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits) + return indexes.flatMap((myId: Int) => { + totalSplits = outputLocs.size + hasSplits = 0 + + totalBlocksInSplit = Array.tabulate(totalSplits)(_ => -1) + hasBlocksInSplit = Array.tabulate(totalSplits)(_ => 0) + + hasSplitsBitVector = new BitSet(totalSplits) + splitsInRequestBitVector = new BitSet(totalSplits) + + receivedData = new LinkedBlockingQueue[(Int, Array[Byte])] + combiners = new HashMap[K, C] + + var threadPool = Shuffle.newDaemonFixedThreadPool( + Shuffle.MaxRxConnections) + + while (hasSplits < totalSplits) { + var numThreadsToCreate = + Math.min(totalSplits, Shuffle.MaxRxConnections) - + threadPool.getActiveCount + + while (hasSplits < totalSplits && numThreadsToCreate > 0) { + // Select a random split to pull + val splitIndex = selectRandomSplit + + if (splitIndex != -1) { + val (serverAddress, serverPort, inputId) = splitsByUri(splitIndex) + + threadPool.execute(new ShuffleClient(serverAddress, serverPort, + shuffleId.toInt, inputId, myId, splitIndex)) + + // splitIndex is in transit. Will be unset in the ShuffleClient + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set(splitIndex) + } + } + + numThreadsToCreate = numThreadsToCreate - 1 + } + + // Sleep for a while before creating new threads + Thread.sleep(Shuffle.MinKnockInterval) + } + + threadPool.shutdown() + + // Start consumer + // TODO: Consumption is delayed until everything has been received. + // Otherwise it interferes with network performance + var shuffleConsumer = new ShuffleConsumer(mergeCombiners) + shuffleConsumer.setDaemon(true) + shuffleConsumer.start() + logInfo("ShuffleConsumer started...") + + // Don't return until consumption is finished + // while (receivedData.size > 0) { + // Thread.sleep(Shuffle.MinKnockInterval) + // } + + // Wait till shuffleConsumer is done + shuffleConsumer.join + + combiners + }) + } + + def selectRandomSplit: Int = { + var requiredSplits = new ArrayBuffer[Int] + + synchronized { + for (i <- 0 until totalSplits) { + if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) { + requiredSplits += i + } + } + } + + if (requiredSplits.size > 0) { + requiredSplits(CustomBlockedInMemoryShuffle.ranGen.nextInt( + requiredSplits.size)) + } else { + -1 + } + } + + class ShuffleConsumer(mergeCombiners: (C, C) => C) + extends Thread with Logging { + override def run: Unit = { + // Run until all splits are here + while (receivedData.size > 0) { + var splitIndex = -1 + var recvByteArray: Array[Byte] = null + + try { + var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])] + splitIndex = tempPair._1 + recvByteArray = tempPair._2 + } catch { + case e: Exception => { + logInfo("Exception during taking data from receivedData") + } + } + + val inputStream = + new ObjectInputStream(new ByteArrayInputStream(recvByteArray)) + + try{ + while (true) { + val (k, c) = inputStream.readObject.asInstanceOf[(K, C)] + combiners(k) = combiners.get(k) match { + case Some(oldC) => mergeCombiners(oldC, c) + case None => c + } + } + } catch { + case e: EOFException => { } + } + inputStream.close() + } + } + } + + class ShuffleClient(hostAddress: String, listenPort: Int, shuffleId: Int, + inputId: Int, myId: Int, splitIndex: Int) + extends Thread with Logging { + private var peerSocketToSource: Socket = null + private var oosSource: ObjectOutputStream = null + private var oisSource: ObjectInputStream = null + + private var receptionSucceeded = false + + override def run: Unit = { + // Setup the timeout mechanism + var timeOutTask = new TimerTask { + override def run: Unit = { + cleanUpConnections() + } + } + + var timeOutTimer = new Timer + timeOutTimer.schedule(timeOutTask, Shuffle.MaxKnockInterval) + + try { + // Everything will break if BLOCKNUM is not correctly received + // First get BLOCKNUM file if totalBlocksInSplit(splitIndex) is unknown + peerSocketToSource = new Socket(hostAddress, listenPort) + oosSource = + new ObjectOutputStream(peerSocketToSource.getOutputStream) + oosSource.flush() + var isSource = peerSocketToSource.getInputStream + oisSource = new ObjectInputStream(isSource) + + // Send path information + oosSource.writeObject((shuffleId, inputId, myId)) + + // TODO: Can be optimized. No need to do it everytime. + // Receive BLOCKNUM + totalBlocksInSplit(splitIndex) = oisSource.readObject.asInstanceOf[Int] + + // Turn the timer OFF, if the sender responds before timeout + timeOutTimer.cancel() + + while (hasBlocksInSplit(splitIndex) < totalBlocksInSplit(splitIndex)) { + // Set receptionSucceeded to false before trying for each block + receptionSucceeded = false + + // Request specific block + oosSource.writeObject(hasBlocksInSplit(splitIndex)) + + // Good to go. First, receive the length of the requested file + var requestedFileLen = oisSource.readObject.asInstanceOf[Int] + logInfo("Received requestedFileLen = " + requestedFileLen) + + val requestSplit = "%d/%d/%d-%d".format(shuffleId, inputId, myId, + hasBlocksInSplit(splitIndex)) + + // Receive the file + if (requestedFileLen != -1) { + val readStartTime = System.currentTimeMillis + logInfo("BEGIN READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit)) + + // Receive data in an Array[Byte] + var recvByteArray = new Array[Byte](requestedFileLen) + var alreadyRead = 0 + var bytesRead = 0 + + while (alreadyRead != requestedFileLen) { + bytesRead = isSource.read(recvByteArray, alreadyRead, + requestedFileLen - alreadyRead) + if (bytesRead > 0) { + alreadyRead = alreadyRead + bytesRead + } + } + + // Make it available to the consumer + try { + receivedData.put((splitIndex, recvByteArray)) + } catch { + case e: Exception => { + logInfo("Exception during putting data into receivedData") + } + } + + // TODO: Updating stats before consumption is completed + hasBlocksInSplit(splitIndex) = hasBlocksInSplit(splitIndex) + 1 + + // Split has been received only if all the blocks have been received + if (hasBlocksInSplit(splitIndex) == totalBlocksInSplit(splitIndex)) { + hasSplitsBitVector.synchronized { + hasSplitsBitVector.set(splitIndex) + hasSplits += 1 + } + } + + // Consistent state in accounting variables + receptionSucceeded = true + + logInfo("END READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit)) + val readTime = System.currentTimeMillis - readStartTime + logInfo("Reading http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit) + " took " + readTime + " millis.") + } else { + throw new SparkException("ShuffleServer " + hostAddress + " does not have " + requestSplit) + } + } + } catch { + // EOFException is expected to happen because sender can break + // connection due to timeout + case eofe: java.io.EOFException => { } + case e: Exception => { + logInfo("ShuffleClient had a " + e) + } + } finally { + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set(splitIndex, false) + } + cleanUpConnections() + } + } + + private def cleanUpConnections(): Unit = { + if (oisSource != null) { + oisSource.close() + } + if (oosSource != null) { + oosSource.close() + } + if (peerSocketToSource != null) { + peerSocketToSource.close() + } + } + } +} + +object CustomBlockedInMemoryShuffle extends Logging { + // Cache for keeping the splits around + val splitsCache = new HashMap[String, Array[Byte]] + + private var initialized = false + private var nextShuffleId = new AtomicLong(0) + + // Variables initialized by initializeIfNeeded() + private var shuffleDir: File = null + + private var shuffleServer: ShuffleServer = null + private var serverAddress = InetAddress.getLocalHost.getHostAddress + private var serverPort: Int = -1 + + // Random number generator + var ranGen = new Random + + private def initializeIfNeeded() = synchronized { + if (!initialized) { + // TODO: localDir should be created by some mechanism common to Spark + // so that it can be shared among shuffle, broadcast, etc + val localDirRoot = System.getProperty("spark.local.dir", "/tmp") + var tries = 0 + var foundLocalDir = false + var localDir: File = null + var localDirUuid: UUID = null + while (!foundLocalDir && tries < 10) { + tries += 1 + try { + localDirUuid = UUID.randomUUID + localDir = new File(localDirRoot, "spark-local-" + localDirUuid) + if (!localDir.exists) { + localDir.mkdirs() + foundLocalDir = true + } + } catch { + case e: Exception => + logWarning("Attempt " + tries + " to create local dir failed", e) + } + } + if (!foundLocalDir) { + logError("Failed 10 attempts to create local dir in " + localDirRoot) + System.exit(1) + } + shuffleDir = new File(localDir, "shuffle") + shuffleDir.mkdirs() + logInfo("Shuffle dir: " + shuffleDir) + + // Create and start the shuffleServer + shuffleServer = new ShuffleServer + shuffleServer.setDaemon(true) + shuffleServer.start() + logInfo("ShuffleServer started...") + + initialized = true + } + } + + def getSplitName(shuffleId: Long, inputId: Int, outputId: Int, + blockId: Int): String = { + initializeIfNeeded() + // Adding shuffleDir is unnecessary. Added to keep the parsers working + return "%s/%d/%d/%d-%d".format(shuffleDir, shuffleId, inputId, outputId, + blockId) + } + + def getBlockNumOutputName(shuffleId: Long, inputId: Int, + outputId: Int): String = { + initializeIfNeeded() + return "%s/%d/%d/BLOCKNUM-%d".format(shuffleDir, shuffleId, inputId, + outputId) + } + + def newShuffleId(): Long = { + nextShuffleId.getAndIncrement() + } + + class ShuffleServer + extends Thread with Logging { + var threadPool = Shuffle.newDaemonFixedThreadPool(Shuffle.MaxTxConnections) + + var serverSocket: ServerSocket = null + + override def run: Unit = { + serverSocket = new ServerSocket(0) + serverPort = serverSocket.getLocalPort + + logInfo("ShuffleServer started with " + serverSocket) + logInfo("Local URI: http://" + serverAddress + ":" + serverPort) + + try { + while (true) { + var clientSocket: Socket = null + try { + clientSocket = serverSocket.accept() + } catch { + case e: Exception => { } + } + if (clientSocket != null) { + logInfo("Serve: Accepted new client connection:" + clientSocket) + try { + threadPool.execute(new ShuffleServerThread(clientSocket)) + } catch { + // In failure, close socket here; else, the thread will close it + case ioe: IOException => { + clientSocket.close() + } + } + } + } + } finally { + if (serverSocket != null) { + logInfo("ShuffleServer now stopping...") + serverSocket.close() + } + } + // Shutdown the thread pool + threadPool.shutdown() + } + + class ShuffleServerThread(val clientSocket: Socket) + extends Thread with Logging { + private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream] + os.flush() + private val bos = new BufferedOutputStream(os) + bos.flush() + private val oos = new ObjectOutputStream(os) + oos.flush() + private val ois = new ObjectInputStream(clientSocket.getInputStream) + + logInfo("new ShuffleServerThread is running") + + override def run: Unit = { + try { + // Receive basic path information + val (shuffleId, myIndex, outputId) = + ois.readObject.asInstanceOf[(Int, Int, Int)] + + var requestedSplitBase = "%s/%d/%d/%d".format( + shuffleDir, shuffleId, myIndex, outputId) + logInfo("requestedSplitBase: " + requestedSplitBase) + + // Read BLOCKNUM and send back the total number of blocks + val blockNumName = "%s/%d/%d/BLOCKNUM-%d".format(shuffleDir, + shuffleId, myIndex, outputId) + + val blockNumIn = new ObjectInputStream(new ByteArrayInputStream( + CustomBlockedInMemoryShuffle.splitsCache(blockNumName))) + val BLOCKNUM = blockNumIn.readObject.asInstanceOf[Int] + blockNumIn.close() + + oos.writeObject(BLOCKNUM) + + val startTime = System.currentTimeMillis + var curTime = startTime + var keepSending = true + var numBlocksToSend = Shuffle.MaxChatBlocks + + while (keepSending && numBlocksToSend > 0) { + // Receive specific block request + val blockId = ois.readObject.asInstanceOf[Int] + + // Ready to send + var requestedSplit = requestedSplitBase + "-" + blockId + + // Send the length of the requestedSplit to let the receiver know that + // transfer is about to start + // In the case of receiver timeout and connection close, this will + // throw a java.net.SocketException: Broken pipe + var requestedSplitLen = -1 + + try { + requestedSplitLen = + CustomBlockedInMemoryShuffle.splitsCache(requestedSplit).length + } catch { + case e: Exception => { } + } + + oos.writeObject(requestedSplitLen) + oos.flush() + + logInfo("requestedSplitLen = " + requestedSplitLen) + + // Read and send the requested file + if (requestedSplitLen != -1) { + // Send + bos.write(CustomBlockedInMemoryShuffle.splitsCache(requestedSplit), + 0, requestedSplitLen) + bos.flush() + + // Update loop variables + numBlocksToSend = numBlocksToSend - 1 + + curTime = System.currentTimeMillis + // Revoke sending only if there is anyone waiting in the queue + if (curTime - startTime >= Shuffle.MaxChatTime && + threadPool.getQueue.size > 0) { + keepSending = false + } + } else { + // Close the connection + } + } + } catch { + // If something went wrong, e.g., the worker at the other end died etc + // then close everything up + // Exception can happen if the receiver stops receiving + // EOFException is expected to happen because receiver can break + // connection as soon as it has all the blocks + case eofe: java.io.EOFException => { } + case e: Exception => { + logInfo("ShuffleServerThread had a " + e) + } + } finally { + logInfo("ShuffleServerThread is closing streams and sockets") + ois.close() + // TODO: Following can cause "java.net.SocketException: Socket closed" + oos.close() + bos.close() + clientSocket.close() + } + } + } + } +} diff --git a/core/src/main/scala/spark/CustomBlockedLocalFileShuffle.scala b/core/src/main/scala/spark/CustomBlockedLocalFileShuffle.scala new file mode 100644 index 0000000000..98af7c8d65 --- /dev/null +++ b/core/src/main/scala/spark/CustomBlockedLocalFileShuffle.scala @@ -0,0 +1,654 @@ +package spark + +import java.io._ +import java.net._ +import java.util.{BitSet, Random, Timer, TimerTask, UUID} +import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory} + +import scala.collection.mutable.{ArrayBuffer, HashMap} + +/** + * An implementation of shuffle using local files served through custom server + * where receivers create simultaneous connections to multiple servers by + * setting the 'spark.shuffle.maxRxConnections' config option. + * + * By controlling the 'spark.shuffle.blockSize' config option one can also + * control the largest block size to divide each map output into. Essentially, + * instead of creating one large output file for each reducer, maps create + * multiple smaller files to enable finer level of engagement. + * + * 'spark.shuffle.maxTxConnections' enforces server-side cap. Ideally, + * maxTxConnections >= maxRxConnections * numReducersPerMachine + * + * TODO: Add support for compression when spark.compress is set to true. + */ +@serializable +class CustomBlockedLocalFileShuffle[K, V, C] +extends Shuffle[K, V, C] with Logging { + @transient var totalSplits = 0 + @transient var hasSplits = 0 + + @transient var totalBlocksInSplit: Array[Int] = null + @transient var hasBlocksInSplit: Array[Int] = null + + @transient var hasSplitsBitVector: BitSet = null + @transient var splitsInRequestBitVector: BitSet = null + + @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null + @transient var combiners: HashMap[K,C] = null + + override def compute(input: RDD[(K, V)], + numOutputSplits: Int, + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C) + : RDD[(K, C)] = + { + val sc = input.sparkContext + val shuffleId = CustomBlockedLocalFileShuffle.newShuffleId() + logInfo("Shuffle ID: " + shuffleId) + + val splitRdd = new NumberedSplitRDD(input) + val numInputSplits = splitRdd.splits.size + + // Run a parallel map and collect to write the intermediate data files, + // returning a list of inputSplitId -> serverUri pairs + val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => { + val myIndex = pair._1 + val myIterator = pair._2 + val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C]) + for ((k, v) <- myIterator) { + var bucketId = k.hashCode % numOutputSplits + if (bucketId < 0) { // Fix bucket ID if hash code was negative + bucketId += numOutputSplits + } + val bucket = buckets(bucketId) + bucket(k) = bucket.get(k) match { + case Some(c) => mergeValue(c, v) + case None => createCombiner(v) + } + } + + for (i <- 0 until numOutputSplits) { + var blockNum = 0 + var isDirty = false + var file: File = null + var out: ObjectOutputStream = null + + var writeStartTime: Long = 0 + + buckets(i).foreach(pair => { + // Open a new file if necessary + if (!isDirty) { + file = CustomBlockedLocalFileShuffle.getOutputFile(shuffleId, + myIndex, i, blockNum) + writeStartTime = System.currentTimeMillis + logInfo("BEGIN WRITE: " + file) + + out = new ObjectOutputStream(new FileOutputStream(file)) + } + + out.writeObject(pair) + out.flush() + isDirty = true + + // Close the old file if has crossed the blockSize limit + if (file.length > Shuffle.BlockSize) { + out.close() + logInfo("END WRITE: " + file) + val writeTime = System.currentTimeMillis - writeStartTime + logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.") + + blockNum = blockNum + 1 + isDirty = false + } + }) + + if (isDirty) { + out.close() + logInfo("END WRITE: " + file) + val writeTime = System.currentTimeMillis - writeStartTime + logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.") + + blockNum = blockNum + 1 + } + + // Write the BLOCKNUM file + file = CustomBlockedLocalFileShuffle.getBlockNumOutputFile(shuffleId, + myIndex, i) + out = new ObjectOutputStream(new FileOutputStream(file)) + out.writeObject(blockNum) + out.close() + } + + (myIndex, CustomBlockedLocalFileShuffle.serverAddress, + CustomBlockedLocalFileShuffle.serverPort) + }).collect() + + val splitsByUri = new ArrayBuffer[(String, Int, Int)] + for ((inputId, serverAddress, serverPort) <- outputLocs) { + splitsByUri += ((serverAddress, serverPort, inputId)) + } + + // TODO: Could broadcast outputLocs + + // Return an RDD that does each of the merges for a given partition + val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits) + return indexes.flatMap((myId: Int) => { + totalSplits = outputLocs.size + hasSplits = 0 + + totalBlocksInSplit = Array.tabulate(totalSplits)(_ => -1) + hasBlocksInSplit = Array.tabulate(totalSplits)(_ => 0) + + hasSplitsBitVector = new BitSet(totalSplits) + splitsInRequestBitVector = new BitSet(totalSplits) + + receivedData = new LinkedBlockingQueue[(Int, Array[Byte])] + combiners = new HashMap[K, C] + + var threadPool = Shuffle.newDaemonFixedThreadPool( + Shuffle.MaxRxConnections) + + while (hasSplits < totalSplits) { + var numThreadsToCreate = + Math.min(totalSplits, Shuffle.MaxRxConnections) - + threadPool.getActiveCount + + while (hasSplits < totalSplits && numThreadsToCreate > 0) { + // Select a random split to pull + val splitIndex = selectRandomSplit + + if (splitIndex != -1) { + val (serverAddress, serverPort, inputId) = splitsByUri(splitIndex) + val requestPath = "%d/%d/%d".format(shuffleId, inputId, myId) + + threadPool.execute(new ShuffleClient(serverAddress, serverPort, + shuffleId.toInt, inputId, myId, splitIndex)) + + // splitIndex is in transit. Will be unset in the ShuffleClient + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set(splitIndex) + } + } + + numThreadsToCreate = numThreadsToCreate - 1 + } + + // Sleep for a while before creating new threads + Thread.sleep(Shuffle.MinKnockInterval) + } + + threadPool.shutdown() + + // Start consumer + // TODO: Consumption is delayed until everything has been received. + // Otherwise it interferes with network performance + var shuffleConsumer = new ShuffleConsumer(mergeCombiners) + shuffleConsumer.setDaemon(true) + shuffleConsumer.start() + logInfo("ShuffleConsumer started...") + + // Don't return until consumption is finished + // while (receivedData.size > 0) { + // Thread.sleep(Shuffle.MinKnockInterval) + // } + + // Wait till shuffleConsumer is done + shuffleConsumer.join + + combiners + }) + } + + def selectRandomSplit: Int = { + var requiredSplits = new ArrayBuffer[Int] + + synchronized { + for (i <- 0 until totalSplits) { + if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) { + requiredSplits += i + } + } + } + + if (requiredSplits.size > 0) { + requiredSplits(CustomBlockedLocalFileShuffle.ranGen.nextInt( + requiredSplits.size)) + } else { + -1 + } + } + + class ShuffleConsumer(mergeCombiners: (C, C) => C) + extends Thread with Logging { + override def run: Unit = { + // Run until all splits are here + while (receivedData.size > 0) { + var splitIndex = -1 + var recvByteArray: Array[Byte] = null + + try { + var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])] + splitIndex = tempPair._1 + recvByteArray = tempPair._2 + } catch { + case e: Exception => { + logInfo("Exception during taking data from receivedData") + } + } + + val inputStream = + new ObjectInputStream(new ByteArrayInputStream(recvByteArray)) + + try{ + while (true) { + val (k, c) = inputStream.readObject.asInstanceOf[(K, C)] + combiners(k) = combiners.get(k) match { + case Some(oldC) => mergeCombiners(oldC, c) + case None => c + } + } + } catch { + case e: EOFException => { } + } + inputStream.close() + } + } + } + + class ShuffleClient(hostAddress: String, listenPort: Int, shuffleId: Int, + inputId: Int, myId: Int, splitIndex: Int) + extends Thread with Logging { + private var peerSocketToSource: Socket = null + private var oosSource: ObjectOutputStream = null + private var oisSource: ObjectInputStream = null + + private var receptionSucceeded = false + + override def run: Unit = { + // Setup the timeout mechanism + var timeOutTask = new TimerTask { + override def run: Unit = { + cleanUpConnections() + } + } + + var timeOutTimer = new Timer + timeOutTimer.schedule(timeOutTask, Shuffle.MaxKnockInterval) + + try { + // Everything will break if BLOCKNUM is not correctly received + // First get BLOCKNUM file if totalBlocksInSplit(splitIndex) is unknown + peerSocketToSource = new Socket(hostAddress, listenPort) + oosSource = + new ObjectOutputStream(peerSocketToSource.getOutputStream) + oosSource.flush() + var isSource = peerSocketToSource.getInputStream + oisSource = new ObjectInputStream(isSource) + + // Send path information + oosSource.writeObject((shuffleId, inputId, myId)) + + // TODO: Can be optimized. No need to do it everytime. + // Receive BLOCKNUM + totalBlocksInSplit(splitIndex) = oisSource.readObject.asInstanceOf[Int] + + // Turn the timer OFF, if the sender responds before timeout + timeOutTimer.cancel() + + while (hasBlocksInSplit(splitIndex) < totalBlocksInSplit(splitIndex)) { + // Set receptionSucceeded to false before trying for each block + receptionSucceeded = false + + // Request specific block + oosSource.writeObject(hasBlocksInSplit(splitIndex)) + + // Good to go. First, receive the length of the requested file + var requestedFileLen = oisSource.readObject.asInstanceOf[Int] + logInfo("Received requestedFileLen = " + requestedFileLen) + + val requestPath = "%d/%d/%d-%d".format(shuffleId, inputId, myId, + hasBlocksInSplit(splitIndex)) + + // Receive the file + if (requestedFileLen != -1) { + val readStartTime = System.currentTimeMillis + logInfo("BEGIN READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestPath)) + + // Receive data in an Array[Byte] + var recvByteArray = new Array[Byte](requestedFileLen) + var alreadyRead = 0 + var bytesRead = 0 + + while (alreadyRead != requestedFileLen) { + bytesRead = isSource.read(recvByteArray, alreadyRead, + requestedFileLen - alreadyRead) + if (bytesRead > 0) { + alreadyRead = alreadyRead + bytesRead + } + } + + // Make it available to the consumer + try { + receivedData.put((splitIndex, recvByteArray)) + } catch { + case e: Exception => { + logInfo("Exception during putting data into receivedData") + } + } + + // TODO: Updating stats before consumption is completed + hasBlocksInSplit(splitIndex) = hasBlocksInSplit(splitIndex) + 1 + + // Split has been received only if all the blocks have been received + if (hasBlocksInSplit(splitIndex) == totalBlocksInSplit(splitIndex)) { + hasSplitsBitVector.synchronized { + hasSplitsBitVector.set(splitIndex) + hasSplits += 1 + } + } + + // Consistent state in accounting variables + receptionSucceeded = true + + logInfo("END READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestPath)) + val readTime = System.currentTimeMillis - readStartTime + logInfo("Reading http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestPath) + " took " + readTime + " millis.") + } else { + throw new SparkException("ShuffleServer " + hostAddress + " does not have " + requestPath) + } + } + } catch { + // EOFException is expected to happen because sender can break + // connection due to timeout + case eofe: java.io.EOFException => { } + case e: Exception => { + logInfo("ShuffleClient had a " + e) + } + } finally { + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set(splitIndex, false) + } + cleanUpConnections() + } + } + + private def cleanUpConnections(): Unit = { + if (oisSource != null) { + oisSource.close() + } + if (oosSource != null) { + oosSource.close() + } + if (peerSocketToSource != null) { + peerSocketToSource.close() + } + } + } +} + +object CustomBlockedLocalFileShuffle extends Logging { + private var initialized = false + private var nextShuffleId = new AtomicLong(0) + + // Variables initialized by initializeIfNeeded() + private var shuffleDir: File = null + + private var shuffleServer: ShuffleServer = null + private var serverAddress = InetAddress.getLocalHost.getHostAddress + private var serverPort: Int = -1 + + // Random number generator + var ranGen = new Random + + private def initializeIfNeeded() = synchronized { + if (!initialized) { + // TODO: localDir should be created by some mechanism common to Spark + // so that it can be shared among shuffle, broadcast, etc + val localDirRoot = System.getProperty("spark.local.dir", "/tmp") + var tries = 0 + var foundLocalDir = false + var localDir: File = null + var localDirUuid: UUID = null + while (!foundLocalDir && tries < 10) { + tries += 1 + try { + localDirUuid = UUID.randomUUID + localDir = new File(localDirRoot, "spark-local-" + localDirUuid) + if (!localDir.exists) { + localDir.mkdirs() + foundLocalDir = true + } + } catch { + case e: Exception => + logWarning("Attempt " + tries + " to create local dir failed", e) + } + } + if (!foundLocalDir) { + logError("Failed 10 attempts to create local dir in " + localDirRoot) + System.exit(1) + } + shuffleDir = new File(localDir, "shuffle") + shuffleDir.mkdirs() + logInfo("Shuffle dir: " + shuffleDir) + + // Create and start the shuffleServer + shuffleServer = new ShuffleServer + shuffleServer.setDaemon(true) + shuffleServer.start() + logInfo("ShuffleServer started...") + + initialized = true + } + } + + def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int, + blockId: Int): File = { + initializeIfNeeded() + val dir = new File(shuffleDir, shuffleId + "/" + inputId) + dir.mkdirs() + val file = new File(dir, "%d-%d".format(outputId, blockId)) + return file + } + + def getBlockNumOutputFile(shuffleId: Long, inputId: Int, + outputId: Int): File = { + initializeIfNeeded() + val dir = new File(shuffleDir, shuffleId + "/" + inputId) + dir.mkdirs() + val file = new File(dir, "BLOCKNUM-" + outputId) + return file + } + + def newShuffleId(): Long = { + nextShuffleId.getAndIncrement() + } + + // Returns a standard ThreadFactory except all threads are daemons + private def newDaemonThreadFactory: ThreadFactory = { + new ThreadFactory { + def newThread(r: Runnable): Thread = { + var t = Executors.defaultThreadFactory.newThread(r) + t.setDaemon(true) + return t + } + } + } + + // Wrapper over newFixedThreadPool + def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = { + var threadPool = + Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor] + + threadPool.setThreadFactory(newDaemonThreadFactory) + + return threadPool + } + + class ShuffleServer + extends Thread with Logging { + var threadPool = + newDaemonFixedThreadPool(Shuffle.MaxTxConnections) + + var serverSocket: ServerSocket = null + + override def run: Unit = { + serverSocket = new ServerSocket(0) + serverPort = serverSocket.getLocalPort + + logInfo("ShuffleServer started with " + serverSocket) + logInfo("Local URI: http://" + serverAddress + ":" + serverPort) + + try { + while (true) { + var clientSocket: Socket = null + try { + clientSocket = serverSocket.accept() + } catch { + case e: Exception => { } + } + if (clientSocket != null) { + logInfo("Serve: Accepted new client connection:" + clientSocket) + try { + threadPool.execute(new ShuffleServerThread(clientSocket)) + } catch { + // In failure, close socket here; else, the thread will close it + case ioe: IOException => { + clientSocket.close() + } + } + } + } + } finally { + if (serverSocket != null) { + logInfo("ShuffleServer now stopping...") + serverSocket.close() + } + } + // Shutdown the thread pool + threadPool.shutdown() + } + + class ShuffleServerThread(val clientSocket: Socket) + extends Thread with Logging { + private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream] + os.flush() + private val bos = new BufferedOutputStream(os) + bos.flush() + private val oos = new ObjectOutputStream(os) + oos.flush() + private val ois = new ObjectInputStream(clientSocket.getInputStream) + + logInfo("new ShuffleServerThread is running") + + override def run: Unit = { + try { + // Receive basic path information + val (shuffleId, myIndex, outputId) = + ois.readObject.asInstanceOf[(Int, Int, Int)] + + var requestPathBase = "%d/%d/%d".format(shuffleId, myIndex, outputId) + + logInfo("requestPathBase: " + requestPathBase) + + // Read BLOCKNUM file and send back the total number of blocks + val blockNumFilePath = "%s/%d/%d/BLOCKNUM-%d".format(shuffleDir, + shuffleId, myIndex, outputId) + val blockNumIn = + new ObjectInputStream(new FileInputStream(blockNumFilePath)) + val BLOCKNUM = blockNumIn.readObject.asInstanceOf[Int] + blockNumIn.close() + + oos.writeObject(BLOCKNUM) + + val startTime = System.currentTimeMillis + var curTime = startTime + var keepSending = true + var numBlocksToSend = Shuffle.MaxChatBlocks + + while (keepSending && numBlocksToSend > 0) { + // Receive specific block request + val blockId = ois.readObject.asInstanceOf[Int] + + // Ready to send + val requestPath = requestPathBase + "-" + blockId + + // Open the file + var requestedFile: File = null + var requestedFileLen = -1 + try { + requestedFile = new File(shuffleDir + "/" + requestPath) + requestedFileLen = requestedFile.length.toInt + } catch { + case e: Exception => { } + } + + // Send the length of the requestPath to let the receiver know that + // transfer is about to start + // In the case of receiver timeout and connection close, this will + // throw a java.net.SocketException: Broken pipe + oos.writeObject(requestedFileLen) + oos.flush() + + logInfo("requestedFileLen = " + requestedFileLen) + + // Read and send the requested file + if (requestedFileLen != -1) { + // Read + var byteArray = new Array[Byte](requestedFileLen) + val bis = + new BufferedInputStream(new FileInputStream(requestedFile)) + + var bytesRead = bis.read(byteArray, 0, byteArray.length) + var alreadyRead = bytesRead + + while (alreadyRead < requestedFileLen) { + bytesRead = bis.read(byteArray, alreadyRead, + (byteArray.length - alreadyRead)) + if(bytesRead > 0) { + alreadyRead = alreadyRead + bytesRead + } + } + bis.close() + + // Send + bos.write(byteArray, 0, byteArray.length) + bos.flush() + + // Update loop variables + numBlocksToSend = numBlocksToSend - 1 + + curTime = System.currentTimeMillis + // Revoke sending only if there is anyone waiting in the queue + if (curTime - startTime >= Shuffle.MaxChatTime && + threadPool.getQueue.size > 0) { + keepSending = false + } + } else { + // Close the connection + } + } + } catch { + // If something went wrong, e.g., the worker at the other end died etc + // then close everything up + // Exception can happen if the receiver stops receiving + // EOFException is expected to happen because receiver can break + // connection as soon as it has all the blocks + case eofe: java.io.EOFException => { } + case e: Exception => { + logInfo("ShuffleServerThread had a " + e) + } + } finally { + logInfo("ShuffleServerThread is closing streams and sockets") + ois.close() + // TODO: Following can cause "java.net.SocketException: Socket closed" + oos.close() + bos.close() + clientSocket.close() + } + } + } + } +} diff --git a/core/src/main/scala/spark/CustomParallelFakeShuffle.scala b/core/src/main/scala/spark/CustomParallelFakeShuffle.scala new file mode 100644 index 0000000000..889c5111b6 --- /dev/null +++ b/core/src/main/scala/spark/CustomParallelFakeShuffle.scala @@ -0,0 +1,495 @@ +package spark + +import java.io._ +import java.net._ +import java.util.{BitSet, Random, Timer, TimerTask, UUID} +import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory} + +import scala.collection.mutable.{ArrayBuffer, HashMap} + +/** + * TODO: THIS IS AN ABSOLUTELY EXPERIMENTAL IMPLEMENTATON (FOR NOW). + * + * An implementation of shuffle using fake data served through custom server + * where receivers create simultaneous connections to multiple servers by + * setting the 'spark.shuffle.maxRxConnections' config option. + * + * TODO: Add support for compression when spark.compress is set to true. + */ +@serializable +class CustomParallelFakeShuffle[K, V, C] +extends Shuffle[K, V, C] with Logging { + @transient var totalSplits = 0 + @transient var hasSplits = 0 + @transient var hasSplitsBitVector: BitSet = null + @transient var splitsInRequestBitVector: BitSet = null + + @transient var combiners: HashMap[K,C] = null + + override def compute(input: RDD[(K, V)], + numOutputSplits: Int, + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C) + : RDD[(K, C)] = + { + val sc = input.sparkContext + val shuffleId = CustomParallelFakeShuffle.newShuffleId() + logInfo("Shuffle ID: " + shuffleId) + + val splitRdd = new NumberedSplitRDD(input) + val numInputSplits = splitRdd.splits.size + + // Run a parallel map and collect to write the intermediate data files, + // returning a list of inputSplitId -> serverUri pairs + val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => { + val myIndex = pair._1 + val myIterator = pair._2 + val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C]) + for ((k, v) <- myIterator) { + var bucketId = k.hashCode % numOutputSplits + if (bucketId < 0) { // Fix bucket ID if hash code was negative + bucketId += numOutputSplits + } + val bucket = buckets(bucketId) + bucket(k) = bucket.get(k) match { + case Some(c) => mergeValue(c, v) + case None => createCombiner(v) + } + } + + for (i <- 0 until numOutputSplits) { + val splitName = + CustomParallelFakeShuffle.getSplitName(shuffleId, myIndex, i) + + val writeStartTime = System.currentTimeMillis + logInfo("BEGIN WRITE: " + splitName) + + // Write buckets(i) to a byte array & put in splitsCache instead of file + val baos = new ByteArrayOutputStream + val oos = new ObjectOutputStream(baos) + buckets(i).foreach(pair => oos.writeObject(pair)) + oos.close + baos.close + + // Store the length only + CustomParallelFakeShuffle.splitsCache(splitName) = baos.toByteArray.length + val splitLen = baos.toByteArray.length + + logInfo("END WRITE: " + splitName) + val writeTime = System.currentTimeMillis - writeStartTime + logInfo("Writing " + splitName + " of size " + splitLen + " bytes took " + writeTime + " millis.") + } + + (myIndex, CustomParallelFakeShuffle.serverAddress, + CustomParallelFakeShuffle.serverPort) + }).collect() + + val splitsByUri = new ArrayBuffer[(String, Int, Int)] + for ((inputId, serverAddress, serverPort) <- outputLocs) { + splitsByUri += ((serverAddress, serverPort, inputId)) + } + + // TODO: Could broadcast splitsByUri + + // Return an RDD that does each of the merges for a given partition + val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits) + return indexes.flatMap((myId: Int) => { + totalSplits = splitsByUri.size + hasSplits = 0 + hasSplitsBitVector = new BitSet(totalSplits) + splitsInRequestBitVector = new BitSet(totalSplits) + + combiners = new HashMap[K, C] + + var threadPool = Shuffle.newDaemonFixedThreadPool( + Shuffle.MaxRxConnections) + + var totalThreadsCreated = 0 + + while (hasSplits < totalSplits) { + var numThreadsToCreate = Math.min(totalSplits, + Shuffle.MaxRxConnections) - + threadPool.getActiveCount + + while (hasSplits < totalSplits && numThreadsToCreate > 0) { + // Select a random split to pull + val splitIndex = selectRandomSplit + + if (splitIndex != -1) { + val (serverAddress, serverPort, inputId) = splitsByUri(splitIndex) + val requestSplit = "%d/%d/%d".format(shuffleId, inputId, myId) + + threadPool.execute(new ShuffleClient(splitIndex, serverAddress, + serverPort, requestSplit)) + totalThreadsCreated += 1 + logInfo("totalThreadsCreated = %d / threadPool.getActiveCount = %d".format(totalThreadsCreated, threadPool.getActiveCount.toInt)) + + // splitIndex is in transit. Will be unset in the ShuffleClient + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set(splitIndex) + } + } + + numThreadsToCreate = numThreadsToCreate - 1 + } + + // Sleep for a while before creating new threads + Thread.sleep(Shuffle.MinKnockInterval) + } + + threadPool.shutdown() + + combiners + }) + } + + def selectRandomSplit: Int = { + var requiredSplits = new ArrayBuffer[Int] + + synchronized { + for (i <- 0 until totalSplits) { + if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) { + requiredSplits += i + } + } + } + + if (requiredSplits.size > 0) { + requiredSplits(CustomParallelFakeShuffle.ranGen.nextInt( + requiredSplits.size)) + } else { + -1 + } + } + + class ShuffleClient(splitIndex: Int, hostAddress: String, listenPort: Int, + requestSplit: String) + extends Thread with Logging { + private var peerSocketToSource: Socket = null + private var oosSource: ObjectOutputStream = null + private var oisSource: ObjectInputStream = null + + private var receptionSucceeded = false + + override def run: Unit = { + // Setup the timeout mechanism + var timeOutTask = new TimerTask { + override def run: Unit = { + cleanUpConnections() + } + } + + var timeOutTimer = new Timer + timeOutTimer.schedule(timeOutTask, Shuffle.MaxKnockInterval) + + logInfo("ShuffleClient started... => %s:%d#%s".format(hostAddress, listenPort, requestSplit)) + + try { + // Connect to the source + peerSocketToSource = new Socket(hostAddress, listenPort) + oosSource = + new ObjectOutputStream(peerSocketToSource.getOutputStream) + oosSource.flush() + var isSource = peerSocketToSource.getInputStream + oisSource = new ObjectInputStream(isSource) + + // Send the request + oosSource.writeObject(requestSplit) + + // Receive the length of the requested file + var requestedFileLen = oisSource.readObject.asInstanceOf[Int] + logInfo("Received requestedFileLen = " + requestedFileLen) + + // Turn the timer OFF, if the sender responds before timeout + timeOutTimer.cancel() + + // Receive the file + if (requestedFileLen != -1) { + val readStartTime = System.currentTimeMillis + logInfo("BEGIN READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit)) + +// var recvByteArray = new Array[Byte](requestedFileLen) +// var alreadyRead = 0 +// var bytesRead = 0 +// +// while (alreadyRead != requestedFileLen) { +// bytesRead = isSource.read(recvByteArray, alreadyRead, +// requestedFileLen - alreadyRead) +// if (bytesRead > 0) { +// alreadyRead = alreadyRead + bytesRead +// } +// } + + // Receive data in an Array[Byte] + var recvByteArray = new Array[Byte](65536) + var alreadyRead = 0 + var bytesRead = 0 + + while (alreadyRead != requestedFileLen) { + bytesRead = isSource.read(recvByteArray) + if (bytesRead > 0) { + alreadyRead = alreadyRead + bytesRead + } + } + + // TODO: Updating stats before consumption is completed + hasSplitsBitVector.synchronized { + hasSplitsBitVector.set(splitIndex) + hasSplits += 1 + } + + // We have received splitIndex + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set(splitIndex, false) + } + + receptionSucceeded = true + + logInfo("END READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit)) + val readTime = System.currentTimeMillis - readStartTime + logInfo("Reading http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit) + " took " + readTime + " millis.") + } else { + throw new SparkException("ShuffleServer " + hostAddress + " does not have " + requestSplit) + } + } catch { + // EOFException is expected to happen because sender can break + // connection due to timeout + case eofe: java.io.EOFException => { } + case e: Exception => { + logInfo("ShuffleClient had a " + e) + } + } finally { + // If reception failed, unset for future retry + if (!receptionSucceeded) { + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set(splitIndex, false) + } + } + cleanUpConnections() + } + } + + private def cleanUpConnections(): Unit = { + if (oisSource != null) { + oisSource.close() + } + if (oosSource != null) { + oosSource.close() + } + if (peerSocketToSource != null) { + peerSocketToSource.close() + } + } + } +} + +object CustomParallelFakeShuffle extends Logging { + // Cache for keeping the splits around + val splitsCache = new HashMap[String, Int] + + private var initialized = false + private var nextShuffleId = new AtomicLong(0) + + // Variables initialized by initializeIfNeeded() + private var shuffleDir: File = null + + private var shuffleServer: ShuffleServer = null + private var serverAddress = InetAddress.getLocalHost.getHostAddress + private var serverPort: Int = -1 + + // Random number generator + var ranGen = new Random + + private def initializeIfNeeded() = synchronized { + if (!initialized) { + // TODO: localDir should be created by some mechanism common to Spark + // so that it can be shared among shuffle, broadcast, etc + val localDirRoot = System.getProperty("spark.local.dir", "/tmp") + var tries = 0 + var foundLocalDir = false + var localDir: File = null + var localDirUuid: UUID = null + + while (!foundLocalDir && tries < 10) { + tries += 1 + try { + localDirUuid = UUID.randomUUID + localDir = new File(localDirRoot, "spark-local-" + localDirUuid) + if (!localDir.exists) { + localDir.mkdirs() + foundLocalDir = true + } + } catch { + case e: Exception => + logWarning("Attempt " + tries + " to create local dir failed", e) + } + } + if (!foundLocalDir) { + logError("Failed 10 attempts to create local dir in " + localDirRoot) + System.exit(1) + } + shuffleDir = new File(localDir, "shuffle") + shuffleDir.mkdirs() + logInfo("Shuffle dir: " + shuffleDir) + + // Create and start the shuffleServer + shuffleServer = new ShuffleServer + shuffleServer.setDaemon(true) + shuffleServer.start() + logInfo("ShuffleServer started...") + + initialized = true + } + } + + def getSplitName(shuffleId: Long, inputId: Int, outputId: Int): String = { + initializeIfNeeded() + // Adding shuffleDir is unnecessary. Added to keep the parsers working + return "%s/%d/%d/%d".format(shuffleDir, shuffleId, inputId, outputId) + } + + def newShuffleId(): Long = { + nextShuffleId.getAndIncrement() + } + + // Returns a standard ThreadFactory except all threads are daemons + private def newDaemonThreadFactory: ThreadFactory = { + new ThreadFactory { + def newThread(r: Runnable): Thread = { + var t = Executors.defaultThreadFactory.newThread(r) + t.setDaemon(true) + return t + } + } + } + + // Wrapper over newFixedThreadPool + def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = { + var threadPool = + Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor] + + threadPool.setThreadFactory(newDaemonThreadFactory) + + return threadPool + } + + class ShuffleServer + extends Thread with Logging { + var threadPool = newDaemonFixedThreadPool(Shuffle.MaxTxConnections) + + var serverSocket: ServerSocket = null + + override def run: Unit = { + serverSocket = new ServerSocket(0) + serverPort = serverSocket.getLocalPort + + logInfo("ShuffleServer started with " + serverSocket) + logInfo("Local URI: http://" + serverAddress + ":" + serverPort) + + try { + while (true) { + var clientSocket: Socket = null + try { + clientSocket = serverSocket.accept() + } catch { + case e: Exception => { } + } + if (clientSocket != null) { + logInfo("Serve: Accepted new client connection:" + clientSocket) + try { + threadPool.execute(new ShuffleServerThread(clientSocket)) + } catch { + // In failure, close socket here; else, the thread will close it + case ioe: IOException => { + clientSocket.close() + } + } + } + } + } finally { + if (serverSocket != null) { + logInfo("ShuffleServer now stopping...") + serverSocket.close() + } + } + // Shutdown the thread pool + threadPool.shutdown() + } + + class ShuffleServerThread(val clientSocket: Socket) + extends Thread with Logging { + private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream] + os.flush() + private val bos = new BufferedOutputStream(os) + bos.flush() + private val oos = new ObjectOutputStream(os) + oos.flush() + private val ois = new ObjectInputStream(clientSocket.getInputStream) + + logInfo("new ShuffleServerThread is running") + + override def run: Unit = { + try { + // Receive requestedSplit from the receiver + // Adding shuffleDir is unnecessary. Added to keep the parsers working + var requestedSplit = + shuffleDir + "/" + ois.readObject.asInstanceOf[String] + logInfo("requestedSplit: " + requestedSplit) + + // Send the length of the requestedSplit to let the receiver know that + // transfer is about to start + // In the case of receiver timeout and connection close, this will + // throw a java.net.SocketException: Broken pipe + var requestedSplitLen = -1 + + try { + requestedSplitLen = + CustomParallelFakeShuffle.splitsCache(requestedSplit) + } catch { + case e: Exception => { } + } + + oos.writeObject(requestedSplitLen) + oos.flush() + + logInfo("requestedSplitLen = " + requestedSplitLen) + + // Read and send the requested split + if (requestedSplitLen != -1) { + // Send fake data +// var byteArray = new Array[Byte](requestedSplitLen) +// bos.write(byteArray, 0, byteArray.length) +// bos.flush() + + val buf = new Array[Byte](65536) + var bytesSent = 0 + while (bytesSent < requestedSplitLen) { + val bytesToSend = Math.min(requestedSplitLen - bytesSent, buf.length) + bos.write(buf, 0, bytesToSend) + bos.flush() + bytesSent += bytesToSend + } + } else { + // Close the connection + } + } catch { + // If something went wrong, e.g., the worker at the other end died etc + // then close everything up + // Exception can happen if the receiver stops receiving + case e: Exception => { + logInfo("ShuffleServerThread had a " + e) + } + } finally { + logInfo("ShuffleServerThread is closing streams and sockets") + ois.close() + // TODO: Following can cause "java.net.SocketException: Socket closed" + oos.close() + bos.close() + clientSocket.close() + } + } + } + } +} diff --git a/core/src/main/scala/spark/CustomParallelInMemoryShuffle.scala b/core/src/main/scala/spark/CustomParallelInMemoryShuffle.scala new file mode 100644 index 0000000000..48f3685a1a --- /dev/null +++ b/core/src/main/scala/spark/CustomParallelInMemoryShuffle.scala @@ -0,0 +1,535 @@ +package spark + +import java.io._ +import java.net._ +import java.util.{BitSet, Random, Timer, TimerTask, UUID} +import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory} + +import scala.collection.mutable.{ArrayBuffer, HashMap} + +/** + * TODO: THIS IS AN ABSOLUTELY EXPERIMENTAL IMPLEMENTATON (FOR NOW). + * + * An implementation of shuffle using local memory served through custom server + * where receivers create simultaneous connections to multiple servers by + * setting the 'spark.shuffle.maxRxConnections' config option. + * + * TODO: Add support for compression when spark.compress is set to true. + */ +@serializable +class CustomParallelInMemoryShuffle[K, V, C] +extends Shuffle[K, V, C] with Logging { + @transient var totalSplits = 0 + @transient var hasSplits = 0 + @transient var hasSplitsBitVector: BitSet = null + @transient var splitsInRequestBitVector: BitSet = null + + @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null + @transient var combiners: HashMap[K,C] = null + + override def compute(input: RDD[(K, V)], + numOutputSplits: Int, + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C) + : RDD[(K, C)] = + { + val sc = input.sparkContext + val shuffleId = CustomParallelInMemoryShuffle.newShuffleId() + logInfo("Shuffle ID: " + shuffleId) + + val splitRdd = new NumberedSplitRDD(input) + val numInputSplits = splitRdd.splits.size + + // Run a parallel map and collect to write the intermediate data files, + // returning a list of inputSplitId -> serverUri pairs + val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => { + val myIndex = pair._1 + val myIterator = pair._2 + val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C]) + for ((k, v) <- myIterator) { + var bucketId = k.hashCode % numOutputSplits + if (bucketId < 0) { // Fix bucket ID if hash code was negative + bucketId += numOutputSplits + } + val bucket = buckets(bucketId) + bucket(k) = bucket.get(k) match { + case Some(c) => mergeValue(c, v) + case None => createCombiner(v) + } + } + + for (i <- 0 until numOutputSplits) { + val splitName = + CustomParallelInMemoryShuffle.getSplitName(shuffleId, myIndex, i) + + val writeStartTime = System.currentTimeMillis + logInfo("BEGIN WRITE: " + splitName) + + // Write buckets(i) to a byte array & put in splitsCache instead of file + val baos = new ByteArrayOutputStream + val oos = new ObjectOutputStream(baos) + buckets(i).foreach(pair => oos.writeObject(pair)) + oos.close + baos.close + + CustomParallelInMemoryShuffle.splitsCache(splitName) = baos.toByteArray + val splitLen = + CustomParallelInMemoryShuffle.splitsCache(splitName).length + + logInfo("END WRITE: " + splitName) + val writeTime = System.currentTimeMillis - writeStartTime + logInfo("Writing " + splitName + " of size " + splitLen + " bytes took " + writeTime + " millis.") + } + + (myIndex, CustomParallelInMemoryShuffle.serverAddress, + CustomParallelInMemoryShuffle.serverPort) + }).collect() + + val splitsByUri = new ArrayBuffer[(String, Int, Int)] + for ((inputId, serverAddress, serverPort) <- outputLocs) { + splitsByUri += ((serverAddress, serverPort, inputId)) + } + + // TODO: Could broadcast splitsByUri + + // Return an RDD that does each of the merges for a given partition + val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits) + return indexes.flatMap((myId: Int) => { + totalSplits = splitsByUri.size + hasSplits = 0 + hasSplitsBitVector = new BitSet(totalSplits) + splitsInRequestBitVector = new BitSet(totalSplits) + + receivedData = new LinkedBlockingQueue[(Int, Array[Byte])] + combiners = new HashMap[K, C] + + var threadPool = Shuffle.newDaemonFixedThreadPool( + Shuffle.MaxRxConnections) + + while (hasSplits < totalSplits) { + var numThreadsToCreate = Math.min(totalSplits, + Shuffle.MaxRxConnections) - + threadPool.getActiveCount + + while (hasSplits < totalSplits && numThreadsToCreate > 0) { + // Select a random split to pull + val splitIndex = selectRandomSplit + + if (splitIndex != -1) { + val (serverAddress, serverPort, inputId) = splitsByUri(splitIndex) + val requestSplit = "%d/%d/%d".format(shuffleId, inputId, myId) + + threadPool.execute(new ShuffleClient(splitIndex, serverAddress, + serverPort, requestSplit)) + + // splitIndex is in transit. Will be unset in the ShuffleClient + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set(splitIndex) + } + } + + numThreadsToCreate = numThreadsToCreate - 1 + } + + // Sleep for a while before creating new threads + Thread.sleep(Shuffle.MinKnockInterval) + } + + threadPool.shutdown() + + // Start consumer + // TODO: Consumption is delayed until everything has been received. + // Otherwise it interferes with network performance + var shuffleConsumer = new ShuffleConsumer(mergeCombiners) + shuffleConsumer.setDaemon(true) + shuffleConsumer.start() + logInfo("ShuffleConsumer started...") + + // Don't return until consumption is finished + // while (receivedData.size > 0) { + // Thread.sleep(Shuffle.MinKnockInterval) + // } + + // Wait till shuffleConsumer is done + shuffleConsumer.join + + combiners + }) + } + + def selectRandomSplit: Int = { + var requiredSplits = new ArrayBuffer[Int] + + synchronized { + for (i <- 0 until totalSplits) { + if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) { + requiredSplits += i + } + } + } + + if (requiredSplits.size > 0) { + requiredSplits(CustomParallelInMemoryShuffle.ranGen.nextInt( + requiredSplits.size)) + } else { + -1 + } + } + + class ShuffleConsumer(mergeCombiners: (C, C) => C) + extends Thread with Logging { + override def run: Unit = { + // Run until all splits are here + while (receivedData.size > 0) { + var splitIndex = -1 + var recvByteArray: Array[Byte] = null + + try { + var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])] + splitIndex = tempPair._1 + recvByteArray = tempPair._2 + } catch { + case e: Exception => { + logInfo("Exception during taking data from receivedData") + } + } + + val inputStream = + new ObjectInputStream(new ByteArrayInputStream(recvByteArray)) + + try{ + while (true) { + val (k, c) = inputStream.readObject.asInstanceOf[(K, C)] + combiners(k) = combiners.get(k) match { + case Some(oldC) => mergeCombiners(oldC, c) + case None => c + } + } + } catch { + case e: EOFException => { } + } + inputStream.close() + } + } + } + + class ShuffleClient(splitIndex: Int, hostAddress: String, listenPort: Int, + requestSplit: String) + extends Thread with Logging { + private var peerSocketToSource: Socket = null + private var oosSource: ObjectOutputStream = null + private var oisSource: ObjectInputStream = null + + private var receptionSucceeded = false + + override def run: Unit = { + // Setup the timeout mechanism + var timeOutTask = new TimerTask { + override def run: Unit = { + cleanUpConnections() + } + } + + var timeOutTimer = new Timer + timeOutTimer.schedule(timeOutTask, Shuffle.MaxKnockInterval) + + logInfo("ShuffleClient started... => %s:%d#%s".format(hostAddress, listenPort, requestSplit)) + + try { + // Connect to the source + peerSocketToSource = new Socket(hostAddress, listenPort) + oosSource = + new ObjectOutputStream(peerSocketToSource.getOutputStream) + oosSource.flush() + var isSource = peerSocketToSource.getInputStream + oisSource = new ObjectInputStream(isSource) + + // Send the request + oosSource.writeObject(requestSplit) + + // Receive the length of the requested file + var requestedFileLen = oisSource.readObject.asInstanceOf[Int] + logInfo("Received requestedFileLen = " + requestedFileLen) + + // Turn the timer OFF, if the sender responds before timeout + timeOutTimer.cancel() + + // Receive the file + if (requestedFileLen != -1) { + val readStartTime = System.currentTimeMillis + logInfo("BEGIN READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit)) + + // Receive data in an Array[Byte] + var recvByteArray = new Array[Byte](requestedFileLen) + var alreadyRead = 0 + var bytesRead = 0 + + while (alreadyRead != requestedFileLen) { + bytesRead = isSource.read(recvByteArray, alreadyRead, + requestedFileLen - alreadyRead) + if (bytesRead > 0) { + alreadyRead = alreadyRead + bytesRead + } + } + + // Make it available to the consumer + try { + receivedData.put((splitIndex, recvByteArray)) + } catch { + case e: Exception => { + logInfo("Exception during putting data into receivedData") + } + } + + // TODO: Updating stats before consumption is completed + hasSplitsBitVector.synchronized { + hasSplitsBitVector.set(splitIndex) + hasSplits += 1 + } + + // We have received splitIndex + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set(splitIndex, false) + } + + receptionSucceeded = true + + logInfo("END READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit)) + val readTime = System.currentTimeMillis - readStartTime + logInfo("Reading http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit) + " took " + readTime + " millis.") + } else { + throw new SparkException("ShuffleServer " + hostAddress + " does not have " + requestSplit) + } + } catch { + // EOFException is expected to happen because sender can break + // connection due to timeout + case eofe: java.io.EOFException => { } + case e: Exception => { + logInfo("ShuffleClient had a " + e) + } + } finally { + // If reception failed, unset for future retry + if (!receptionSucceeded) { + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set(splitIndex, false) + } + } + cleanUpConnections() + } + } + + private def cleanUpConnections(): Unit = { + if (oisSource != null) { + oisSource.close() + } + if (oosSource != null) { + oosSource.close() + } + if (peerSocketToSource != null) { + peerSocketToSource.close() + } + } + } +} + +object CustomParallelInMemoryShuffle extends Logging { + // Cache for keeping the splits around + val splitsCache = new HashMap[String, Array[Byte]] + + private var initialized = false + private var nextShuffleId = new AtomicLong(0) + + // Variables initialized by initializeIfNeeded() + private var shuffleDir: File = null + + private var shuffleServer: ShuffleServer = null + private var serverAddress = InetAddress.getLocalHost.getHostAddress + private var serverPort: Int = -1 + + // Random number generator + var ranGen = new Random + + private def initializeIfNeeded() = synchronized { + if (!initialized) { + // TODO: localDir should be created by some mechanism common to Spark + // so that it can be shared among shuffle, broadcast, etc + val localDirRoot = System.getProperty("spark.local.dir", "/tmp") + var tries = 0 + var foundLocalDir = false + var localDir: File = null + var localDirUuid: UUID = null + + while (!foundLocalDir && tries < 10) { + tries += 1 + try { + localDirUuid = UUID.randomUUID + localDir = new File(localDirRoot, "spark-local-" + localDirUuid) + if (!localDir.exists) { + localDir.mkdirs() + foundLocalDir = true + } + } catch { + case e: Exception => + logWarning("Attempt " + tries + " to create local dir failed", e) + } + } + if (!foundLocalDir) { + logError("Failed 10 attempts to create local dir in " + localDirRoot) + System.exit(1) + } + shuffleDir = new File(localDir, "shuffle") + shuffleDir.mkdirs() + logInfo("Shuffle dir: " + shuffleDir) + + // Create and start the shuffleServer + shuffleServer = new ShuffleServer + shuffleServer.setDaemon(true) + shuffleServer.start() + logInfo("ShuffleServer started...") + + initialized = true + } + } + + def getSplitName(shuffleId: Long, inputId: Int, outputId: Int): String = { + initializeIfNeeded() + // Adding shuffleDir is unnecessary. Added to keep the parsers working + return "%s/%d/%d/%d".format(shuffleDir, shuffleId, inputId, outputId) + } + + def newShuffleId(): Long = { + nextShuffleId.getAndIncrement() + } + + // Returns a standard ThreadFactory except all threads are daemons + private def newDaemonThreadFactory: ThreadFactory = { + new ThreadFactory { + def newThread(r: Runnable): Thread = { + var t = Executors.defaultThreadFactory.newThread(r) + t.setDaemon(true) + return t + } + } + } + + // Wrapper over newFixedThreadPool + def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = { + var threadPool = + Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor] + + threadPool.setThreadFactory(newDaemonThreadFactory) + + return threadPool + } + + class ShuffleServer + extends Thread with Logging { + var threadPool = newDaemonFixedThreadPool(Shuffle.MaxTxConnections) + + var serverSocket: ServerSocket = null + + override def run: Unit = { + serverSocket = new ServerSocket(0) + serverPort = serverSocket.getLocalPort + + logInfo("ShuffleServer started with " + serverSocket) + logInfo("Local URI: http://" + serverAddress + ":" + serverPort) + + try { + while (true) { + var clientSocket: Socket = null + try { + clientSocket = serverSocket.accept() + } catch { + case e: Exception => { } + } + if (clientSocket != null) { + logInfo("Serve: Accepted new client connection:" + clientSocket) + try { + threadPool.execute(new ShuffleServerThread(clientSocket)) + } catch { + // In failure, close socket here; else, the thread will close it + case ioe: IOException => { + clientSocket.close() + } + } + } + } + } finally { + if (serverSocket != null) { + logInfo("ShuffleServer now stopping...") + serverSocket.close() + } + } + // Shutdown the thread pool + threadPool.shutdown() + } + + class ShuffleServerThread(val clientSocket: Socket) + extends Thread with Logging { + private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream] + os.flush() + private val bos = new BufferedOutputStream(os) + bos.flush() + private val oos = new ObjectOutputStream(os) + oos.flush() + private val ois = new ObjectInputStream(clientSocket.getInputStream) + + logInfo("new ShuffleServerThread is running") + + override def run: Unit = { + try { + // Receive requestedSplit from the receiver + // Adding shuffleDir is unnecessary. Added to keep the parsers working + var requestedSplit = + shuffleDir + "/" + ois.readObject.asInstanceOf[String] + logInfo("requestedSplit: " + requestedSplit) + + // Send the length of the requestedSplit to let the receiver know that + // transfer is about to start + // In the case of receiver timeout and connection close, this will + // throw a java.net.SocketException: Broken pipe + var requestedSplitLen = -1 + + try { + requestedSplitLen = + CustomParallelInMemoryShuffle.splitsCache(requestedSplit).length + } catch { + case e: Exception => { } + } + + oos.writeObject(requestedSplitLen) + oos.flush() + + logInfo("requestedSplitLen = " + requestedSplitLen) + + // Read and send the requested split + if (requestedSplitLen != -1) { + // Send + bos.write(CustomParallelInMemoryShuffle.splitsCache(requestedSplit), + 0, requestedSplitLen) + bos.flush() + } else { + // Close the connection + } + } catch { + // If something went wrong, e.g., the worker at the other end died etc + // then close everything up + // Exception can happen if the receiver stops receiving + case e: Exception => { + logInfo("ShuffleServerThread had a " + e) + } + } finally { + logInfo("ShuffleServerThread is closing streams and sockets") + ois.close() + // TODO: Following can cause "java.net.SocketException: Socket closed" + oos.close() + bos.close() + clientSocket.close() + } + } + } + } +} diff --git a/core/src/main/scala/spark/CustomParallelLocalFileShuffle.scala b/core/src/main/scala/spark/CustomParallelLocalFileShuffle.scala new file mode 100644 index 0000000000..87e824fb2e --- /dev/null +++ b/core/src/main/scala/spark/CustomParallelLocalFileShuffle.scala @@ -0,0 +1,541 @@ +package spark + +import java.io._ +import java.net._ +import java.util.{BitSet, Random, Timer, TimerTask, UUID} +import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory} + +import scala.collection.mutable.{ArrayBuffer, HashMap} + +/** + * An implementation of shuffle using local files served through custom server + * where receivers create simultaneous connections to multiple servers by + * setting the 'spark.shuffle.maxRxConnections' config option. + * + * 'spark.shuffle.maxTxConnections' enforces server-side cap. Ideally, + * maxTxConnections >= maxRxConnections * numReducersPerMachine + * + * TODO: Add support for compression when spark.compress is set to true. + */ +@serializable +class CustomParallelLocalFileShuffle[K, V, C] +extends Shuffle[K, V, C] with Logging { + @transient var totalSplits = 0 + @transient var hasSplits = 0 + @transient var hasSplitsBitVector: BitSet = null + @transient var splitsInRequestBitVector: BitSet = null + + @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null + @transient var combiners: HashMap[K,C] = null + + override def compute(input: RDD[(K, V)], + numOutputSplits: Int, + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C) + : RDD[(K, C)] = + { + val sc = input.sparkContext + val shuffleId = CustomParallelLocalFileShuffle.newShuffleId() + logInfo("Shuffle ID: " + shuffleId) + + val splitRdd = new NumberedSplitRDD(input) + val numInputSplits = splitRdd.splits.size + + // Run a parallel map and collect to write the intermediate data files, + // returning a list of inputSplitId -> serverUri pairs + val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => { + val myIndex = pair._1 + val myIterator = pair._2 + val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C]) + for ((k, v) <- myIterator) { + var bucketId = k.hashCode % numOutputSplits + if (bucketId < 0) { // Fix bucket ID if hash code was negative + bucketId += numOutputSplits + } + val bucket = buckets(bucketId) + bucket(k) = bucket.get(k) match { + case Some(c) => mergeValue(c, v) + case None => createCombiner(v) + } + } + + for (i <- 0 until numOutputSplits) { + val file = CustomParallelLocalFileShuffle.getOutputFile(shuffleId, + myIndex, i) + val writeStartTime = System.currentTimeMillis + logInfo("BEGIN WRITE: " + file) + val out = new ObjectOutputStream(new FileOutputStream(file)) + buckets(i).foreach(pair => out.writeObject(pair)) + out.close() + logInfo("END WRITE: " + file) + val writeTime = System.currentTimeMillis - writeStartTime + logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.") + } + + (myIndex, CustomParallelLocalFileShuffle.serverAddress, + CustomParallelLocalFileShuffle.serverPort) + }).collect() + + val splitsByUri = new ArrayBuffer[(String, Int, Int)] + for ((inputId, serverAddress, serverPort) <- outputLocs) { + splitsByUri += ((serverAddress, serverPort, inputId)) + } + + // TODO: Could broadcast splitsByUri + + // Return an RDD that does each of the merges for a given partition + val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits) + return indexes.flatMap((myId: Int) => { + totalSplits = splitsByUri.size + hasSplits = 0 + hasSplitsBitVector = new BitSet(totalSplits) + splitsInRequestBitVector = new BitSet(totalSplits) + + receivedData = new LinkedBlockingQueue[(Int, Array[Byte])] + combiners = new HashMap[K, C] + + var threadPool = Shuffle.newDaemonFixedThreadPool( + Shuffle.MaxRxConnections) + + while (hasSplits < totalSplits) { + var numThreadsToCreate = Math.min(totalSplits, + Shuffle.MaxRxConnections) - + threadPool.getActiveCount + + while (hasSplits < totalSplits && numThreadsToCreate > 0) { + // Select a random split to pull + val splitIndex = selectRandomSplit + + if (splitIndex != -1) { + val (serverAddress, serverPort, inputId) = splitsByUri(splitIndex) + val requestPath = "%d/%d/%d".format(shuffleId, inputId, myId) + + threadPool.execute(new ShuffleClient(splitIndex, serverAddress, + serverPort, requestPath)) + + // splitIndex is in transit. Will be unset in the ShuffleClient + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set(splitIndex) + } + } + + numThreadsToCreate = numThreadsToCreate - 1 + } + + // Sleep for a while before creating new threads + Thread.sleep(Shuffle.MinKnockInterval) + } + + threadPool.shutdown() + + // Start consumer + // TODO: Consumption is delayed until everything has been received. + // Otherwise it interferes with network performance + var shuffleConsumer = new ShuffleConsumer(mergeCombiners) + shuffleConsumer.setDaemon(true) + shuffleConsumer.start() + logInfo("ShuffleConsumer started...") + + // Don't return until consumption is finished + // while (receivedData.size > 0) { + // Thread.sleep(Shuffle.MinKnockInterval) + // } + + // Wait till shuffleConsumer is done + shuffleConsumer.join + + combiners + }) + } + + def selectRandomSplit: Int = { + var requiredSplits = new ArrayBuffer[Int] + + synchronized { + for (i <- 0 until totalSplits) { + if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) { + requiredSplits += i + } + } + } + + if (requiredSplits.size > 0) { + requiredSplits(CustomParallelLocalFileShuffle.ranGen.nextInt( + requiredSplits.size)) + } else { + -1 + } + } + + class ShuffleConsumer(mergeCombiners: (C, C) => C) + extends Thread with Logging { + override def run: Unit = { + // Run until all splits are here + while (receivedData.size > 0) { + var splitIndex = -1 + var recvByteArray: Array[Byte] = null + + try { + var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])] + splitIndex = tempPair._1 + recvByteArray = tempPair._2 + } catch { + case e: Exception => { + logInfo("Exception during taking data from receivedData") + } + } + + val inputStream = + new ObjectInputStream(new ByteArrayInputStream(recvByteArray)) + + try{ + while (true) { + val (k, c) = inputStream.readObject.asInstanceOf[(K, C)] + combiners(k) = combiners.get(k) match { + case Some(oldC) => mergeCombiners(oldC, c) + case None => c + } + } + } catch { + case e: EOFException => { } + } + inputStream.close() + } + } + } + + class ShuffleClient(splitIndex: Int, hostAddress: String, listenPort: Int, + requestPath: String) + extends Thread with Logging { + private var peerSocketToSource: Socket = null + private var oosSource: ObjectOutputStream = null + private var oisSource: ObjectInputStream = null + + private var receptionSucceeded = false + + override def run: Unit = { + // Setup the timeout mechanism + var timeOutTask = new TimerTask { + override def run: Unit = { + cleanUpConnections() + } + } + + var timeOutTimer = new Timer + timeOutTimer.schedule(timeOutTask, Shuffle.MaxKnockInterval) + + logInfo("ShuffleClient started... => %s:%d#%s".format(hostAddress, listenPort, requestPath)) + + try { + // Connect to the source + peerSocketToSource = new Socket(hostAddress, listenPort) + oosSource = + new ObjectOutputStream(peerSocketToSource.getOutputStream) + oosSource.flush() + var isSource = peerSocketToSource.getInputStream + oisSource = new ObjectInputStream(isSource) + + // Send the request + oosSource.writeObject(requestPath) + + // Receive the length of the requested file + var requestedFileLen = oisSource.readObject.asInstanceOf[Int] + logInfo("Received requestedFileLen = " + requestedFileLen) + + // Turn the timer OFF, if the sender responds before timeout + timeOutTimer.cancel() + + // Receive the file + if (requestedFileLen != -1) { + val readStartTime = System.currentTimeMillis + logInfo("BEGIN READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestPath)) + + // Receive data in an Array[Byte] + var recvByteArray = new Array[Byte](requestedFileLen) + var alreadyRead = 0 + var bytesRead = 0 + + while (alreadyRead != requestedFileLen) { + bytesRead = isSource.read(recvByteArray, alreadyRead, + requestedFileLen - alreadyRead) + if (bytesRead > 0) { + alreadyRead = alreadyRead + bytesRead + } + } + + // Make it available to the consumer + try { + receivedData.put((splitIndex, recvByteArray)) + } catch { + case e: Exception => { + logInfo("Exception during putting data into receivedData") + } + } + + // TODO: Updating stats before consumption is completed + hasSplitsBitVector.synchronized { + hasSplitsBitVector.set(splitIndex) + hasSplits += 1 + } + + // We have received splitIndex + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set(splitIndex, false) + } + + receptionSucceeded = true + + logInfo("END READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestPath)) + val readTime = System.currentTimeMillis - readStartTime + logInfo("Reading http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestPath) + " took " + readTime + " millis.") + } else { + throw new SparkException("ShuffleServer " + hostAddress + " does not have " + requestPath) + } + } catch { + // EOFException is expected to happen because sender can break + // connection due to timeout + case eofe: java.io.EOFException => { } + case e: Exception => { + logInfo("ShuffleClient had a " + e) + } + } finally { + // If reception failed, unset for future retry + if (!receptionSucceeded) { + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set(splitIndex, false) + } + } + cleanUpConnections() + } + } + + private def cleanUpConnections(): Unit = { + if (oisSource != null) { + oisSource.close() + } + if (oosSource != null) { + oosSource.close() + } + if (peerSocketToSource != null) { + peerSocketToSource.close() + } + } + } +} + +object CustomParallelLocalFileShuffle extends Logging { + private var initialized = false + private var nextShuffleId = new AtomicLong(0) + + // Variables initialized by initializeIfNeeded() + private var shuffleDir: File = null + + private var shuffleServer: ShuffleServer = null + private var serverAddress = InetAddress.getLocalHost.getHostAddress + private var serverPort: Int = -1 + + // Random number generator + var ranGen = new Random + + private def initializeIfNeeded() = synchronized { + if (!initialized) { + // TODO: localDir should be created by some mechanism common to Spark + // so that it can be shared among shuffle, broadcast, etc + val localDirRoot = System.getProperty("spark.local.dir", "/tmp") + var tries = 0 + var foundLocalDir = false + var localDir: File = null + var localDirUuid: UUID = null + + while (!foundLocalDir && tries < 10) { + tries += 1 + try { + localDirUuid = UUID.randomUUID + localDir = new File(localDirRoot, "spark-local-" + localDirUuid) + if (!localDir.exists) { + localDir.mkdirs() + foundLocalDir = true + } + } catch { + case e: Exception => + logWarning("Attempt " + tries + " to create local dir failed", e) + } + } + if (!foundLocalDir) { + logError("Failed 10 attempts to create local dir in " + localDirRoot) + System.exit(1) + } + shuffleDir = new File(localDir, "shuffle") + shuffleDir.mkdirs() + logInfo("Shuffle dir: " + shuffleDir) + + // Create and start the shuffleServer + shuffleServer = new ShuffleServer + shuffleServer.setDaemon(true) + shuffleServer.start() + logInfo("ShuffleServer started...") + + initialized = true + } + } + + def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int): File = { + initializeIfNeeded() + val dir = new File(shuffleDir, shuffleId + "/" + inputId) + dir.mkdirs() + val file = new File(dir, "" + outputId) + return file + } + + def newShuffleId(): Long = { + nextShuffleId.getAndIncrement() + } + + // Returns a standard ThreadFactory except all threads are daemons + private def newDaemonThreadFactory: ThreadFactory = { + new ThreadFactory { + def newThread(r: Runnable): Thread = { + var t = Executors.defaultThreadFactory.newThread(r) + t.setDaemon(true) + return t + } + } + } + + // Wrapper over newFixedThreadPool + def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = { + var threadPool = + Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor] + + threadPool.setThreadFactory(newDaemonThreadFactory) + + return threadPool + } + + class ShuffleServer + extends Thread with Logging { + var threadPool = + newDaemonFixedThreadPool(Shuffle.MaxTxConnections) + + var serverSocket: ServerSocket = null + + override def run: Unit = { + serverSocket = new ServerSocket(0) + serverPort = serverSocket.getLocalPort + + logInfo("ShuffleServer started with " + serverSocket) + logInfo("Local URI: http://" + serverAddress + ":" + serverPort) + + try { + while (true) { + var clientSocket: Socket = null + try { + clientSocket = serverSocket.accept() + } catch { + case e: Exception => { } + } + if (clientSocket != null) { + logInfo("Serve: Accepted new client connection:" + clientSocket) + try { + threadPool.execute(new ShuffleServerThread(clientSocket)) + } catch { + // In failure, close socket here; else, the thread will close it + case ioe: IOException => { + clientSocket.close() + } + } + } + } + } finally { + if (serverSocket != null) { + logInfo("ShuffleServer now stopping...") + serverSocket.close() + } + } + // Shutdown the thread pool + threadPool.shutdown() + } + + class ShuffleServerThread(val clientSocket: Socket) + extends Thread with Logging { + private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream] + os.flush() + private val bos = new BufferedOutputStream(os) + bos.flush() + private val oos = new ObjectOutputStream(os) + oos.flush() + private val ois = new ObjectInputStream(clientSocket.getInputStream) + + logInfo("new ShuffleServerThread is running") + + override def run: Unit = { + try { + // Receive requestPath from the receiver + var requestPath = ois.readObject.asInstanceOf[String] + logInfo("requestPath: " + shuffleDir + "/" + requestPath) + + // Open the file + var requestedFile: File = null + var requestedFileLen = -1 + try { + requestedFile = new File(shuffleDir + "/" + requestPath) + requestedFileLen = requestedFile.length.toInt + } catch { + case e: Exception => { } + } + + // Send the length of the requestPath to let the receiver know that + // transfer is about to start + // In the case of receiver timeout and connection close, this will + // throw a java.net.SocketException: Broken pipe + oos.writeObject(requestedFileLen) + oos.flush() + + logInfo("requestedFileLen = " + requestedFileLen) + + // Read and send the requested file + if (requestedFileLen != -1) { + // Read + var byteArray = new Array[Byte](requestedFileLen) + val bis = + new BufferedInputStream(new FileInputStream(requestedFile)) + + var bytesRead = bis.read(byteArray, 0, byteArray.length) + var alreadyRead = bytesRead + + while (alreadyRead < requestedFileLen) { + bytesRead = bis.read(byteArray, alreadyRead, + (byteArray.length - alreadyRead)) + if(bytesRead > 0) { + alreadyRead = alreadyRead + bytesRead + } + } + bis.close() + + // Send + bos.write(byteArray, 0, byteArray.length) + bos.flush() + } else { + // Close the connection + } + } catch { + // If something went wrong, e.g., the worker at the other end died etc + // then close everything up + // Exception can happen if the receiver stops receiving + case e: Exception => { + logInfo("ShuffleServerThread had a " + e) + } + } finally { + logInfo("ShuffleServerThread is closing streams and sockets") + ois.close() + // TODO: Following can cause "java.net.SocketException: Socket closed" + oos.close() + bos.close() + clientSocket.close() + } + } + } + } +} diff --git a/core/src/main/scala/spark/DfsShuffle.scala b/core/src/main/scala/spark/DfsShuffle.scala index 7a42bf2d06..bf91be7d2c 100644 --- a/core/src/main/scala/spark/DfsShuffle.scala +++ b/core/src/main/scala/spark/DfsShuffle.scala @@ -9,7 +9,6 @@ import scala.collection.mutable.HashMap import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem} - /** * A simple implementation of shuffle using a distributed file system. * @@ -82,7 +81,6 @@ class DfsShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { } } - /** * Companion object of DfsShuffle; responsible for initializing a Hadoop * FileSystem object based on the spark.dfs property and generating names diff --git a/core/src/main/scala/spark/HttpBlockedLocalFileShuffle.scala b/core/src/main/scala/spark/HttpBlockedLocalFileShuffle.scala new file mode 100644 index 0000000000..8e89cadfdd --- /dev/null +++ b/core/src/main/scala/spark/HttpBlockedLocalFileShuffle.scala @@ -0,0 +1,471 @@ +package spark + +import java.io._ +import java.net._ +import java.util.{BitSet, Random, Timer, TimerTask, UUID} +import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory} + +import scala.collection.mutable.{ArrayBuffer, HashMap} + +/** + * An implementation of shuffle using local files served through HTTP where + * receivers create simultaneous connections to multiple servers by setting the + * 'spark.shuffle.maxRxConnections' config option. + * + * By controlling the 'spark.shuffle.blockSize' config option one can also + * control the largest block size to retrieve by each reducers. An INDEX file + * keeps track of block boundaries instead of creating many smaller files. + * + * TODO: Add support for compression when spark.compress is set to true. + */ +@serializable +class HttpBlockedLocalFileShuffle[K, V, C] +extends Shuffle[K, V, C] with Logging { + @transient var totalSplits = 0 + @transient var hasSplits = 0 + + @transient var blocksInSplit: Array[ArrayBuffer[Long]] = null + @transient var totalBlocksInSplit: Array[Int] = null + @transient var hasBlocksInSplit: Array[Int] = null + + @transient var hasSplitsBitVector: BitSet = null + @transient var splitsInRequestBitVector: BitSet = null + + @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null + @transient var combiners: HashMap[K,C] = null + + override def compute(input: RDD[(K, V)], + numOutputSplits: Int, + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C) + : RDD[(K, C)] = + { + val sc = input.sparkContext + val shuffleId = HttpBlockedLocalFileShuffle.newShuffleId() + logInfo("Shuffle ID: " + shuffleId) + + val splitRdd = new NumberedSplitRDD(input) + val numInputSplits = splitRdd.splits.size + + // Run a parallel map and collect to write the intermediate data files, + // returning a list of inputSplitId -> serverUri pairs + val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => { + val myIndex = pair._1 + val myIterator = pair._2 + val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C]) + for ((k, v) <- myIterator) { + var bucketId = k.hashCode % numOutputSplits + if (bucketId < 0) { // Fix bucket ID if hash code was negative + bucketId += numOutputSplits + } + val bucket = buckets(bucketId) + bucket(k) = bucket.get(k) match { + case Some(c) => mergeValue(c, v) + case None => createCombiner(v) + } + } + + for (i <- 0 until numOutputSplits) { + // Open the INDEX file + var indexFile: File = + HttpBlockedLocalFileShuffle.getBlockIndexOutputFile(shuffleId, + myIndex, i) + var indexOut = new ObjectOutputStream(new FileOutputStream(indexFile)) + var indexDirty: Boolean = true + var alreadyWritten: Long = 0 + + // Open the actual file + var file: File = + HttpBlockedLocalFileShuffle.getOutputFile(shuffleId, myIndex, i) + val out = new ObjectOutputStream(new FileOutputStream(file)) + + val writeStartTime = System.currentTimeMillis + logInfo("BEGIN WRITE: " + file) + + buckets(i).foreach(pair => { + out.writeObject(pair) + out.flush() + indexDirty = true + + // Update the INDEX file if more than blockSize limit has been written + if (file.length - alreadyWritten > Shuffle.BlockSize) { + indexOut.writeObject(file.length) + indexDirty = false + alreadyWritten = file.length + } + }) + + // Write down the last range if it was not written + if (indexDirty) { + indexOut.writeObject(file.length) + } + + out.close() + indexOut.close() + + logInfo("END WRITE: " + file) + val writeTime = (System.currentTimeMillis - writeStartTime) + logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.") + } + + (myIndex, HttpBlockedLocalFileShuffle.serverUri) + }).collect() + + // TODO: Could broadcast outputLocs + + // Return an RDD that does each of the merges for a given partition + val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits) + return indexes.flatMap((myId: Int) => { + totalSplits = outputLocs.size + hasSplits = 0 + + blocksInSplit = Array.tabulate(totalSplits)(_ => new ArrayBuffer[Long]) + totalBlocksInSplit = Array.tabulate(totalSplits)(_ => -1) + hasBlocksInSplit = Array.tabulate(totalSplits)(_ => 0) + + hasSplitsBitVector = new BitSet(totalSplits) + splitsInRequestBitVector = new BitSet(totalSplits) + + receivedData = new LinkedBlockingQueue[(Int, Array[Byte])] + combiners = new HashMap[K, C] + + var threadPool = Shuffle.newDaemonFixedThreadPool( + Shuffle.MaxRxConnections) + + while (hasSplits < totalSplits) { + var numThreadsToCreate = + Math.min(totalSplits, Shuffle.MaxRxConnections) - + threadPool.getActiveCount + + while (hasSplits < totalSplits && numThreadsToCreate > 0) { + // Select a random split to pull + val splitIndex = selectRandomSplit + + if (splitIndex != -1) { + val (inputId, serverUri) = outputLocs(splitIndex) + + threadPool.execute(new ShuffleClient(serverUri, shuffleId.toInt, + inputId, myId, splitIndex)) + + // splitIndex is in transit. Will be unset in the ShuffleClient + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set(splitIndex) + } + } + + numThreadsToCreate = numThreadsToCreate - 1 + } + + // Sleep for a while before creating new threads + Thread.sleep(Shuffle.MinKnockInterval) + } + + threadPool.shutdown() + + // Start consumer + // TODO: Consumption is delayed until everything has been received. + // Otherwise it interferes with network performance + var shuffleConsumer = new ShuffleConsumer(mergeCombiners) + shuffleConsumer.setDaemon(true) + shuffleConsumer.start() + logInfo("ShuffleConsumer started...") + + // Don't return until consumption is finished + // while (receivedData.size > 0) { + // Thread.sleep(Shuffle.MinKnockInterval) + // } + + // Wait till shuffleConsumer is done + shuffleConsumer.join + + combiners + }) + } + + def selectRandomSplit: Int = { + var requiredSplits = new ArrayBuffer[Int] + + synchronized { + for (i <- 0 until totalSplits) { + if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) { + requiredSplits += i + } + } + } + + if (requiredSplits.size > 0) { + requiredSplits(HttpBlockedLocalFileShuffle.ranGen.nextInt( + requiredSplits.size)) + } else { + -1 + } + } + + class ShuffleConsumer(mergeCombiners: (C, C) => C) + extends Thread with Logging { + override def run: Unit = { + // Run until all splits are here + while (receivedData.size > 0) { + var splitIndex = -1 + var recvByteArray: Array[Byte] = null + + try { + var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])] + splitIndex = tempPair._1 + recvByteArray = tempPair._2 + } catch { + case e: Exception => { + logInfo("Exception during taking data from receivedData") + } + } + + val inputStream = + new ObjectInputStream(new ByteArrayInputStream(recvByteArray)) + + try{ + while (true) { + val (k, c) = inputStream.readObject.asInstanceOf[(K, C)] + combiners(k) = combiners.get(k) match { + case Some(oldC) => mergeCombiners(oldC, c) + case None => c + } + } + } catch { + case e: EOFException => { } + } + inputStream.close() + } + } + } + + class ShuffleClient(serverUri: String, shuffleId: Int, + inputId: Int, myId: Int, splitIndex: Int) + extends Thread with Logging { + private var receptionSucceeded = false + + override def run: Unit = { + try { + // First get the INDEX file if totalBlocksInSplit(splitIndex) is unknown + if (totalBlocksInSplit(splitIndex) == -1) { + val url = "%s/shuffle/%d/%d/INDEX-%d".format(serverUri, shuffleId, + inputId, myId) + val inputStream = new ObjectInputStream(new URL(url).openStream()) + + try { + while (true) { + blocksInSplit(splitIndex) += + inputStream.readObject().asInstanceOf[Long] + } + } catch { + case e: EOFException => {} + } + + totalBlocksInSplit(splitIndex) = blocksInSplit(splitIndex).size + inputStream.close() + } + + // Open connection + val urlString = + "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, inputId, myId) + val url = new URL(urlString) + val httpConnection = + url.openConnection().asInstanceOf[HttpURLConnection] + + // Set the range to download + val blockStartsAt = hasBlocksInSplit(splitIndex) match { + case 0 => 0 + case _ => blocksInSplit(splitIndex)(hasBlocksInSplit(splitIndex) - 1) + 1 + } + val blockEndsAt = blocksInSplit(splitIndex)(hasBlocksInSplit(splitIndex)) + httpConnection.setRequestProperty("Range", + "bytes=" + blockStartsAt + "-" + blockEndsAt) + + // Connect to the server + httpConnection.connect() + + val urStringWithRange = + urlString + "[%d:%d]".format(blockStartsAt, blockEndsAt) + val readStartTime = System.currentTimeMillis + logInfo("BEGIN READ: " + urStringWithRange) + + // Receive data in an Array[Byte] + val requestedFileLen: Int = (blockEndsAt - blockStartsAt).toInt + 1 + var recvByteArray = new Array[Byte](requestedFileLen) + var alreadyRead = 0 + var bytesRead = 0 + + val isSource = httpConnection.getInputStream() + while (alreadyRead != requestedFileLen) { + bytesRead = isSource.read(recvByteArray, alreadyRead, + requestedFileLen - alreadyRead) + if (bytesRead > 0) { + alreadyRead = alreadyRead + bytesRead + } + } + + // Disconnect + httpConnection.disconnect() + + // Make it available to the consumer + try { + receivedData.put((splitIndex, recvByteArray)) + } catch { + case e: Exception => { + logInfo("Exception during putting data into receivedData") + } + } + + // TODO: Updating stats before consumption is completed + hasBlocksInSplit(splitIndex) = hasBlocksInSplit(splitIndex) + 1 + + // Split has been received only if all the blocks have been received + if (hasBlocksInSplit(splitIndex) == totalBlocksInSplit(splitIndex)) { + hasSplitsBitVector.synchronized { + hasSplitsBitVector.set(splitIndex) + hasSplits += 1 + } + } + + // We have received splitIndex + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set(splitIndex, false) + } + + receptionSucceeded = true + + logInfo("END READ: " + urStringWithRange) + val readTime = System.currentTimeMillis - readStartTime + logInfo("Reading " + urStringWithRange + " took " + readTime + " millis.") + } catch { + // EOFException is expected to happen because sender can break + // connection due to timeout + case eofe: java.io.EOFException => { } + case e: Exception => { + logInfo("ShuffleClient had a " + e) + } + } finally { + // If reception failed, unset for future retry + if (!receptionSucceeded) { + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set(splitIndex, false) + } + } + } + } + } +} + +object HttpBlockedLocalFileShuffle extends Logging { + private var initialized = false + private var nextShuffleId = new AtomicLong(0) + + // Variables initialized by initializeIfNeeded() + private var shuffleDir: File = null + private var server: HttpServer = null + private var serverUri: String = null + + // Random number generator + var ranGen = new Random + + private def initializeIfNeeded() = synchronized { + if (!initialized) { + // TODO: localDir should be created by some mechanism common to Spark + // so that it can be shared among shuffle, broadcast, etc + val localDirRoot = System.getProperty("spark.local.dir", "/tmp") + var tries = 0 + var foundLocalDir = false + var localDir: File = null + var localDirUuid: UUID = null + while (!foundLocalDir && tries < 10) { + tries += 1 + try { + localDirUuid = UUID.randomUUID + localDir = new File(localDirRoot, "spark-local-" + localDirUuid) + if (!localDir.exists) { + localDir.mkdirs() + foundLocalDir = true + } + } catch { + case e: Exception => + logWarning("Attempt " + tries + " to create local dir failed", e) + } + } + if (!foundLocalDir) { + logError("Failed 10 attempts to create local dir in " + localDirRoot) + System.exit(1) + } + shuffleDir = new File(localDir, "shuffle") + shuffleDir.mkdirs() + logInfo("Shuffle dir: " + shuffleDir) + + val extServerPort = System.getProperty( + "spark.localFileShuffle.external.server.port", "-1").toInt + if (extServerPort != -1) { + // We're using an external HTTP server; set URI relative to its root + var extServerPath = System.getProperty( + "spark.localFileShuffle.external.server.path", "") + if (extServerPath != "" && !extServerPath.endsWith("/")) { + extServerPath += "/" + } + serverUri = "http://%s:%d/%s/spark-local-%s".format( + Utils.localIpAddress, extServerPort, extServerPath, localDirUuid) + } else { + // Create our own server + server = new HttpServer(localDir) + server.start() + serverUri = server.uri + } + initialized = true + logInfo("Local URI: " + serverUri) + } + } + + def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int): File = { + initializeIfNeeded() + val dir = new File(shuffleDir, shuffleId + "/" + inputId) + dir.mkdirs() + val file = new File(dir, "" + outputId) + return file + } + + def getBlockIndexOutputFile(shuffleId: Long, inputId: Int, + outputId: Int): File = { + initializeIfNeeded() + val dir = new File(shuffleDir, shuffleId + "/" + inputId) + dir.mkdirs() + val file = new File(dir, "INDEX-" + outputId) + return file + } + + def getServerUri(): String = { + initializeIfNeeded() + serverUri + } + + def newShuffleId(): Long = { + nextShuffleId.getAndIncrement() + } + + // Returns a standard ThreadFactory except all threads are daemons + private def newDaemonThreadFactory: ThreadFactory = { + new ThreadFactory { + def newThread(r: Runnable): Thread = { + var t = Executors.defaultThreadFactory.newThread(r) + t.setDaemon(true) + return t + } + } + } + + // Wrapper over newFixedThreadPool + def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = { + var threadPool = + Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor] + + threadPool.setThreadFactory(newDaemonThreadFactory) + + return threadPool + } +} diff --git a/core/src/main/scala/spark/HttpParallelLocalFileShuffle.scala b/core/src/main/scala/spark/HttpParallelLocalFileShuffle.scala new file mode 100644 index 0000000000..8e7b897668 --- /dev/null +++ b/core/src/main/scala/spark/HttpParallelLocalFileShuffle.scala @@ -0,0 +1,391 @@ +package spark + +import java.io._ +import java.net._ +import java.util.{BitSet, Random, Timer, TimerTask, UUID} +import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory} + +import scala.collection.mutable.{ArrayBuffer, HashMap} + +/** + * An implementation of shuffle using local files served through HTTP where + * receivers create simultaneous connections to multiple servers by setting the + * 'spark.shuffle.maxRxConnections' config option. + * + * TODO: Add support for compression when spark.compress is set to true. + */ +@serializable +class HttpParallelLocalFileShuffle[K, V, C] +extends Shuffle[K, V, C] with Logging { + @transient var totalSplits = 0 + @transient var hasSplits = 0 + + @transient var hasSplitsBitVector: BitSet = null + @transient var splitsInRequestBitVector: BitSet = null + + @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null + @transient var combiners: HashMap[K,C] = null + + override def compute(input: RDD[(K, V)], + numOutputSplits: Int, + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C) + : RDD[(K, C)] = + { + val sc = input.sparkContext + val shuffleId = HttpParallelLocalFileShuffle.newShuffleId() + logInfo("Shuffle ID: " + shuffleId) + + val splitRdd = new NumberedSplitRDD(input) + val numInputSplits = splitRdd.splits.size + + // Run a parallel map and collect to write the intermediate data files, + // returning a list of inputSplitId -> serverUri pairs + val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => { + val myIndex = pair._1 + val myIterator = pair._2 + val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C]) + for ((k, v) <- myIterator) { + var bucketId = k.hashCode % numOutputSplits + if (bucketId < 0) { // Fix bucket ID if hash code was negative + bucketId += numOutputSplits + } + val bucket = buckets(bucketId) + bucket(k) = bucket.get(k) match { + case Some(c) => mergeValue(c, v) + case None => createCombiner(v) + } + } + + for (i <- 0 until numOutputSplits) { + val file = + HttpParallelLocalFileShuffle.getOutputFile(shuffleId, myIndex, i) + + val writeStartTime = System.currentTimeMillis + logInfo("BEGIN WRITE: " + file) + + val out = new ObjectOutputStream(new FileOutputStream(file)) + buckets(i).foreach(pair => out.writeObject(pair)) + + logInfo("END WRITE: " + file) + val writeTime = System.currentTimeMillis - writeStartTime + logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.") + + out.close() + } + + (myIndex, HttpParallelLocalFileShuffle.serverUri) + }).collect() + + // TODO: Could broadcast outputLocs + + // Return an RDD that does each of the merges for a given partition + val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits) + return indexes.flatMap((myId: Int) => { + totalSplits = outputLocs.size + hasSplits = 0 + + hasSplitsBitVector = new BitSet(totalSplits) + splitsInRequestBitVector = new BitSet(totalSplits) + + receivedData = new LinkedBlockingQueue[(Int, Array[Byte])] + combiners = new HashMap[K, C] + + var threadPool = Shuffle.newDaemonFixedThreadPool( + Shuffle.MaxRxConnections) + + while (hasSplits < totalSplits) { + var numThreadsToCreate = + Math.min(totalSplits, Shuffle.MaxRxConnections) - + threadPool.getActiveCount + + while (hasSplits < totalSplits && numThreadsToCreate > 0) { + // Select a random split to pull + val splitIndex = selectRandomSplit + + if (splitIndex != -1) { + val (inputId, serverUri) = outputLocs(splitIndex) + + threadPool.execute(new ShuffleClient(serverUri, shuffleId.toInt, + inputId, myId, splitIndex)) + + // splitIndex is in transit. Will be unset in the ShuffleClient + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set(splitIndex) + } + } + + numThreadsToCreate = numThreadsToCreate - 1 + } + + // Sleep for a while before creating new threads + Thread.sleep(Shuffle.MinKnockInterval) + } + + threadPool.shutdown() + + // Start consumer + // TODO: Consumption is delayed until everything has been received. + // Otherwise it interferes with network performance + var shuffleConsumer = new ShuffleConsumer(mergeCombiners) + shuffleConsumer.setDaemon(true) + shuffleConsumer.start() + logInfo("ShuffleConsumer started...") + + // Don't return until consumption is finished + // while (receivedData.size > 0) { + // Thread.sleep(Shuffle.MinKnockInterval) + // } + + // Wait till shuffleConsumer is done + shuffleConsumer.join + + combiners + }) + } + + def selectRandomSplit: Int = { + var requiredSplits = new ArrayBuffer[Int] + + synchronized { + for (i <- 0 until totalSplits) { + if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) { + requiredSplits += i + } + } + } + + if (requiredSplits.size > 0) { + requiredSplits(HttpParallelLocalFileShuffle.ranGen.nextInt( + requiredSplits.size)) + } else { + -1 + } + } + + class ShuffleConsumer(mergeCombiners: (C, C) => C) + extends Thread with Logging { + override def run: Unit = { + // Run until all splits are here + while (receivedData.size > 0) { + var splitIndex = -1 + var recvByteArray: Array[Byte] = null + + try { + var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])] + splitIndex = tempPair._1 + recvByteArray = tempPair._2 + } catch { + case e: Exception => { + logInfo("Exception during taking data from receivedData") + } + } + + val inputStream = + new ObjectInputStream(new ByteArrayInputStream(recvByteArray)) + + try{ + while (true) { + val (k, c) = inputStream.readObject.asInstanceOf[(K, C)] + combiners(k) = combiners.get(k) match { + case Some(oldC) => mergeCombiners(oldC, c) + case None => c + } + } + } catch { + case e: EOFException => { } + } + inputStream.close() + } + } + } + + class ShuffleClient(serverUri: String, shuffleId: Int, + inputId: Int, myId: Int, splitIndex: Int) + extends Thread with Logging { + private var receptionSucceeded = false + + override def run: Unit = { + try { + // Open connection + val urlString = + "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, inputId, myId) + val url = new URL(urlString) + val httpConnection = + url.openConnection().asInstanceOf[HttpURLConnection] + + // Connect to the server + httpConnection.connect() + + // Receive file length + var requestedFileLen = httpConnection.getContentLength + + val readStartTime = System.currentTimeMillis + logInfo("BEGIN READ: " + urlString) + + // Receive data in an Array[Byte] + var recvByteArray = new Array[Byte](requestedFileLen) + var alreadyRead = 0 + var bytesRead = 0 + + val isSource = httpConnection.getInputStream() + while (alreadyRead != requestedFileLen) { + bytesRead = isSource.read(recvByteArray, alreadyRead, + requestedFileLen - alreadyRead) + if (bytesRead > 0) { + alreadyRead = alreadyRead + bytesRead + } + } + + // Disconnect + httpConnection.disconnect() + + // Make it available to the consumer + try { + receivedData.put((splitIndex, recvByteArray)) + } catch { + case e: Exception => { + logInfo("Exception during putting data into receivedData") + } + } + + // TODO: Updating stats before consumption is completed + hasSplitsBitVector.synchronized { + hasSplitsBitVector.set(splitIndex) + hasSplits += 1 + } + + // We have received splitIndex + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set(splitIndex, false) + } + + receptionSucceeded = true + + logInfo("END READ: " + urlString) + val readTime = System.currentTimeMillis - readStartTime + logInfo("Reading " + urlString + " took " + readTime + " millis.") + } catch { + // EOFException is expected to happen because sender can break + // connection due to timeout + case eofe: java.io.EOFException => { } + case e: Exception => { + logInfo("ShuffleClient had a " + e) + } + } finally { + // If reception failed, unset for future retry + if (!receptionSucceeded) { + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set(splitIndex, false) + } + } + } + } + } +} + +object HttpParallelLocalFileShuffle extends Logging { + private var initialized = false + private var nextShuffleId = new AtomicLong(0) + + // Variables initialized by initializeIfNeeded() + private var shuffleDir: File = null + private var server: HttpServer = null + private var serverUri: String = null + + // Random number generator + var ranGen = new Random + + private def initializeIfNeeded() = synchronized { + if (!initialized) { + // TODO: localDir should be created by some mechanism common to Spark + // so that it can be shared among shuffle, broadcast, etc + val localDirRoot = System.getProperty("spark.local.dir", "/tmp") + var tries = 0 + var foundLocalDir = false + var localDir: File = null + var localDirUuid: UUID = null + while (!foundLocalDir && tries < 10) { + tries += 1 + try { + localDirUuid = UUID.randomUUID + localDir = new File(localDirRoot, "spark-local-" + localDirUuid) + if (!localDir.exists) { + localDir.mkdirs() + foundLocalDir = true + } + } catch { + case e: Exception => + logWarning("Attempt " + tries + " to create local dir failed", e) + } + } + if (!foundLocalDir) { + logError("Failed 10 attempts to create local dir in " + localDirRoot) + System.exit(1) + } + shuffleDir = new File(localDir, "shuffle") + shuffleDir.mkdirs() + logInfo("Shuffle dir: " + shuffleDir) + + val extServerPort = System.getProperty( + "spark.localFileShuffle.external.server.port", "-1").toInt + if (extServerPort != -1) { + // We're using an external HTTP server; set URI relative to its root + var extServerPath = System.getProperty( + "spark.localFileShuffle.external.server.path", "") + if (extServerPath != "" && !extServerPath.endsWith("/")) { + extServerPath += "/" + } + serverUri = "http://%s:%d/%s/spark-local-%s".format( + Utils.localIpAddress, extServerPort, extServerPath, localDirUuid) + } else { + // Create our own server + server = new HttpServer(localDir) + server.start() + serverUri = server.uri + } + initialized = true + logInfo("Local URI: " + serverUri) + } + } + + def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int): File = { + initializeIfNeeded() + val dir = new File(shuffleDir, shuffleId + "/" + inputId) + dir.mkdirs() + val file = new File(dir, "" + outputId) + return file + } + + def getServerUri(): String = { + initializeIfNeeded() + serverUri + } + + def newShuffleId(): Long = { + nextShuffleId.getAndIncrement() + } + + // Returns a standard ThreadFactory except all threads are daemons + private def newDaemonThreadFactory: ThreadFactory = { + new ThreadFactory { + def newThread(r: Runnable): Thread = { + var t = Executors.defaultThreadFactory.newThread(r) + t.setDaemon(true) + return t + } + } + } + + // Wrapper over newFixedThreadPool + def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = { + var threadPool = + Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor] + + threadPool.setThreadFactory(newDaemonThreadFactory) + + return threadPool + } +} diff --git a/core/src/main/scala/spark/ManualBlockedLocalFileShuffle.scala b/core/src/main/scala/spark/ManualBlockedLocalFileShuffle.scala new file mode 100644 index 0000000000..cfd84fdb83 --- /dev/null +++ b/core/src/main/scala/spark/ManualBlockedLocalFileShuffle.scala @@ -0,0 +1,465 @@ +package spark + +import java.io._ +import java.net._ +import java.util.{BitSet, Random, Timer, TimerTask, UUID} +import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory} + +import scala.collection.mutable.{ArrayBuffer, HashMap} + +/** + * An implementation of shuffle using local files served through HTTP where + * receivers create simultaneous connections to multiple servers by setting the + * 'spark.shuffle.maxRxConnections' config option. + * + * By controlling the 'spark.shuffle.blockSize' config option one can also + * control the largest block size to divide each map output into. Essentially, + * instead of creating one large output file for each reducer, maps create + * multiple smaller files to enable finer level of engagement. + * + * TODO: Add support for compression when spark.compress is set to true. + */ +@serializable +class ManualBlockedLocalFileShuffle[K, V, C] +extends Shuffle[K, V, C] with Logging { + @transient var totalSplits = 0 + @transient var hasSplits = 0 + + @transient var totalBlocksInSplit: Array[Int] = null + @transient var hasBlocksInSplit: Array[Int] = null + + @transient var hasSplitsBitVector: BitSet = null + @transient var splitsInRequestBitVector: BitSet = null + + @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null + @transient var combiners: HashMap[K,C] = null + + override def compute(input: RDD[(K, V)], + numOutputSplits: Int, + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C) + : RDD[(K, C)] = + { + val sc = input.sparkContext + val shuffleId = ManualBlockedLocalFileShuffle.newShuffleId() + logInfo("Shuffle ID: " + shuffleId) + + val splitRdd = new NumberedSplitRDD(input) + val numInputSplits = splitRdd.splits.size + + // Run a parallel map and collect to write the intermediate data files, + // returning a list of inputSplitId -> serverUri pairs + val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => { + val myIndex = pair._1 + val myIterator = pair._2 + val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C]) + for ((k, v) <- myIterator) { + var bucketId = k.hashCode % numOutputSplits + if (bucketId < 0) { // Fix bucket ID if hash code was negative + bucketId += numOutputSplits + } + val bucket = buckets(bucketId) + bucket(k) = bucket.get(k) match { + case Some(c) => mergeValue(c, v) + case None => createCombiner(v) + } + } + + for (i <- 0 until numOutputSplits) { + var blockNum = 0 + var isDirty = false + var file: File = null + var out: ObjectOutputStream = null + + var writeStartTime: Long = 0 + + buckets(i).foreach(pair => { + // Open a new file if necessary + if (!isDirty) { + file = ManualBlockedLocalFileShuffle.getOutputFile(shuffleId, + myIndex, i, blockNum) + writeStartTime = System.currentTimeMillis + logInfo("BEGIN WRITE: " + file) + + out = new ObjectOutputStream(new FileOutputStream(file)) + } + + out.writeObject(pair) + out.flush() + isDirty = true + + // Close the old file if has crossed the blockSize limit + if (file.length > Shuffle.BlockSize) { + out.close() + logInfo("END WRITE: " + file) + val writeTime = System.currentTimeMillis - writeStartTime + logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.") + + blockNum = blockNum + 1 + isDirty = false + } + }) + + if (isDirty) { + out.close() + logInfo("END WRITE: " + file) + val writeTime = System.currentTimeMillis - writeStartTime + logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.") + + blockNum = blockNum + 1 + } + + // Write the BLOCKNUM file + file = ManualBlockedLocalFileShuffle.getBlockNumOutputFile(shuffleId, + myIndex, i) + out = new ObjectOutputStream(new FileOutputStream(file)) + out.writeObject(blockNum) + out.close() + } + + (myIndex, ManualBlockedLocalFileShuffle.serverUri) + }).collect() + + // TODO: Could broadcast outputLocs + + // Return an RDD that does each of the merges for a given partition + val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits) + return indexes.flatMap((myId: Int) => { + totalSplits = outputLocs.size + hasSplits = 0 + + totalBlocksInSplit = Array.tabulate(totalSplits)(_ => -1) + hasBlocksInSplit = Array.tabulate(totalSplits)(_ => 0) + + hasSplitsBitVector = new BitSet(totalSplits) + splitsInRequestBitVector = new BitSet(totalSplits) + + receivedData = new LinkedBlockingQueue[(Int, Array[Byte])] + combiners = new HashMap[K, C] + + var threadPool = Shuffle.newDaemonFixedThreadPool( + Shuffle.MaxRxConnections) + + while (hasSplits < totalSplits) { + var numThreadsToCreate = + Math.min(totalSplits, Shuffle.MaxRxConnections) - + threadPool.getActiveCount + + while (hasSplits < totalSplits && numThreadsToCreate > 0) { + // Select a random split to pull + val splitIndex = selectRandomSplit + + if (splitIndex != -1) { + val (inputId, serverUri) = outputLocs(splitIndex) + + threadPool.execute(new ShuffleClient(serverUri, shuffleId.toInt, + inputId, myId, splitIndex, mergeCombiners)) + + // splitIndex is in transit. Will be unset in the ShuffleClient + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set(splitIndex) + } + } + + numThreadsToCreate = numThreadsToCreate - 1 + } + + // Sleep for a while before creating new threads + Thread.sleep(Shuffle.MinKnockInterval) + } + + threadPool.shutdown() + + // Start consumer + // TODO: Consumption is delayed until everything has been received. + // Otherwise it interferes with network performance + var shuffleConsumer = new ShuffleConsumer(mergeCombiners) + shuffleConsumer.setDaemon(true) + shuffleConsumer.start() + logInfo("ShuffleConsumer started...") + + // Don't return until consumption is finished + // while (receivedData.size > 0) { + // Thread.sleep(Shuffle.MinKnockInterval) + // } + + // Wait till shuffleConsumer is done + shuffleConsumer.join + + combiners + }) + } + + def selectRandomSplit: Int = { + var requiredSplits = new ArrayBuffer[Int] + + synchronized { + for (i <- 0 until totalSplits) { + if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) { + requiredSplits += i + } + } + } + + if (requiredSplits.size > 0) { + requiredSplits(ManualBlockedLocalFileShuffle.ranGen.nextInt( + requiredSplits.size)) + } else { + -1 + } + } + + class ShuffleConsumer(mergeCombiners: (C, C) => C) + extends Thread with Logging { + override def run: Unit = { + // Run until all splits are here + while (receivedData.size > 0) { + var splitIndex = -1 + var recvByteArray: Array[Byte] = null + + try { + var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])] + splitIndex = tempPair._1 + recvByteArray = tempPair._2 + } catch { + case e: Exception => { + logInfo("Exception during taking data from receivedData") + } + } + + val inputStream = + new ObjectInputStream(new ByteArrayInputStream(recvByteArray)) + + try{ + while (true) { + val (k, c) = inputStream.readObject.asInstanceOf[(K, C)] + combiners(k) = combiners.get(k) match { + case Some(oldC) => mergeCombiners(oldC, c) + case None => c + } + } + } catch { + case e: EOFException => { } + } + inputStream.close() + } + } + } + + class ShuffleClient(serverUri: String, shuffleId: Int, + inputId: Int, myId: Int, splitIndex: Int, + mergeCombiners: (C, C) => C) + extends Thread with Logging { + private var receptionSucceeded = false + + override def run: Unit = { + try { + // Everything will break if BLOCKNUM is not correctly received + // First get BLOCKNUM file if totalBlocksInSplit(splitIndex) is unknown + if (totalBlocksInSplit(splitIndex) == -1) { + val url = "%s/shuffle/%d/%d/BLOCKNUM-%d".format(serverUri, shuffleId, + inputId, myId) + val inputStream = new ObjectInputStream(new URL(url).openStream()) + totalBlocksInSplit(splitIndex) = + inputStream.readObject().asInstanceOf[Int] + inputStream.close() + } + + // Open connection + val urlString = + "%s/shuffle/%d/%d/%d-%d".format(serverUri, shuffleId, inputId, + myId, hasBlocksInSplit(splitIndex)) + val url = new URL(urlString) + val httpConnection = + url.openConnection().asInstanceOf[HttpURLConnection] + + // Connect to the server + httpConnection.connect() + + // Receive file length + var requestedFileLen = httpConnection.getContentLength + + val readStartTime = System.currentTimeMillis + logInfo("BEGIN READ: " + url) + + // Receive data in an Array[Byte] + var recvByteArray = new Array[Byte](requestedFileLen) + var alreadyRead = 0 + var bytesRead = 0 + + val isSource = httpConnection.getInputStream() + while (alreadyRead != requestedFileLen) { + bytesRead = isSource.read(recvByteArray, alreadyRead, + requestedFileLen - alreadyRead) + if (bytesRead > 0) { + alreadyRead = alreadyRead + bytesRead + } + } + + // Disconnect + httpConnection.disconnect() + + // Make it available to the consumer + try { + receivedData.put((splitIndex, recvByteArray)) + } catch { + case e: Exception => { + logInfo("Exception during putting data into receivedData") + } + } + + // TODO: Updating stats before consumption is completed + hasBlocksInSplit(splitIndex) = hasBlocksInSplit(splitIndex) + 1 + + // Split has been received only if all the blocks have been received + if (hasBlocksInSplit(splitIndex) == totalBlocksInSplit(splitIndex)) { + hasSplitsBitVector.synchronized { + hasSplitsBitVector.set(splitIndex) + hasSplits += 1 + } + } + + // We have received splitIndex + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set(splitIndex, false) + } + + receptionSucceeded = true + + logInfo("END READ: " + url) + val readTime = System.currentTimeMillis - readStartTime + logInfo("Reading " + url + " took " + readTime + " millis.") + } catch { + // EOFException is expected to happen because sender can break + // connection due to timeout + case eofe: java.io.EOFException => { } + case e: Exception => { + logInfo("ShuffleClient had a " + e) + } + } finally { + // If reception failed, unset for future retry + if (!receptionSucceeded) { + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set(splitIndex, false) + } + } + } + } + } +} + +object ManualBlockedLocalFileShuffle extends Logging { + private var initialized = false + private var nextShuffleId = new AtomicLong(0) + + // Variables initialized by initializeIfNeeded() + private var shuffleDir: File = null + private var server: HttpServer = null + private var serverUri: String = null + + // Random number generator + var ranGen = new Random + + private def initializeIfNeeded() = synchronized { + if (!initialized) { + // TODO: localDir should be created by some mechanism common to Spark + // so that it can be shared among shuffle, broadcast, etc + val localDirRoot = System.getProperty("spark.local.dir", "/tmp") + var tries = 0 + var foundLocalDir = false + var localDir: File = null + var localDirUuid: UUID = null + while (!foundLocalDir && tries < 10) { + tries += 1 + try { + localDirUuid = UUID.randomUUID + localDir = new File(localDirRoot, "spark-local-" + localDirUuid) + if (!localDir.exists) { + localDir.mkdirs() + foundLocalDir = true + } + } catch { + case e: Exception => + logWarning("Attempt " + tries + " to create local dir failed", e) + } + } + if (!foundLocalDir) { + logError("Failed 10 attempts to create local dir in " + localDirRoot) + System.exit(1) + } + shuffleDir = new File(localDir, "shuffle") + shuffleDir.mkdirs() + logInfo("Shuffle dir: " + shuffleDir) + + val extServerPort = System.getProperty( + "spark.localFileShuffle.external.server.port", "-1").toInt + if (extServerPort != -1) { + // We're using an external HTTP server; set URI relative to its root + var extServerPath = System.getProperty( + "spark.localFileShuffle.external.server.path", "") + if (extServerPath != "" && !extServerPath.endsWith("/")) { + extServerPath += "/" + } + serverUri = "http://%s:%d/%s/spark-local-%s".format( + Utils.localIpAddress, extServerPort, extServerPath, localDirUuid) + } else { + // Create our own server + server = new HttpServer(localDir) + server.start() + serverUri = server.uri + } + initialized = true + logInfo("Local URI: " + serverUri) + } + } + + def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int, + blockId: Int): File = { + initializeIfNeeded() + val dir = new File(shuffleDir, shuffleId + "/" + inputId) + dir.mkdirs() + val file = new File(dir, "%d-%d".format(outputId, blockId)) + return file + } + + def getBlockNumOutputFile(shuffleId: Long, inputId: Int, + outputId: Int): File = { + initializeIfNeeded() + val dir = new File(shuffleDir, shuffleId + "/" + inputId) + dir.mkdirs() + val file = new File(dir, "BLOCKNUM-" + outputId) + return file + } + + def getServerUri(): String = { + initializeIfNeeded() + serverUri + } + + def newShuffleId(): Long = { + nextShuffleId.getAndIncrement() + } + + // Returns a standard ThreadFactory except all threads are daemons + private def newDaemonThreadFactory: ThreadFactory = { + new ThreadFactory { + def newThread(r: Runnable): Thread = { + var t = Executors.defaultThreadFactory.newThread(r) + t.setDaemon(true) + return t + } + } + } + + // Wrapper over newFixedThreadPool + def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = { + var threadPool = + Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor] + + threadPool.setThreadFactory(newDaemonThreadFactory) + + return threadPool + } +} diff --git a/core/src/main/scala/spark/Shuffle.scala b/core/src/main/scala/spark/Shuffle.scala index 4c5649b537..f2d790f727 100644 --- a/core/src/main/scala/spark/Shuffle.scala +++ b/core/src/main/scala/spark/Shuffle.scala @@ -1,5 +1,9 @@ package spark +import java.net._ +import java.util.{BitSet} +import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} + /** * A trait for shuffle system. Given an input RDD and combiner functions * for PairRDDExtras.combineByKey(), returns an output RDD. @@ -13,3 +17,112 @@ trait Shuffle[K, V, C] { mergeCombiners: (C, C) => C) : RDD[(K, C)] } + +/** + * An object containing common shuffle config parameters + */ +private object Shuffle +extends Logging { + // Tracker communication constants + val ReducerEntering = 0 + val ReducerLeaving = 1 + + // ShuffleTracker info + private var MasterHostAddress_ = System.getProperty( + "spark.shuffle.masterHostAddress", InetAddress.getLocalHost.getHostAddress) + private var MasterTrackerPort_ = System.getProperty( + "spark.shuffle.masterTrackerPort", "22222").toInt + + private var BlockSize_ = System.getProperty( + "spark.shuffle.blockSize", "1024").toInt * 1024 + + // Used thoughout the code for small and large waits/timeouts + private var MinKnockInterval_ = System.getProperty( + "spark.shuffle.minKnockInterval", "1000").toInt + private var MaxKnockInterval_ = System.getProperty( + "spark.shuffle.maxKnockInterval", "5000").toInt + + // Maximum number of connections + private var MaxRxConnections_ = System.getProperty( + "spark.shuffle.maxRxConnections", "4").toInt + private var MaxTxConnections_ = System.getProperty( + "spark.shuffle.maxTxConnections", "8").toInt + + // Upper limit on receiving in blocked implementations (whichever comes first) + private var MaxChatTime_ = System.getProperty( + "spark.shuffle.maxChatTime", "250").toInt + private var MaxChatBlocks_ = System.getProperty( + "spark.shuffle.maxChatBlocks", "1024").toInt + + // A reducer is throttled if it is this much faster + private var ThrottleFraction_ = System.getProperty( + "spark.shuffle.throttleFraction", "2.0").toDouble + + def MasterHostAddress = MasterHostAddress_ + def MasterTrackerPort = MasterTrackerPort_ + + def BlockSize = BlockSize_ + + def MinKnockInterval = MinKnockInterval_ + def MaxKnockInterval = MaxKnockInterval_ + + def MaxRxConnections = MaxRxConnections_ + def MaxTxConnections = MaxTxConnections_ + + def MaxChatTime = MaxChatTime_ + def MaxChatBlocks = MaxChatBlocks_ + + def ThrottleFraction = ThrottleFraction_ + + // Returns a standard ThreadFactory except all threads are daemons + private def newDaemonThreadFactory: ThreadFactory = { + new ThreadFactory { + def newThread(r: Runnable): Thread = { + var t = Executors.defaultThreadFactory.newThread(r) + t.setDaemon(true) + return t + } + } + } + + // Wrapper over newFixedThreadPool + def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = { + var threadPool = + Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor] + + threadPool.setThreadFactory(newDaemonThreadFactory) + + return threadPool + } + + // Wrapper over newCachedThreadPool + def newDaemonCachedThreadPool: ThreadPoolExecutor = { + var threadPool = + Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor] + + threadPool.setThreadFactory(newDaemonThreadFactory) + + return threadPool + } +} + +@serializable +case class SplitInfo(val hostAddress: String, val listenPort: Int, + val splitId: Int) { + + var hasSplits = 0 + var hasSplitsBitVector: BitSet = null + + // Used by mappers of dim |numOutputSplits| + var totalBlocksPerOutputSplit: Array[Int] = null + // Used by reducers of dim |numInputSplits| + var hasBlocksPerInputSplit: Array[Int] = null +} + +object SplitInfo { + // Constants for special values of listenPort + val MappersBusy = -1 + + // Other constants + val UnusedParam = 0 +} diff --git a/core/src/main/scala/spark/ShuffleTrackerStrategy.scala b/core/src/main/scala/spark/ShuffleTrackerStrategy.scala new file mode 100644 index 0000000000..fc2f4aa5f7 --- /dev/null +++ b/core/src/main/scala/spark/ShuffleTrackerStrategy.scala @@ -0,0 +1,470 @@ +package spark + +import java.util.{BitSet, Random} + +import scala.collection.mutable.ArrayBuffer +import scala.util.Sorting._ + +/** + * A trait for implementing tracker strategies for the shuffle system. + */ +trait ShuffleTrackerStrategy { + // Initialize + def initialize(outputLocs_ : Array[SplitInfo]): Unit + + // Select a set of splits and send back + def selectSplit(reducerSplitInfo: SplitInfo): ArrayBuffer[Int] + + // Update internal stats if things could be sent back successfully + def AddReducerToSplit(reducerSplitInfo: SplitInfo, splitIndices: ArrayBuffer[Int]): Unit + + // A reducer is done. Update internal stats + def deleteReducerFrom(reducerSplitInfo: SplitInfo, + receptionStat: ReceptionStats): Unit +} + +/** + * Helper class to send back reception stats from the reducer + */ +case class ReceptionStats(val bytesReceived: Int, val timeSpent: Int, + serverSplitIndex: Int) { } + +/** + * A simple ShuffleTrackerStrategy that tries to balance the total number of + * connections created for each mapper. + */ +class BalanceConnectionsShuffleTrackerStrategy +extends ShuffleTrackerStrategy with Logging { + private var numSources = -1 + private var outputLocs: Array[SplitInfo] = null + private var curConnectionsPerLoc: Array[Int] = null + private var totalConnectionsPerLoc: Array[Int] = null + + // The order of elements in the outputLocs (splitIndex) is used to pass + // information back and forth between the tracker, mappers, and reducers + def initialize(outputLocs_ : Array[SplitInfo]): Unit = { + outputLocs = outputLocs_ + numSources = outputLocs.size + + // Now initialize other data structures + curConnectionsPerLoc = Array.tabulate(numSources)(_ => 0) + totalConnectionsPerLoc = Array.tabulate(numSources)(_ => 0) + } + + def selectSplit(reducerSplitInfo: SplitInfo): ArrayBuffer[Int] = synchronized { + var minConnections = Int.MaxValue + var minIndex = -1 + + var splitIndices = ArrayBuffer[Int]() + + for (i <- 0 until numSources) { + // TODO: Use of MaxRxConnections instead of MaxTxConnections is + // intentional here. MaxTxConnections is per machine whereas + // MaxRxConnections is per mapper/reducer. Will have to find a better way. + if (curConnectionsPerLoc(i) < Shuffle.MaxRxConnections && + totalConnectionsPerLoc(i) < minConnections && + !reducerSplitInfo.hasSplitsBitVector.get(i)) { + minConnections = totalConnectionsPerLoc(i) + minIndex = i + } + } + + if (minIndex != -1) { + splitIndices += minIndex + } + + return splitIndices + } + + def AddReducerToSplit(reducerSplitInfo: SplitInfo, splitIndices: ArrayBuffer[Int]): Unit = synchronized { + splitIndices.foreach { splitIndex => + curConnectionsPerLoc(splitIndex) = curConnectionsPerLoc(splitIndex) + 1 + totalConnectionsPerLoc(splitIndex) = + totalConnectionsPerLoc(splitIndex) + 1 + } + } + + def deleteReducerFrom(reducerSplitInfo: SplitInfo, + receptionStat: ReceptionStats): Unit = synchronized { + // Decrease number of active connections + curConnectionsPerLoc(receptionStat.serverSplitIndex) = + curConnectionsPerLoc(receptionStat.serverSplitIndex) - 1 + + // TODO: This assertion can legally fail when ShuffleClient times out while + // waiting for tracker response and decides to go to a random server + // assert(curConnectionsPerLoc(receptionStat.serverSplitIndex) >= 0) + + // Just in case + if (curConnectionsPerLoc(receptionStat.serverSplitIndex) < 0) { + curConnectionsPerLoc(receptionStat.serverSplitIndex) = 0 + } + } +} + +/** + * A simple ShuffleTrackerStrategy that randomly selects mapper for each reducer + */ +class SelectRandomShuffleTrackerStrategy +extends ShuffleTrackerStrategy with Logging { + private var numMappers = -1 + private var outputLocs: Array[SplitInfo] = null + + private var ranGen = new Random + + // The order of elements in the outputLocs (splitIndex) is used to pass + // information back and forth between the tracker, mappers, and reducers + def initialize(outputLocs_ : Array[SplitInfo]): Unit = { + outputLocs = outputLocs_ + numMappers = outputLocs.size + } + + def selectSplit(reducerSplitInfo: SplitInfo): ArrayBuffer[Int] = synchronized { + var splitIndex = -1 + + do { + splitIndex = ranGen.nextInt(numMappers) + } while (reducerSplitInfo.hasSplitsBitVector.get(splitIndex)) + + return ArrayBuffer(splitIndex) + } + + def AddReducerToSplit(reducerSplitInfo: SplitInfo, splitIndices: ArrayBuffer[Int]): Unit = synchronized { + } + + def deleteReducerFrom(reducerSplitInfo: SplitInfo, + receptionStat: ReceptionStats): Unit = synchronized { + // TODO: This assertion can legally fail when ShuffleClient times out while + // waiting for tracker response and decides to go to a random server + // assert(curConnectionsPerLoc(receptionStat.serverSplitIndex) >= 0) + } +} + +/** + * Shuffle tracker strategy that tries to balance the percentage of blocks + * remaining for each reducer + */ +class BalanceRemainingShuffleTrackerStrategy +extends ShuffleTrackerStrategy with Logging { + // Number of mappers + private var numMappers = -1 + // Number of reducers + private var numReducers = -1 + private var outputLocs: Array[SplitInfo] = null + + // Data structures from reducers' perspectives + private var totalBlocksPerInputSplit: Array[Array[Int]] = null + private var hasBlocksPerInputSplit: Array[Array[Int]] = null + + // Stored in bytes per millisecond + private var speedPerInputSplit: Array[Array[Double]] = null + + private var curConnectionsPerLoc: Array[Int] = null + private var totalConnectionsPerLoc: Array[Int] = null + + // The order of elements in the outputLocs (splitIndex) is used to pass + // information back and forth between the tracker, mappers, and reducers + def initialize(outputLocs_ : Array[SplitInfo]): Unit = { + outputLocs = outputLocs_ + + numMappers = outputLocs.size + + // All the outputLocs have totalBlocksPerOutputSplit of same size + numReducers = outputLocs(0).totalBlocksPerOutputSplit.size + + // Now initialize the data structures + totalBlocksPerInputSplit = Array.tabulate(numReducers, numMappers)((i,j) => + outputLocs(j).totalBlocksPerOutputSplit(i)) + hasBlocksPerInputSplit = Array.tabulate(numReducers, numMappers)((_,_) => 0) + + // Initialize to -1 + speedPerInputSplit = Array.tabulate(numReducers, numMappers)((_,_) => -1.0) + + curConnectionsPerLoc = Array.tabulate(numMappers)(_ => 0) + totalConnectionsPerLoc = Array.tabulate(numMappers)(_ => 0) + } + + def selectSplit(reducerSplitInfo: SplitInfo): ArrayBuffer[Int] = synchronized { + var splitIndex = -1 + + // Estimate time remaining to finish receiving for all reducer/mapper pairs + // If speed is unknown or zero then make it 1 to give a large estimate + var individualEstimates = Array.tabulate(numReducers, numMappers)((_,_) => 0.0) + for (i <- 0 until numReducers; j <- 0 until numMappers) { + var blocksRemaining = totalBlocksPerInputSplit(i)(j) - + hasBlocksPerInputSplit(i)(j) + assert(blocksRemaining >= 0) + + individualEstimates(i)(j) = 1.0 * blocksRemaining * Shuffle.BlockSize / + { if (speedPerInputSplit(i)(j) <= 0.0) 1.0 else speedPerInputSplit(i)(j) } + } + + // Check if all speedPerInputSplit entries have non-zero values + var estimationComplete = true + for (i <- 0 until numReducers; j <- 0 until numMappers) { + if (speedPerInputSplit(i)(j) < 0.0) { + estimationComplete = false + } + } + + // Mark mappers where this reducer is too fast + var throttleFromMapper = Array.tabulate(numMappers)(_ => false) + + for (i <- 0 until numMappers) { + var estimatesFromAMapper = + Array.tabulate(numReducers)(j => individualEstimates(j)(i)) + + val estimateOfThisReducer = estimatesFromAMapper(reducerSplitInfo.splitId) + + // Only care if this reducer yet has something to receive from this mapper + if (estimateOfThisReducer > 0) { + // Sort the estimated times + quickSort(estimatesFromAMapper) + + // Find a Shuffle.ThrottleFraction amount of gap + var gapIndex = -1 + for (i <- 0 until numReducers - 1) { + if (gapIndex == -1 && estimatesFromAMapper(i) > 0 && + (Shuffle.ThrottleFraction * estimatesFromAMapper(i) < + estimatesFromAMapper(i + 1))) { + gapIndex = i + } + + assert (estimatesFromAMapper(i) <= estimatesFromAMapper(i + 1)) + } + + // Keep track of how many have completed + var numComplete = estimatesFromAMapper.findIndexOf(i => (i > 0)) + if (numComplete == -1) { + numComplete = numReducers + } + + // TODO: Pick a configurable parameter + if (gapIndex != -1 && (1.0 * (gapIndex - numComplete + 1) < 0.1 * Shuffle.ThrottleFraction * (numReducers - numComplete)) && + estimateOfThisReducer <= estimatesFromAMapper(gapIndex)) { + throttleFromMapper(i) = true + logInfo("Throttling R-%d at M-%d with %f and cut-off %f at %d".format(reducerSplitInfo.splitId, i, estimateOfThisReducer, estimatesFromAMapper(gapIndex + 1), gapIndex)) +// for (i <- 0 until numReducers) { +// print(estimatesFromAMapper(i) + " ") +// } +// println("") + } + } else { + throttleFromMapper(i) = true + } + } + + var minConnections = Int.MaxValue + for (i <- 0 until numMappers) { + // TODO: Use of MaxRxConnections instead of MaxTxConnections is + // intentional here. MaxTxConnections is per machine whereas + // MaxRxConnections is per mapper/reducer. Will have to find a better way. + if (curConnectionsPerLoc(i) < Shuffle.MaxRxConnections && + totalConnectionsPerLoc(i) < minConnections && + !reducerSplitInfo.hasSplitsBitVector.get(i) && + !throttleFromMapper(i)) { + minConnections = totalConnectionsPerLoc(i) + splitIndex = i + } + } + + return ArrayBuffer(splitIndex) + } + + def AddReducerToSplit(reducerSplitInfo: SplitInfo, splitIndices: ArrayBuffer[Int]): Unit = synchronized { + splitIndices.foreach { splitIndex => + curConnectionsPerLoc(splitIndex) += 1 + totalConnectionsPerLoc(splitIndex) += 1 + } + } + + def deleteReducerFrom(reducerSplitInfo: SplitInfo, + receptionStat: ReceptionStats): Unit = synchronized { + // Update hasBlocksPerInputSplit for reducerSplitInfo + hasBlocksPerInputSplit(reducerSplitInfo.splitId) = + reducerSplitInfo.hasBlocksPerInputSplit + + // Store the last known speed. Add 1 to avoid divide-by-zero. Ignore 0 bytes + // TODO: We are forgetting the old speed. Can use averaging at some point. + if (receptionStat.bytesReceived > 0) { + speedPerInputSplit(reducerSplitInfo.splitId)(receptionStat.serverSplitIndex) = + 1.0 * receptionStat.bytesReceived / (receptionStat.timeSpent + 1.0) + } + + // logInfo("%d received %d bytes in %d millis".format(reducerSplitInfo.splitId, receptionStat.bytesReceived, receptionStat.timeSpent)) + + // Update current connections to the mapper + curConnectionsPerLoc(receptionStat.serverSplitIndex) -= 1 + + // TODO: This assertion can legally fail when ShuffleClient times out while + // waiting for tracker response and decides to go to a random server + // assert(curConnectionsPerLoc(receptionStat.serverSplitIndex) >= 0) + + // Just in case + if (curConnectionsPerLoc(receptionStat.serverSplitIndex) < 0) { + curConnectionsPerLoc(receptionStat.serverSplitIndex) = 0 + } + } +} + +/** + * Shuffle tracker strategy that allows reducers to create receiving threads + * depending on their estimated time remaining + */ +class LimitConnectionsShuffleTrackerStrategy +extends ShuffleTrackerStrategy with Logging { + // Number of mappers + private var numMappers = -1 + // Number of reducers + private var numReducers = -1 + private var outputLocs: Array[SplitInfo] = null + + private var ranGen = new Random + + // Data structures from reducers' perspectives + private var totalBlocksPerInputSplit: Array[Array[Int]] = null + private var hasBlocksPerInputSplit: Array[Array[Int]] = null + + // Stored in bytes per millisecond + private var speedPerInputSplit: Array[Array[Double]] = null + + private var curConnectionsPerReducer: Array[Int] = null + private var maxConnectionsPerReducer: Array[Int] = null + + // The order of elements in the outputLocs (splitIndex) is used to pass + // information back and forth between the tracker, mappers, and reducers + def initialize(outputLocs_ : Array[SplitInfo]): Unit = { + outputLocs = outputLocs_ + + numMappers = outputLocs.size + + // All the outputLocs have totalBlocksPerOutputSplit of same size + numReducers = outputLocs(0).totalBlocksPerOutputSplit.size + + // Now initialize the data structures + totalBlocksPerInputSplit = Array.tabulate(numReducers, numMappers)((i,j) => + outputLocs(j).totalBlocksPerOutputSplit(i)) + hasBlocksPerInputSplit = Array.tabulate(numReducers, numMappers)((_,_) => 0) + + // Initialize to -1 + speedPerInputSplit = Array.tabulate(numReducers, numMappers)((_,_) => -1.0) + + curConnectionsPerReducer = Array.tabulate(numReducers)(_ => 0) + maxConnectionsPerReducer = Array.tabulate(numReducers)(_ => Shuffle.MaxRxConnections) + } + + def selectSplit(reducerSplitInfo: SplitInfo): ArrayBuffer[Int] = synchronized { + var splitIndices = ArrayBuffer[Int]() + + // Estimate time remaining to finish receiving for all reducer/mapper pairs + // If speed is unknown or zero then make it 1 to give a large estimate + var individualEstimates = Array.tabulate(numReducers, numMappers)((_,_) => 0.0) + for (i <- 0 until numReducers; j <- 0 until numMappers) { + var blocksRemaining = totalBlocksPerInputSplit(i)(j) - + hasBlocksPerInputSplit(i)(j) + assert(blocksRemaining >= 0) + + individualEstimates(i)(j) = 1.0 * blocksRemaining * Shuffle.BlockSize / + { if (speedPerInputSplit(i)(j) <= 0.0) 1.0 else speedPerInputSplit(i)(j) } + } + + // Check if all speedPerInputSplit entries have non-zero values + var estimationComplete = true + for (i <- 0 until numReducers; j <- 0 until numMappers) { + if (speedPerInputSplit(i)(j) < 0.0) { + estimationComplete = false + } + } + + // Estimate time remaining to finish receiving for each reducer + var completionEstimates = Array.tabulate(numReducers)( + individualEstimates(_).foldLeft(Double.MinValue)(Math.max(_,_))) + + var numFinished = 0 + for (i <- 0 until numReducers) { + if (completionEstimates(i).toInt == 0) { + numFinished += 1 + } + } + + // If a certain number of reducers have finished already, then don't bother + if (numFinished >= ((1.0 - 0.1 * Shuffle.ThrottleFraction) * numReducers)) { + for (i <- 0 until numReducers) { + maxConnectionsPerReducer(i) = Shuffle.MaxRxConnections + } + // Otherwise, if estimation is complete give reducers proportional threads + } else if (estimationComplete) { + val fastestEstimate = + completionEstimates.foldLeft(Double.MaxValue)(Math.min(_,_)) + val slowestEstimate = + completionEstimates.foldLeft(Double.MinValue)(Math.max(_,_)) + + // Set maxConnectionsPerReducer for all reducers proportional to their + // estimated time remaining with slowestEstimate reducer having the max + for (i <- 0 until numReducers) { + maxConnectionsPerReducer(i) = + ((completionEstimates(i) / slowestEstimate) * Shuffle.MaxRxConnections).toInt + } + } + + // Send back a splitIndex if this reducer is within its limit + if (curConnectionsPerReducer(reducerSplitInfo.splitId) < + maxConnectionsPerReducer(reducerSplitInfo.splitId)) { + + var i = maxConnectionsPerReducer(reducerSplitInfo.splitId) - + curConnectionsPerReducer(reducerSplitInfo.splitId) + + var temp = reducerSplitInfo.hasSplitsBitVector.clone.asInstanceOf[BitSet] + temp.flip(0, numMappers) + + i = Math.min(i, temp.cardinality) + + while (i > 0) { + var splitIndex = -1 + + do { + splitIndex = ranGen.nextInt(numMappers) + } while (reducerSplitInfo.hasSplitsBitVector.get(splitIndex)) + + reducerSplitInfo.hasSplitsBitVector.set(splitIndex) + splitIndices += splitIndex + i -= 1 + } + } + + return splitIndices + } + + def AddReducerToSplit(reducerSplitInfo: SplitInfo, splitIndices: ArrayBuffer[Int]): Unit = synchronized { + splitIndices.foreach { splitIndex => + curConnectionsPerReducer(reducerSplitInfo.splitId) += 1 + } + } + + def deleteReducerFrom(reducerSplitInfo: SplitInfo, + receptionStat: ReceptionStats): Unit = synchronized { + // Update hasBlocksPerInputSplit for reducerSplitInfo + hasBlocksPerInputSplit(reducerSplitInfo.splitId) = + reducerSplitInfo.hasBlocksPerInputSplit + + // Store the last known speed. Add 1 to avoid divide-by-zero. Ignore 0 bytes + // TODO: We are forgetting the old speed. Can use averaging at some point. + if (receptionStat.bytesReceived > 0) { + speedPerInputSplit(reducerSplitInfo.splitId)(receptionStat.serverSplitIndex) = + 1.0 * receptionStat.bytesReceived / (receptionStat.timeSpent + 1.0) + } + + // logInfo("%d received %d bytes in %d millis".format(reducerSplitInfo.splitId, receptionStat.bytesReceived, receptionStat.timeSpent)) + + // Update current threads by this reducer + curConnectionsPerReducer(reducerSplitInfo.splitId) -= 1 + + // TODO: This assertion can legally fail when ShuffleClient times out while + // waiting for tracker response and decides to go to a random server + // assert(curConnectionsPerLoc(receptionStat.serverSplitIndex) >= 0) + + // Just in case + if (curConnectionsPerReducer(reducerSplitInfo.splitId) < 0) { + curConnectionsPerReducer(reducerSplitInfo.splitId) = 0 + } + } +} diff --git a/core/src/main/scala/spark/TrackedCustomBlockedInMemoryShuffle.scala b/core/src/main/scala/spark/TrackedCustomBlockedInMemoryShuffle.scala new file mode 100644 index 0000000000..0d21df9338 --- /dev/null +++ b/core/src/main/scala/spark/TrackedCustomBlockedInMemoryShuffle.scala @@ -0,0 +1,926 @@ +package spark + +import java.io._ +import java.net._ +import java.util.{BitSet, Random, Timer, TimerTask, UUID} +import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory} + +import scala.collection.mutable.{ArrayBuffer, HashMap} + +/** + * An implementation of shuffle using memory served through custom server + * where receivers create simultaneous connections to multiple servers by + * setting the 'spark.shuffle.maxRxConnections' config option. + * + * By controlling the 'spark.shuffle.blockSize' config option one can also + * control the largest block size to divide each map output into. Essentially, + * instead of creating one large output file for each reducer, maps create + * multiple smaller files to enable finer level of engagement. + * + * 'spark.shuffle.maxTxConnections' enforces server-side cap. Ideally, + * maxTxConnections >= maxRxConnections * numReducersPerMachine + * + * 'spark.shuffle.TrackerStrategy' decides which strategy to use in the tracker + * + * TODO: Add support for compression when spark.compress is set to true. + */ +@serializable +class TrackedCustomBlockedInMemoryShuffle[K, V, C] +extends Shuffle[K, V, C] with Logging { + @transient var totalSplits = 0 + @transient var hasSplits = 0 + + @transient var totalBlocksInSplit: Array[Int] = null + @transient var hasBlocksInSplit: Array[Int] = null + + @transient var hasSplitsBitVector: BitSet = null + @transient var splitsInRequestBitVector: BitSet = null + + @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null + @transient var combiners: HashMap[K,C] = null + + override def compute(input: RDD[(K, V)], + numOutputSplits: Int, + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C) + : RDD[(K, C)] = + { + val sc = input.sparkContext + val shuffleId = TrackedCustomBlockedInMemoryShuffle.newShuffleId() + logInfo("Shuffle ID: " + shuffleId) + + val splitRdd = new NumberedSplitRDD(input) + val numInputSplits = splitRdd.splits.size + + // Run a parallel map and collect to write the intermediate data files, + // returning a list of inputSplitId -> serverUri pairs + val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => { + val myIndex = pair._1 + val myIterator = pair._2 + val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C]) + for ((k, v) <- myIterator) { + var bucketId = k.hashCode % numOutputSplits + if (bucketId < 0) { // Fix bucket ID if hash code was negative + bucketId += numOutputSplits + } + val bucket = buckets(bucketId) + bucket(k) = bucket.get(k) match { + case Some(c) => mergeValue(c, v) + case None => createCombiner(v) + } + } + + // Keep track of number of blocks for each output split + var numBlocksPerOutputSplit = Array.tabulate(numOutputSplits)(_ => 0) + + for (i <- 0 until numOutputSplits) { + var blockNum = 0 + var isDirty = false + + var splitName = "" + var baos: ByteArrayOutputStream = null + var oos: ObjectOutputStream = null + + var writeStartTime: Long = 0 + + buckets(i).foreach(pair => { + // Open a new stream if necessary + if (!isDirty) { + splitName = TrackedCustomBlockedInMemoryShuffle.getSplitName(shuffleId, + myIndex, i, blockNum) + + baos = new ByteArrayOutputStream + oos = new ObjectOutputStream(baos) + + writeStartTime = System.currentTimeMillis + logInfo("BEGIN WRITE: " + splitName) + } + + oos.writeObject(pair) + isDirty = true + + // Close the old stream if has crossed the blockSize limit + if (baos.size > Shuffle.BlockSize) { + TrackedCustomBlockedInMemoryShuffle.splitsCache(splitName) = + baos.toByteArray + + logInfo("END WRITE: " + splitName) + val writeTime = System.currentTimeMillis - writeStartTime + logInfo("Writing " + splitName + " of size " + baos.size + " bytes took " + writeTime + " millis.") + + blockNum = blockNum + 1 + isDirty = false + oos.close() + } + }) + + if (isDirty) { + TrackedCustomBlockedInMemoryShuffle.splitsCache(splitName) = baos.toByteArray + + logInfo("END WRITE: " + splitName) + val writeTime = System.currentTimeMillis - writeStartTime + logInfo("Writing " + splitName + " of size " + baos.size + " bytes took " + writeTime + " millis.") + + blockNum = blockNum + 1 + oos.close() + } + + // Store BLOCKNUM info + splitName = TrackedCustomBlockedInMemoryShuffle.getBlockNumOutputName( + shuffleId, myIndex, i) + baos = new ByteArrayOutputStream + oos = new ObjectOutputStream(baos) + oos.writeObject(blockNum) + TrackedCustomBlockedInMemoryShuffle.splitsCache(splitName) = baos.toByteArray + + // Close streams + oos.close() + + // Store number of blocks for this outputSplit + numBlocksPerOutputSplit(i) = blockNum + } + + var retVal = SplitInfo(TrackedCustomBlockedInMemoryShuffle.serverAddress, + TrackedCustomBlockedInMemoryShuffle.serverPort, myIndex) + retVal.totalBlocksPerOutputSplit = numBlocksPerOutputSplit + + (retVal) + }).collect() + + // Start tracker + var shuffleTracker = new ShuffleTracker(outputLocs) + shuffleTracker.setDaemon(true) + shuffleTracker.start() + logInfo("ShuffleTracker started...") + + // Return an RDD that does each of the merges for a given partition + val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits) + return indexes.flatMap((myId: Int) => { + totalSplits = outputLocs.size + hasSplits = 0 + + totalBlocksInSplit = Array.tabulate(totalSplits)(_ => -1) + hasBlocksInSplit = Array.tabulate(totalSplits)(_ => 0) + + hasSplitsBitVector = new BitSet(totalSplits) + splitsInRequestBitVector = new BitSet(totalSplits) + + receivedData = new LinkedBlockingQueue[(Int, Array[Byte])] + combiners = new HashMap[K, C] + + var threadPool = Shuffle.newDaemonFixedThreadPool( + Shuffle.MaxRxConnections) + + while (hasSplits < totalSplits) { + // Local status of hasSplitsBitVector and splitsInRequestBitVector + val localSplitInfo = getLocalSplitInfo(myId) + + // DO NOT talk to the tracker if all the required splits are already busy + val hasOrWillHaveSplits = localSplitInfo.hasSplitsBitVector.cardinality + + var numThreadsToCreate = + Math.min(totalSplits - hasOrWillHaveSplits, Shuffle.MaxRxConnections) - + threadPool.getActiveCount + + while (hasSplits < totalSplits && numThreadsToCreate > 0) { + // Receive which split to pull from the tracker + logInfo("Talking to tracker...") + val startTime = System.currentTimeMillis + val splitIndices = getTrackerSelectedSplit(myId) + logInfo("Got %s from tracker in %d millis".format(splitIndices, System.currentTimeMillis - startTime)) + + if (splitIndices.size > 0) { + splitIndices.foreach { splitIndex => + val selectedSplitInfo = outputLocs(splitIndex) + val requestSplit = + "%d/%d/%d".format(shuffleId, selectedSplitInfo.splitId, myId) + + threadPool.execute(new ShuffleClient(splitIndex, selectedSplitInfo, + requestSplit, myId)) + + // splitIndex is in transit. Will be unset in the ShuffleClient + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set(splitIndex) + } + } + } else { + // Tracker replied back with a NO. Sleep for a while. + Thread.sleep(Shuffle.MinKnockInterval) + numThreadsToCreate = 0 + } + + numThreadsToCreate = numThreadsToCreate - splitIndices.size + } + + // Sleep for a while before creating new threads + Thread.sleep(Shuffle.MinKnockInterval) + } + + threadPool.shutdown() + + // Start consumer + // TODO: Consumption is delayed until everything has been received. + // Otherwise it interferes with network performance + var shuffleConsumer = new ShuffleConsumer(mergeCombiners) + shuffleConsumer.setDaemon(true) + shuffleConsumer.start() + logInfo("ShuffleConsumer started...") + + // Don't return until consumption is finished + // while (receivedData.size > 0) { + // Thread.sleep(Shuffle.MinKnockInterval) + // } + + // Wait till shuffleConsumer is done + shuffleConsumer.join + + combiners + }) + } + + private def getLocalSplitInfo(myId: Int): SplitInfo = { + var localSplitInfo = SplitInfo(InetAddress.getLocalHost.getHostAddress, + SplitInfo.UnusedParam, myId) + + // Store hasSplits + localSplitInfo.hasSplits = hasSplits + + // Store hasSplitsBitVector + hasSplitsBitVector.synchronized { + localSplitInfo.hasSplitsBitVector = + hasSplitsBitVector.clone.asInstanceOf[BitSet] + } + + // Store hasBlocksInSplit to hasBlocksPerInputSplit + hasBlocksInSplit.synchronized { + localSplitInfo.hasBlocksPerInputSplit = + hasBlocksInSplit.clone.asInstanceOf[Array[Int]] + } + + // Include the splitsInRequest as well + splitsInRequestBitVector.synchronized { + localSplitInfo.hasSplitsBitVector.or(splitsInRequestBitVector) + } + + return localSplitInfo + } + + def selectRandomSplit: Int = { + var requiredSplits = new ArrayBuffer[Int] + + synchronized { + for (i <- 0 until totalSplits) { + if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) { + requiredSplits += i + } + } + } + + if (requiredSplits.size > 0) { + requiredSplits(TrackedCustomBlockedInMemoryShuffle.ranGen.nextInt( + requiredSplits.size)) + } else { + -1 + } + } + + // Talks to the tracker and receives instruction + private def getTrackerSelectedSplit(myId: Int): ArrayBuffer[Int] = { + // Local status of hasSplitsBitVector and splitsInRequestBitVector + val localSplitInfo = getLocalSplitInfo(myId) + + // DO NOT talk to the tracker if all the required splits are already busy + if (localSplitInfo.hasSplitsBitVector.cardinality == totalSplits) { + return ArrayBuffer[Int]() + } + + val clientSocketToTracker = new Socket(Shuffle.MasterHostAddress, + Shuffle.MasterTrackerPort) + val oosTracker = + new ObjectOutputStream(clientSocketToTracker.getOutputStream) + oosTracker.flush() + val oisTracker = + new ObjectInputStream(clientSocketToTracker.getInputStream) + + var selectedSplitIndices = ArrayBuffer[Int]() + + // Setup the timeout mechanism + var timeOutTask = new TimerTask { + override def run: Unit = { + logInfo("Waited enough for tracker response... Take random response...") + + // sockets will be closed in finally + // TODO: Sometimes timer wont go off + + // TODO: Selecting randomly here. Tracker won't know about it and get an + // asssertion failure when this thread leaves + + selectedSplitIndices = ArrayBuffer(selectRandomSplit) + } + } + + var timeOutTimer = new Timer + // TODO: Which timeout to use? + // timeOutTimer.schedule(timeOutTask, Shuffle.MinKnockInterval) + + try { + // Send intention + oosTracker.writeObject(Shuffle.ReducerEntering) + oosTracker.flush() + + // Send what this reducer has + oosTracker.writeObject(localSplitInfo) + oosTracker.flush() + + // Receive reply from the tracker + selectedSplitIndices = oisTracker.readObject.asInstanceOf[ArrayBuffer[Int]] + + // Turn the timer OFF + timeOutTimer.cancel() + } catch { + case e: Exception => { + logInfo("getTrackerSelectedSplit had a " + e) + } + } finally { + oisTracker.close() + oosTracker.close() + clientSocketToTracker.close() + } + + return selectedSplitIndices + } + + class ShuffleTracker(outputLocs: Array[SplitInfo]) + extends Thread with Logging { + var threadPool = Shuffle.newDaemonCachedThreadPool + var serverSocket: ServerSocket = null + + // Create trackerStrategy object + val trackerStrategyClass = System.getProperty( + "spark.shuffle.trackerStrategy", + "spark.BalanceConnectionsShuffleTrackerStrategy") + + val trackerStrategy = + Class.forName(trackerStrategyClass).newInstance().asInstanceOf[ShuffleTrackerStrategy] + + // Must initialize here by supplying the outputLocs param + // TODO: This could be avoided by directly passing it to the constructor + trackerStrategy.initialize(outputLocs) + + override def run: Unit = { + serverSocket = new ServerSocket(Shuffle.MasterTrackerPort) + logInfo("ShuffleTracker" + serverSocket) + + try { + while (true) { + var clientSocket: Socket = null + try { + clientSocket = serverSocket.accept() + } catch { + case e: Exception => { + logInfo("ShuffleTracker had a " + e) + } + } + + if (clientSocket != null) { + try { + threadPool.execute(new Thread { + override def run: Unit = { + val oos = new ObjectOutputStream(clientSocket.getOutputStream) + oos.flush() + val ois = new ObjectInputStream(clientSocket.getInputStream) + + try { + // Receive intention + val reducerIntention = ois.readObject.asInstanceOf[Int] + + if (reducerIntention == Shuffle.ReducerEntering) { + // Receive what the reducer has + val reducerSplitInfo = + ois.readObject.asInstanceOf[SplitInfo] + + // Select splits and update stats if necessary + var selectedSplitIndices = ArrayBuffer[Int]() + trackerStrategy.synchronized { + selectedSplitIndices = trackerStrategy.selectSplit( + reducerSplitInfo) + } + + // Send reply back + oos.writeObject(selectedSplitIndices) + oos.flush() + + // Update internal stats, only if receiver got the reply + trackerStrategy.synchronized { + trackerStrategy.AddReducerToSplit(reducerSplitInfo, + selectedSplitIndices) + } + } + else if (reducerIntention == Shuffle.ReducerLeaving) { + val reducerSplitInfo = + ois.readObject.asInstanceOf[SplitInfo] + + // Receive reception stats: how many blocks the reducer + // read in how much time and from where + val receptionStat = + ois.readObject.asInstanceOf[ReceptionStats] + + // Update stats + trackerStrategy.synchronized { + trackerStrategy.deleteReducerFrom(reducerSplitInfo, + receptionStat) + } + + // Send ACK + oos.writeObject(receptionStat.serverSplitIndex) + oos.flush() + } + else { + throw new SparkException("Undefined reducerIntention") + } + } catch { + // EOFException is expected to happen because receiver can + // break connection due to timeout and pick random instead + case eofe: java.io.EOFException => { } + case e: Exception => { + logInfo("ShuffleTracker had a " + e) + } + } finally { + ois.close() + oos.close() + clientSocket.close() + } + } + }) + } catch { + // In failure, close socket here; else, client thread will close + case ioe: IOException => { + clientSocket.close() + } + } + } + } + } finally { + serverSocket.close() + } + // Shutdown the thread pool + threadPool.shutdown() + } + } + + class ShuffleConsumer(mergeCombiners: (C, C) => C) + extends Thread with Logging { + override def run: Unit = { + // Run until all splits are here + while (receivedData.size > 0) { + var splitIndex = -1 + var recvByteArray: Array[Byte] = null + + try { + var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])] + splitIndex = tempPair._1 + recvByteArray = tempPair._2 + } catch { + case e: Exception => { + logInfo("Exception during taking data from receivedData") + } + } + + val inputStream = + new ObjectInputStream(new ByteArrayInputStream(recvByteArray)) + + try{ + while (true) { + val (k, c) = inputStream.readObject.asInstanceOf[(K, C)] + combiners(k) = combiners.get(k) match { + case Some(oldC) => mergeCombiners(oldC, c) + case None => c + } + } + } catch { + case e: EOFException => { } + } + inputStream.close() + } + } + } + + class ShuffleClient(splitIndex: Int, serversplitInfo: SplitInfo, + requestSplit: String, myId: Int) + extends Thread with Logging { + private var peerSocketToSource: Socket = null + private var oosSource: ObjectOutputStream = null + private var oisSource: ObjectInputStream = null + + private var receptionSucceeded = false + + // Make sure that multiple messages don't go to the tracker + private var alreadySentLeavingNotification = false + + // Keep track of bytes received and time spent + private var numBytesReceived = 0 + private var totalTimeSpent = 0 + + override def run: Unit = { + // Setup the timeout mechanism + var timeOutTask = new TimerTask { + override def run: Unit = { + cleanUp() + } + } + + var timeOutTimer = new Timer + timeOutTimer.schedule(timeOutTask, Shuffle.MaxKnockInterval) + + try { + // Everything will break if BLOCKNUM is not correctly received + // First get BLOCKNUM file if totalBlocksInSplit(splitIndex) is unknown + peerSocketToSource = new Socket( + serversplitInfo.hostAddress, serversplitInfo.listenPort) + oosSource = + new ObjectOutputStream(peerSocketToSource.getOutputStream) + oosSource.flush() + var isSource = peerSocketToSource.getInputStream + oisSource = new ObjectInputStream(isSource) + + // Send path information + oosSource.writeObject(requestSplit) + + // TODO: Can be optimized. No need to do it everytime. + // Receive BLOCKNUM + totalBlocksInSplit(splitIndex) = oisSource.readObject.asInstanceOf[Int] + + // Turn the timer OFF, if the sender responds before timeout + timeOutTimer.cancel() + + while (hasBlocksInSplit(splitIndex) < totalBlocksInSplit(splitIndex)) { + // Set receptionSucceeded to false before trying for each block + receptionSucceeded = false + + // Request specific block + oosSource.writeObject(hasBlocksInSplit(splitIndex)) + + // Good to go. First, receive the length of the requested file + var requestedFileLen = oisSource.readObject.asInstanceOf[Int] + logInfo("Received requestedFileLen = " + requestedFileLen) + + // Create a temp variable to be used in different places + val requestPath = "http://%s:%d/shuffle/%s-%d".format( + serversplitInfo.hostAddress, serversplitInfo.listenPort, requestSplit, + hasBlocksInSplit(splitIndex)) + + // Receive the file + if (requestedFileLen != -1) { + val readStartTime = System.currentTimeMillis + logInfo("BEGIN READ: " + requestPath) + + // Receive data in an Array[Byte] + var recvByteArray = new Array[Byte](requestedFileLen) + var alreadyRead = 0 + var bytesRead = 0 + + while (alreadyRead != requestedFileLen) { + bytesRead = isSource.read(recvByteArray, alreadyRead, + requestedFileLen - alreadyRead) + if (bytesRead > 0) { + alreadyRead = alreadyRead + bytesRead + } + } + + // Make it available to the consumer + try { + receivedData.put((splitIndex, recvByteArray)) + } catch { + case e: Exception => { + logInfo("Exception during putting data into receivedData") + } + } + + // TODO: Updating stats before consumption is completed + hasBlocksInSplit(splitIndex) = hasBlocksInSplit(splitIndex) + 1 + + // Split has been received only if all the blocks have been received + if (hasBlocksInSplit(splitIndex) == totalBlocksInSplit(splitIndex)) { + hasSplitsBitVector.synchronized { + hasSplitsBitVector.set(splitIndex) + hasSplits += 1 + } + } + + receptionSucceeded = true + + logInfo("END READ: " + requestPath) + val readTime = System.currentTimeMillis - readStartTime + logInfo("Reading " + requestPath + " took " + readTime + " millis.") + + // Update stats + numBytesReceived = numBytesReceived + requestedFileLen + totalTimeSpent = totalTimeSpent + readTime.toInt + } else { + throw new SparkException("ShuffleServer " + serversplitInfo.hostAddress + " does not have " + requestSplit) + } + } + } catch { + // EOFException is expected to happen because sender can break + // connection due to timeout + case eofe: java.io.EOFException => { } + case e: Exception => { + logInfo("ShuffleClient had a " + e) + } + } finally { + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set(splitIndex, false) + } + cleanUp() + } + } + + // Connect to the tracker and update its stats + private def sendLeavingNotification(): Unit = synchronized { + if (!alreadySentLeavingNotification) { + val clientSocketToTracker = new Socket(Shuffle.MasterHostAddress, + Shuffle.MasterTrackerPort) + val oosTracker = + new ObjectOutputStream(clientSocketToTracker.getOutputStream) + oosTracker.flush() + val oisTracker = + new ObjectInputStream(clientSocketToTracker.getInputStream) + + try { + // Send intention + oosTracker.writeObject(Shuffle.ReducerLeaving) + oosTracker.flush() + + // Send reducerSplitInfo + oosTracker.writeObject(getLocalSplitInfo(myId)) + oosTracker.flush() + + // Send reception stats + oosTracker.writeObject(ReceptionStats( + numBytesReceived, totalTimeSpent, splitIndex)) + oosTracker.flush() + + // Receive ACK. No need to do anything with that + oisTracker.readObject.asInstanceOf[Int] + + // Now update sentLeavingNotifacation + alreadySentLeavingNotification = true + } catch { + case e: Exception => { + logInfo("sendLeavingNotification had a " + e) + } + } finally { + oisTracker.close() + oosTracker.close() + clientSocketToTracker.close() + } + } + } + + private def cleanUp(): Unit = { + // Update tracker stats first + sendLeavingNotification() + + // Clean up the connections to the mapper + if (oisSource != null) { + oisSource.close() + } + if (oosSource != null) { + oosSource.close() + } + if (peerSocketToSource != null) { + peerSocketToSource.close() + } + + logInfo("Leaving client") + } + } +} + +object TrackedCustomBlockedInMemoryShuffle extends Logging { + // Cache for keeping the splits around + val splitsCache = new HashMap[String, Array[Byte]] + + private var initialized = false + private var nextShuffleId = new AtomicLong(0) + + // Variables initialized by initializeIfNeeded() + private var shuffleDir: File = null + + private var shuffleServer: ShuffleServer = null + private var serverAddress = InetAddress.getLocalHost.getHostAddress + private var serverPort: Int = -1 + + // Random number generator + var ranGen = new Random + + private def initializeIfNeeded() = synchronized { + if (!initialized) { + // TODO: localDir should be created by some mechanism common to Spark + // so that it can be shared among shuffle, broadcast, etc + val localDirRoot = System.getProperty("spark.local.dir", "/tmp") + var tries = 0 + var foundLocalDir = false + var localDir: File = null + var localDirUuid: UUID = null + while (!foundLocalDir && tries < 10) { + tries += 1 + try { + localDirUuid = UUID.randomUUID + localDir = new File(localDirRoot, "spark-local-" + localDirUuid) + if (!localDir.exists) { + localDir.mkdirs() + foundLocalDir = true + } + } catch { + case e: Exception => + logWarning("Attempt " + tries + " to create local dir failed", e) + } + } + if (!foundLocalDir) { + logError("Failed 10 attempts to create local dir in " + localDirRoot) + System.exit(1) + } + shuffleDir = new File(localDir, "shuffle") + shuffleDir.mkdirs() + logInfo("Shuffle dir: " + shuffleDir) + + // Create and start the shuffleServer + shuffleServer = new ShuffleServer + shuffleServer.setDaemon(true) + shuffleServer.start() + logInfo("ShuffleServer started...") + + initialized = true + } + } + + def getSplitName(shuffleId: Long, inputId: Int, outputId: Int, + blockId: Int): String = { + initializeIfNeeded() + // Adding shuffleDir is unnecessary. Added to keep the parsers working + return "%s/%d/%d/%d-%d".format(shuffleDir, shuffleId, inputId, outputId, + blockId) + } + + def getBlockNumOutputName(shuffleId: Long, inputId: Int, + outputId: Int): String = { + initializeIfNeeded() + return "%s/%d/%d/%d-BLOCKNUM".format(shuffleDir, shuffleId, inputId, + outputId) + } + + def newShuffleId(): Long = { + nextShuffleId.getAndIncrement() + } + + class ShuffleServer + extends Thread with Logging { + var threadPool = Shuffle.newDaemonFixedThreadPool(Shuffle.MaxTxConnections) + + var serverSocket: ServerSocket = null + + override def run: Unit = { + serverSocket = new ServerSocket(0) + serverPort = serverSocket.getLocalPort + + logInfo("ShuffleServer started with " + serverSocket) + logInfo("Local URI: http://" + serverAddress + ":" + serverPort) + + try { + while (true) { + var clientSocket: Socket = null + try { + clientSocket = serverSocket.accept() + } catch { + case e: Exception => { } + } + if (clientSocket != null) { + logInfo("Serve: Accepted new client connection:" + clientSocket) + try { + threadPool.execute(new ShuffleServerThread(clientSocket)) + } catch { + // In failure, close socket here; else, the thread will close it + case ioe: IOException => { + clientSocket.close() + } + } + } + } + } finally { + if (serverSocket != null) { + logInfo("ShuffleServer now stopping...") + serverSocket.close() + } + } + // Shutdown the thread pool + threadPool.shutdown() + } + + class ShuffleServerThread(val clientSocket: Socket) + extends Thread with Logging { + private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream] + os.flush() + private val bos = new BufferedOutputStream(os) + bos.flush() + private val oos = new ObjectOutputStream(os) + oos.flush() + private val ois = new ObjectInputStream(clientSocket.getInputStream) + + logInfo("new ShuffleServerThread is running") + + override def run: Unit = { + try { + // Receive basic path information + var requestedSplitBase = ois.readObject.asInstanceOf[String] + + logInfo("requestedSplitBase: " + requestedSplitBase) + + // Read BLOCKNUM and send back the total number of blocks + val blockNumName = "%s/%s-BLOCKNUM".format(shuffleDir, + requestedSplitBase) + + val blockNumIn = new ObjectInputStream(new ByteArrayInputStream( + TrackedCustomBlockedInMemoryShuffle.splitsCache(blockNumName))) + val BLOCKNUM = blockNumIn.readObject.asInstanceOf[Int] + blockNumIn.close() + + oos.writeObject(BLOCKNUM) + oos.flush() + + val startTime = System.currentTimeMillis + var curTime = startTime + var keepSending = true + var numBlocksToSend = Shuffle.MaxChatBlocks + + while (keepSending && numBlocksToSend > 0) { + // Receive specific block request + val blockId = ois.readObject.asInstanceOf[Int] + + // Ready to send + var requestedSplit = shuffleDir + "/" + requestedSplitBase + "-" + blockId + + // Send the length of the requestedSplit to let the receiver know that + // transfer is about to start + // In the case of receiver timeout and connection close, this will + // throw a java.net.SocketException: Broken pipe + var requestedSplitLen = -1 + + try { + requestedSplitLen = + TrackedCustomBlockedInMemoryShuffle.splitsCache(requestedSplit).length + } catch { + case e: Exception => { } + } + + oos.writeObject(requestedSplitLen) + oos.flush() + + logInfo("requestedSplitLen = " + requestedSplitLen) + + // Read and send the requested file + if (requestedSplitLen != -1) { + // Send + bos.write(TrackedCustomBlockedInMemoryShuffle.splitsCache(requestedSplit), + 0, requestedSplitLen) + bos.flush() + + // Update loop variables + numBlocksToSend = numBlocksToSend - 1 + + curTime = System.currentTimeMillis + // Revoke sending only if there is anyone waiting in the queue + // TODO: Turning OFF the optimization so that reducers go back to + // tracker get advice + if (curTime - startTime >= Shuffle.MaxChatTime /* && + threadPool.getQueue.size > 0 */) { + keepSending = false + } + } else { + // Close the connection + } + } + } catch { + // If something went wrong, e.g., the worker at the other end died etc + // then close everything up + // Exception can happen if the receiver stops receiving + // EOFException is expected to happen because receiver can break + // connection as soon as it has all the blocks + case eofe: java.io.EOFException => { } + case e: Exception => { + logInfo("ShuffleServerThread had a " + e) + } + } finally { + logInfo("ShuffleServerThread is closing streams and sockets") + ois.close() + // TODO: Following can cause "java.net.SocketException: Socket closed" + oos.close() + bos.close() + clientSocket.close() + } + } + } + } +} diff --git a/core/src/main/scala/spark/TrackedCustomBlockedLocalFileShuffle.scala b/core/src/main/scala/spark/TrackedCustomBlockedLocalFileShuffle.scala new file mode 100644 index 0000000000..a27fa628c6 --- /dev/null +++ b/core/src/main/scala/spark/TrackedCustomBlockedLocalFileShuffle.scala @@ -0,0 +1,930 @@ +package spark + +import java.io._ +import java.net._ +import java.util.{BitSet, Random, Timer, TimerTask, UUID} +import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory} + +import scala.collection.mutable.{ArrayBuffer, HashMap} + +/** + * An implementation of shuffle using local files served through custom server + * where receivers create simultaneous connections to multiple servers by + * setting the 'spark.shuffle.maxRxConnections' config option. + * + * By controlling the 'spark.shuffle.blockSize' config option one can also + * control the largest block size to divide each map output into. Essentially, + * instead of creating one large output file for each reducer, maps create + * multiple smaller files to enable finer level of engagement. + * + * 'spark.shuffle.maxTxConnections' enforces server-side cap. Ideally, + * maxTxConnections >= maxRxConnections * numReducersPerMachine + * + * 'spark.shuffle.TrackerStrategy' decides which strategy to use in the tracker + * + * TODO: Add support for compression when spark.compress is set to true. + */ +@serializable +class TrackedCustomBlockedLocalFileShuffle[K, V, C] +extends Shuffle[K, V, C] with Logging { + @transient var totalSplits = 0 + @transient var hasSplits = 0 + + @transient var totalBlocksInSplit: Array[Int] = null + @transient var hasBlocksInSplit: Array[Int] = null + + @transient var hasSplitsBitVector: BitSet = null + @transient var splitsInRequestBitVector: BitSet = null + + @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null + @transient var combiners: HashMap[K,C] = null + + override def compute(input: RDD[(K, V)], + numOutputSplits: Int, + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C) + : RDD[(K, C)] = + { + val sc = input.sparkContext + val shuffleId = TrackedCustomBlockedLocalFileShuffle.newShuffleId() + logInfo("Shuffle ID: " + shuffleId) + + val splitRdd = new NumberedSplitRDD(input) + val numInputSplits = splitRdd.splits.size + + // Run a parallel map and collect to write the intermediate data files, + // returning a list of inputSplitId -> serverUri pairs + val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => { + val myIndex = pair._1 + val myIterator = pair._2 + val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C]) + for ((k, v) <- myIterator) { + var bucketId = k.hashCode % numOutputSplits + if (bucketId < 0) { // Fix bucket ID if hash code was negative + bucketId += numOutputSplits + } + val bucket = buckets(bucketId) + bucket(k) = bucket.get(k) match { + case Some(c) => mergeValue(c, v) + case None => createCombiner(v) + } + } + + // Keep track of number of blocks for each output split + var numBlocksPerOutputSplit = Array.tabulate(numOutputSplits)(_ => 0) + + for (i <- 0 until numOutputSplits) { + var blockNum = 0 + var isDirty = false + var file: File = null + var out: ObjectOutputStream = null + + var writeStartTime: Long = 0 + + buckets(i).foreach(pair => { + // Open a new file if necessary + if (!isDirty) { + file = TrackedCustomBlockedLocalFileShuffle.getOutputFile(shuffleId, + myIndex, i, blockNum) + writeStartTime = System.currentTimeMillis + logInfo("BEGIN WRITE: " + file) + + out = new ObjectOutputStream(new FileOutputStream(file)) + } + + out.writeObject(pair) + out.flush() + isDirty = true + + // Close the old file if has crossed the blockSize limit + if (file.length > Shuffle.BlockSize) { + out.close() + logInfo("END WRITE: " + file) + val writeTime = System.currentTimeMillis - writeStartTime + logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.") + + blockNum = blockNum + 1 + isDirty = false + } + }) + + if (isDirty) { + out.close() + logInfo("END WRITE: " + file) + val writeTime = System.currentTimeMillis - writeStartTime + logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.") + + blockNum = blockNum + 1 + } + + // Write the BLOCKNUM file + file = TrackedCustomBlockedLocalFileShuffle.getBlockNumOutputFile(shuffleId, + myIndex, i) + out = new ObjectOutputStream(new FileOutputStream(file)) + out.writeObject(blockNum) + out.close() + + // Store number of blocks for this outputSplit + numBlocksPerOutputSplit(i) = blockNum + } + + var retVal = SplitInfo(TrackedCustomBlockedLocalFileShuffle.serverAddress, + TrackedCustomBlockedLocalFileShuffle.serverPort, myIndex) + retVal.totalBlocksPerOutputSplit = numBlocksPerOutputSplit + + (retVal) + }).collect() + + // Start tracker + var shuffleTracker = new ShuffleTracker(outputLocs) + shuffleTracker.setDaemon(true) + shuffleTracker.start() + logInfo("ShuffleTracker started...") + + // Return an RDD that does each of the merges for a given partition + val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits) + return indexes.flatMap((myId: Int) => { + totalSplits = outputLocs.size + hasSplits = 0 + + totalBlocksInSplit = Array.tabulate(totalSplits)(_ => -1) + hasBlocksInSplit = Array.tabulate(totalSplits)(_ => 0) + + hasSplitsBitVector = new BitSet(totalSplits) + splitsInRequestBitVector = new BitSet(totalSplits) + + receivedData = new LinkedBlockingQueue[(Int, Array[Byte])] + combiners = new HashMap[K, C] + + var threadPool = Shuffle.newDaemonFixedThreadPool( + Shuffle.MaxRxConnections) + + while (hasSplits < totalSplits) { + // Local status of hasSplitsBitVector and splitsInRequestBitVector + val localSplitInfo = getLocalSplitInfo(myId) + + // DO NOT talk to the tracker if all the required splits are already busy + val hasOrWillHaveSplits = localSplitInfo.hasSplitsBitVector.cardinality + + var numThreadsToCreate = + Math.min(totalSplits - hasOrWillHaveSplits, Shuffle.MaxRxConnections) - + threadPool.getActiveCount + + while (hasSplits < totalSplits && numThreadsToCreate > 0) { + // Receive which split to pull from the tracker + logInfo("Talking to tracker...") + val startTime = System.currentTimeMillis + val splitIndices = getTrackerSelectedSplit(myId) + logInfo("Got %s from tracker in %d millis".format(splitIndices, System.currentTimeMillis - startTime)) + + if (splitIndices.size > 0) { + splitIndices.foreach { splitIndex => + val selectedSplitInfo = outputLocs(splitIndex) + val requestSplit = + "%d/%d/%d".format(shuffleId, selectedSplitInfo.splitId, myId) + + threadPool.execute(new ShuffleClient(splitIndex, selectedSplitInfo, + requestSplit, myId)) + + // splitIndex is in transit. Will be unset in the ShuffleClient + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set(splitIndex) + } + } + } else { + // Tracker replied back with a NO. Sleep for a while. + Thread.sleep(Shuffle.MinKnockInterval) + numThreadsToCreate = 0 + } + + numThreadsToCreate = numThreadsToCreate - splitIndices.size + } + + // Sleep for a while before creating new threads + Thread.sleep(Shuffle.MinKnockInterval) + } + + threadPool.shutdown() + + // Start consumer + // TODO: Consumption is delayed until everything has been received. + // Otherwise it interferes with network performance + var shuffleConsumer = new ShuffleConsumer(mergeCombiners) + shuffleConsumer.setDaemon(true) + shuffleConsumer.start() + logInfo("ShuffleConsumer started...") + + // Don't return until consumption is finished + // while (receivedData.size > 0) { + // Thread.sleep(Shuffle.MinKnockInterval) + // } + + // Wait till shuffleConsumer is done + shuffleConsumer.join + + combiners + }) + } + + private def getLocalSplitInfo(myId: Int): SplitInfo = { + var localSplitInfo = SplitInfo(InetAddress.getLocalHost.getHostAddress, + SplitInfo.UnusedParam, myId) + + // Store hasSplits + localSplitInfo.hasSplits = hasSplits + + // Store hasSplitsBitVector + hasSplitsBitVector.synchronized { + localSplitInfo.hasSplitsBitVector = + hasSplitsBitVector.clone.asInstanceOf[BitSet] + } + + // Store hasBlocksInSplit to hasBlocksPerInputSplit + hasBlocksInSplit.synchronized { + localSplitInfo.hasBlocksPerInputSplit = + hasBlocksInSplit.clone.asInstanceOf[Array[Int]] + } + + // Include the splitsInRequest as well + splitsInRequestBitVector.synchronized { + localSplitInfo.hasSplitsBitVector.or(splitsInRequestBitVector) + } + + return localSplitInfo + } + + def selectRandomSplit: Int = { + var requiredSplits = new ArrayBuffer[Int] + + synchronized { + for (i <- 0 until totalSplits) { + if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) { + requiredSplits += i + } + } + } + + if (requiredSplits.size > 0) { + requiredSplits(TrackedCustomBlockedLocalFileShuffle.ranGen.nextInt( + requiredSplits.size)) + } else { + -1 + } + } + + // Talks to the tracker and receives instruction + private def getTrackerSelectedSplit(myId: Int): ArrayBuffer[Int] = { + // Local status of hasSplitsBitVector and splitsInRequestBitVector + val localSplitInfo = getLocalSplitInfo(myId) + + // DO NOT talk to the tracker if all the required splits are already busy + if (localSplitInfo.hasSplitsBitVector.cardinality == totalSplits) { + return ArrayBuffer[Int]() + } + + val clientSocketToTracker = new Socket(Shuffle.MasterHostAddress, + Shuffle.MasterTrackerPort) + val oosTracker = + new ObjectOutputStream(clientSocketToTracker.getOutputStream) + oosTracker.flush() + val oisTracker = + new ObjectInputStream(clientSocketToTracker.getInputStream) + + var selectedSplitIndices = ArrayBuffer[Int]() + + // Setup the timeout mechanism + var timeOutTask = new TimerTask { + override def run: Unit = { + logInfo("Waited enough for tracker response... Take random response...") + + // sockets will be closed in finally + // TODO: Sometimes timer wont go off + + // TODO: Selecting randomly here. Tracker won't know about it and get an + // asssertion failure when this thread leaves + + selectedSplitIndices = ArrayBuffer(selectRandomSplit) + } + } + + var timeOutTimer = new Timer + // TODO: Which timeout to use? + // timeOutTimer.schedule(timeOutTask, Shuffle.MinKnockInterval) + + try { + // Send intention + oosTracker.writeObject(Shuffle.ReducerEntering) + oosTracker.flush() + + // Send what this reducer has + oosTracker.writeObject(localSplitInfo) + oosTracker.flush() + + // Receive reply from the tracker + selectedSplitIndices = oisTracker.readObject.asInstanceOf[ArrayBuffer[Int]] + + // Turn the timer OFF + timeOutTimer.cancel() + } catch { + case e: Exception => { + logInfo("getTrackerSelectedSplit had a " + e) + } + } finally { + oisTracker.close() + oosTracker.close() + clientSocketToTracker.close() + } + + return selectedSplitIndices + } + + class ShuffleTracker(outputLocs: Array[SplitInfo]) + extends Thread with Logging { + var threadPool = Shuffle.newDaemonCachedThreadPool + var serverSocket: ServerSocket = null + + // Create trackerStrategy object + val trackerStrategyClass = System.getProperty( + "spark.shuffle.trackerStrategy", + "spark.BalanceConnectionsShuffleTrackerStrategy") + + val trackerStrategy = + Class.forName(trackerStrategyClass).newInstance().asInstanceOf[ShuffleTrackerStrategy] + + // Must initialize here by supplying the outputLocs param + // TODO: This could be avoided by directly passing it to the constructor + trackerStrategy.initialize(outputLocs) + + override def run: Unit = { + serverSocket = new ServerSocket(Shuffle.MasterTrackerPort) + logInfo("ShuffleTracker" + serverSocket) + + try { + while (true) { + var clientSocket: Socket = null + try { + clientSocket = serverSocket.accept() + } catch { + case e: Exception => { + logInfo("ShuffleTracker had a " + e) + } + } + + if (clientSocket != null) { + try { + threadPool.execute(new Thread { + override def run: Unit = { + val oos = new ObjectOutputStream(clientSocket.getOutputStream) + oos.flush() + val ois = new ObjectInputStream(clientSocket.getInputStream) + + try { + // Receive intention + val reducerIntention = ois.readObject.asInstanceOf[Int] + + if (reducerIntention == Shuffle.ReducerEntering) { + // Receive what the reducer has + val reducerSplitInfo = + ois.readObject.asInstanceOf[SplitInfo] + + // Select splits and update stats if necessary + var selectedSplitIndices = ArrayBuffer[Int]() + trackerStrategy.synchronized { + selectedSplitIndices = trackerStrategy.selectSplit( + reducerSplitInfo) + } + + // Send reply back + oos.writeObject(selectedSplitIndices) + oos.flush() + + // Update internal stats, only if receiver got the reply + trackerStrategy.synchronized { + trackerStrategy.AddReducerToSplit(reducerSplitInfo, + selectedSplitIndices) + } + } + else if (reducerIntention == Shuffle.ReducerLeaving) { + val reducerSplitInfo = + ois.readObject.asInstanceOf[SplitInfo] + + // Receive reception stats: how many blocks the reducer + // read in how much time and from where + val receptionStat = + ois.readObject.asInstanceOf[ReceptionStats] + + // Update stats + trackerStrategy.synchronized { + trackerStrategy.deleteReducerFrom(reducerSplitInfo, + receptionStat) + } + + // Send ACK + oos.writeObject(receptionStat.serverSplitIndex) + oos.flush() + } + else { + throw new SparkException("Undefined reducerIntention") + } + } catch { + // EOFException is expected to happen because receiver can + // break connection due to timeout and pick random instead + case eofe: java.io.EOFException => { } + case e: Exception => { + logInfo("ShuffleTracker had a " + e) + } + } finally { + ois.close() + oos.close() + clientSocket.close() + } + } + }) + } catch { + // In failure, close socket here; else, client thread will close + case ioe: IOException => { + clientSocket.close() + } + } + } + } + } finally { + serverSocket.close() + } + // Shutdown the thread pool + threadPool.shutdown() + } + } + + class ShuffleConsumer(mergeCombiners: (C, C) => C) + extends Thread with Logging { + override def run: Unit = { + // Run until all splits are here + while (receivedData.size > 0) { + var splitIndex = -1 + var recvByteArray: Array[Byte] = null + + try { + var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])] + splitIndex = tempPair._1 + recvByteArray = tempPair._2 + } catch { + case e: Exception => { + logInfo("Exception during taking data from receivedData") + } + } + + val inputStream = + new ObjectInputStream(new ByteArrayInputStream(recvByteArray)) + + try{ + while (true) { + val (k, c) = inputStream.readObject.asInstanceOf[(K, C)] + combiners(k) = combiners.get(k) match { + case Some(oldC) => mergeCombiners(oldC, c) + case None => c + } + } + } catch { + case e: EOFException => { } + } + inputStream.close() + } + } + } + + class ShuffleClient(splitIndex: Int, serversplitInfo: SplitInfo, + requestSplit: String, myId: Int) + extends Thread with Logging { + private var peerSocketToSource: Socket = null + private var oosSource: ObjectOutputStream = null + private var oisSource: ObjectInputStream = null + + private var receptionSucceeded = false + + // Make sure that multiple messages don't go to the tracker + private var alreadySentLeavingNotification = false + + // Keep track of bytes received and time spent + private var numBytesReceived = 0 + private var totalTimeSpent = 0 + + override def run: Unit = { + // Setup the timeout mechanism + var timeOutTask = new TimerTask { + override def run: Unit = { + cleanUp() + } + } + + var timeOutTimer = new Timer + timeOutTimer.schedule(timeOutTask, Shuffle.MaxKnockInterval) + + try { + // Everything will break if BLOCKNUM is not correctly received + // First get BLOCKNUM file if totalBlocksInSplit(splitIndex) is unknown + peerSocketToSource = new Socket( + serversplitInfo.hostAddress, serversplitInfo.listenPort) + oosSource = + new ObjectOutputStream(peerSocketToSource.getOutputStream) + oosSource.flush() + var isSource = peerSocketToSource.getInputStream + oisSource = new ObjectInputStream(isSource) + + // Send path information + oosSource.writeObject(requestSplit) + + // TODO: Can be optimized. No need to do it everytime. + // Receive BLOCKNUM + totalBlocksInSplit(splitIndex) = oisSource.readObject.asInstanceOf[Int] + + // Turn the timer OFF, if the sender responds before timeout + timeOutTimer.cancel() + + while (hasBlocksInSplit(splitIndex) < totalBlocksInSplit(splitIndex)) { + // Set receptionSucceeded to false before trying for each block + receptionSucceeded = false + + // Request specific block + oosSource.writeObject(hasBlocksInSplit(splitIndex)) + + // Good to go. First, receive the length of the requested file + var requestedFileLen = oisSource.readObject.asInstanceOf[Int] + logInfo("Received requestedFileLen = " + requestedFileLen) + + // Create a temp variable to be used in different places + val requestPath = "http://%s:%d/shuffle/%s-%d".format( + serversplitInfo.hostAddress, serversplitInfo.listenPort, requestSplit, + hasBlocksInSplit(splitIndex)) + + // Receive the file + if (requestedFileLen != -1) { + val readStartTime = System.currentTimeMillis + logInfo("BEGIN READ: " + requestPath) + + // Receive data in an Array[Byte] + var recvByteArray = new Array[Byte](requestedFileLen) + var alreadyRead = 0 + var bytesRead = 0 + + while (alreadyRead != requestedFileLen) { + bytesRead = isSource.read(recvByteArray, alreadyRead, + requestedFileLen - alreadyRead) + if (bytesRead > 0) { + alreadyRead = alreadyRead + bytesRead + } + } + + // Make it available to the consumer + try { + receivedData.put((splitIndex, recvByteArray)) + } catch { + case e: Exception => { + logInfo("Exception during putting data into receivedData") + } + } + + // TODO: Updating stats before consumption is completed + hasBlocksInSplit(splitIndex) = hasBlocksInSplit(splitIndex) + 1 + + // Split has been received only if all the blocks have been received + if (hasBlocksInSplit(splitIndex) == totalBlocksInSplit(splitIndex)) { + hasSplitsBitVector.synchronized { + hasSplitsBitVector.set(splitIndex) + hasSplits += 1 + } + } + + receptionSucceeded = true + + logInfo("END READ: " + requestPath) + val readTime = System.currentTimeMillis - readStartTime + logInfo("Reading " + requestPath + " took " + readTime + " millis.") + + // Update stats + numBytesReceived = numBytesReceived + requestedFileLen + totalTimeSpent = totalTimeSpent + readTime.toInt + } else { + throw new SparkException("ShuffleServer " + serversplitInfo.hostAddress + " does not have " + requestSplit) + } + } + } catch { + // EOFException is expected to happen because sender can break + // connection due to timeout + case eofe: java.io.EOFException => { } + case e: Exception => { + logInfo("ShuffleClient had a " + e) + } + } finally { + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set(splitIndex, false) + } + cleanUp() + } + } + + // Connect to the tracker and update its stats + private def sendLeavingNotification(): Unit = synchronized { + if (!alreadySentLeavingNotification) { + val clientSocketToTracker = new Socket(Shuffle.MasterHostAddress, + Shuffle.MasterTrackerPort) + val oosTracker = + new ObjectOutputStream(clientSocketToTracker.getOutputStream) + oosTracker.flush() + val oisTracker = + new ObjectInputStream(clientSocketToTracker.getInputStream) + + try { + // Send intention + oosTracker.writeObject(Shuffle.ReducerLeaving) + oosTracker.flush() + + // Send reducerSplitInfo + oosTracker.writeObject(getLocalSplitInfo(myId)) + oosTracker.flush() + + // Send reception stats + oosTracker.writeObject(ReceptionStats( + numBytesReceived, totalTimeSpent, splitIndex)) + oosTracker.flush() + + // Receive ACK. No need to do anything with that + oisTracker.readObject.asInstanceOf[Int] + + // Now update sentLeavingNotifacation + alreadySentLeavingNotification = true + } catch { + case e: Exception => { + logInfo("sendLeavingNotification had a " + e) + } + } finally { + oisTracker.close() + oosTracker.close() + clientSocketToTracker.close() + } + } + } + + private def cleanUp(): Unit = { + // Update tracker stats first + sendLeavingNotification() + + // Clean up the connections to the mapper + if (oisSource != null) { + oisSource.close() + } + if (oosSource != null) { + oosSource.close() + } + if (peerSocketToSource != null) { + peerSocketToSource.close() + } + + logInfo("Leaving client") + } + } +} + +object TrackedCustomBlockedLocalFileShuffle extends Logging { + private var initialized = false + private var nextShuffleId = new AtomicLong(0) + + // Variables initialized by initializeIfNeeded() + private var shuffleDir: File = null + + private var shuffleServer: ShuffleServer = null + private var serverAddress = InetAddress.getLocalHost.getHostAddress + private var serverPort: Int = -1 + + // Random number generator + var ranGen = new Random + + private def initializeIfNeeded() = synchronized { + if (!initialized) { + // TODO: localDir should be created by some mechanism common to Spark + // so that it can be shared among shuffle, broadcast, etc + val localDirRoot = System.getProperty("spark.local.dir", "/tmp") + var tries = 0 + var foundLocalDir = false + var localDir: File = null + var localDirUuid: UUID = null + while (!foundLocalDir && tries < 10) { + tries += 1 + try { + localDirUuid = UUID.randomUUID + localDir = new File(localDirRoot, "spark-local-" + localDirUuid) + if (!localDir.exists) { + localDir.mkdirs() + foundLocalDir = true + } + } catch { + case e: Exception => + logWarning("Attempt " + tries + " to create local dir failed", e) + } + } + if (!foundLocalDir) { + logError("Failed 10 attempts to create local dir in " + localDirRoot) + System.exit(1) + } + shuffleDir = new File(localDir, "shuffle") + shuffleDir.mkdirs() + logInfo("Shuffle dir: " + shuffleDir) + + // Create and start the shuffleServer + shuffleServer = new ShuffleServer + shuffleServer.setDaemon(true) + shuffleServer.start() + logInfo("ShuffleServer started...") + + initialized = true + } + } + + def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int, + blockId: Int): File = { + initializeIfNeeded() + val dir = new File(shuffleDir, shuffleId + "/" + inputId) + dir.mkdirs() + val file = new File(dir, "%d-%d".format(outputId, blockId)) + return file + } + + def getBlockNumOutputFile(shuffleId: Long, inputId: Int, + outputId: Int): File = { + initializeIfNeeded() + val dir = new File(shuffleDir, shuffleId + "/" + inputId) + dir.mkdirs() + val file = new File(dir, outputId + "-BLOCKNUM") + return file + } + + def newShuffleId(): Long = { + nextShuffleId.getAndIncrement() + } + + class ShuffleServer + extends Thread with Logging { + var threadPool = Shuffle.newDaemonFixedThreadPool(Shuffle.MaxTxConnections) + + var serverSocket: ServerSocket = null + + override def run: Unit = { + serverSocket = new ServerSocket(0) + serverPort = serverSocket.getLocalPort + + logInfo("ShuffleServer started with " + serverSocket) + logInfo("Local URI: http://" + serverAddress + ":" + serverPort) + + try { + while (true) { + var clientSocket: Socket = null + try { + clientSocket = serverSocket.accept() + } catch { + case e: Exception => { } + } + if (clientSocket != null) { + logInfo("Serve: Accepted new client connection:" + clientSocket) + try { + threadPool.execute(new ShuffleServerThread(clientSocket)) + } catch { + // In failure, close socket here; else, the thread will close it + case ioe: IOException => { + clientSocket.close() + } + } + } + } + } finally { + if (serverSocket != null) { + logInfo("ShuffleServer now stopping...") + serverSocket.close() + } + } + // Shutdown the thread pool + threadPool.shutdown() + } + + class ShuffleServerThread(val clientSocket: Socket) + extends Thread with Logging { + private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream] + os.flush() + private val bos = new BufferedOutputStream(os) + bos.flush() + private val oos = new ObjectOutputStream(os) + oos.flush() + private val ois = new ObjectInputStream(clientSocket.getInputStream) + + logInfo("new ShuffleServerThread is running") + + override def run: Unit = { + try { + // Receive basic path information + var requestPathBase = ois.readObject.asInstanceOf[String] + + logInfo("requestPathBase: " + requestPathBase) + + // Read BLOCKNUM file and send back the total number of blocks + val blockNumFilePath = "%s/%s-BLOCKNUM".format(shuffleDir, + requestPathBase) + val blockNumIn = + new ObjectInputStream(new FileInputStream(blockNumFilePath)) + val BLOCKNUM = blockNumIn.readObject.asInstanceOf[Int] + blockNumIn.close() + + oos.writeObject(BLOCKNUM) + oos.flush() + + val startTime = System.currentTimeMillis + var curTime = startTime + var keepSending = true + var numBlocksToSend = Shuffle.MaxChatBlocks + + while (keepSending && numBlocksToSend > 0) { + // Receive specific block request + val blockId = ois.readObject.asInstanceOf[Int] + + // Ready to send + var requestPath = requestPathBase + "-" + blockId + + // Open the file + var requestedFile: File = null + var requestedFileLen = -1 + try { + requestedFile = new File(shuffleDir + "/" + requestPath) + requestedFileLen = requestedFile.length.toInt + } catch { + case e: Exception => { } + } + + // Send the length of the requestPath to let the receiver know that + // transfer is about to start + // In the case of receiver timeout and connection close, this will + // throw a java.net.SocketException: Broken pipe + oos.writeObject(requestedFileLen) + oos.flush() + + logInfo("requestedFileLen = " + requestedFileLen) + + // Read and send the requested file + if (requestedFileLen != -1) { + // Read + var byteArray = new Array[Byte](requestedFileLen) + val bis = + new BufferedInputStream(new FileInputStream(requestedFile)) + + var bytesRead = bis.read(byteArray, 0, byteArray.length) + var alreadyRead = bytesRead + + while (alreadyRead < requestedFileLen) { + bytesRead = bis.read(byteArray, alreadyRead, + (byteArray.length - alreadyRead)) + if(bytesRead > 0) { + alreadyRead = alreadyRead + bytesRead + } + } + bis.close() + + // Send + bos.write(byteArray, 0, byteArray.length) + bos.flush() + + // Update loop variables + numBlocksToSend = numBlocksToSend - 1 + + curTime = System.currentTimeMillis + // Revoke sending only if there is anyone waiting in the queue + // TODO: Turning OFF the optimization so that reducers go back to + // tracker get advice + if (curTime - startTime >= Shuffle.MaxChatTime /* && + threadPool.getQueue.size > 0 */) { + keepSending = false + } + } else { + // Close the connection + } + } + } catch { + // If something went wrong, e.g., the worker at the other end died etc + // then close everything up + // Exception can happen if the receiver stops receiving + // EOFException is expected to happen because receiver can break + // connection as soon as it has all the blocks + case eofe: java.io.EOFException => { } + case e: Exception => { + logInfo("ShuffleServerThread had a " + e) + } + } finally { + logInfo("ShuffleServerThread is closing streams and sockets") + ois.close() + // TODO: Following can cause "java.net.SocketException: Socket closed" + oos.close() + bos.close() + clientSocket.close() + } + } + } + } +} diff --git a/core/src/main/scala/spark/TrackedCustomParallelLocalFileShuffle.scala b/core/src/main/scala/spark/TrackedCustomParallelLocalFileShuffle.scala new file mode 100644 index 0000000000..4a38a8d7ff --- /dev/null +++ b/core/src/main/scala/spark/TrackedCustomParallelLocalFileShuffle.scala @@ -0,0 +1,802 @@ +package spark + +import java.io._ +import java.net._ +import java.util.{BitSet, Random, Timer, TimerTask, UUID} +import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory} + +import scala.collection.mutable.{ArrayBuffer, HashMap} + +/** + * An implementation of shuffle using local files served through custom server + * where receivers create simultaneous connections to multiple servers by + * setting the 'spark.shuffle.maxRxConnections' config option. + * + * 'spark.shuffle.maxTxConnections' enforces server-side cap. Ideally, + * maxTxConnections >= maxRxConnections * numReducersPerMachine + * + * 'spark.shuffle.TrackerStrategy' decides which strategy to use in the tracker + * + * TODO: Add support for compression when spark.compress is set to true. + */ +@serializable +class TrackedCustomParallelLocalFileShuffle[K, V, C] +extends Shuffle[K, V, C] with Logging { + @transient var totalSplits = 0 + @transient var hasSplits = 0 + @transient var hasSplitsBitVector: BitSet = null + @transient var splitsInRequestBitVector: BitSet = null + + @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null + @transient var combiners: HashMap[K,C] = null + + override def compute(input: RDD[(K, V)], + numOutputSplits: Int, + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C) + : RDD[(K, C)] = + { + val sc = input.sparkContext + val shuffleId = TrackedCustomParallelLocalFileShuffle.newShuffleId() + logInfo("Shuffle ID: " + shuffleId) + + val splitRdd = new NumberedSplitRDD(input) + val numInputSplits = splitRdd.splits.size + + // Run a parallel map and collect to write the intermediate data files + val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => { + val myIndex = pair._1 + val myIterator = pair._2 + val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C]) + for ((k, v) <- myIterator) { + var bucketId = k.hashCode % numOutputSplits + if (bucketId < 0) { // Fix bucket ID if hash code was negative + bucketId += numOutputSplits + } + val bucket = buckets(bucketId) + bucket(k) = bucket.get(k) match { + case Some(c) => mergeValue(c, v) + case None => createCombiner(v) + } + } + + for (i <- 0 until numOutputSplits) { + val file = TrackedCustomParallelLocalFileShuffle.getOutputFile(shuffleId, + myIndex, i) + val writeStartTime = System.currentTimeMillis + logInfo("BEGIN WRITE: " + file) + val out = new ObjectOutputStream(new FileOutputStream(file)) + buckets(i).foreach(pair => out.writeObject(pair)) + out.close() + logInfo("END WRITE: " + file) + val writeTime = System.currentTimeMillis - writeStartTime + logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.") + } + + (SplitInfo (TrackedCustomParallelLocalFileShuffle.serverAddress, + TrackedCustomParallelLocalFileShuffle.serverPort, myIndex)) + }).collect() + + // Start tracker + var shuffleTracker = new ShuffleTracker(outputLocs) + shuffleTracker.setDaemon(true) + shuffleTracker.start() + logInfo("ShuffleTracker started...") + + // Return an RDD that does each of the merges for a given partition + val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits) + return indexes.flatMap((myId: Int) => { + totalSplits = outputLocs.size + hasSplits = 0 + hasSplitsBitVector = new BitSet(totalSplits) + splitsInRequestBitVector = new BitSet(totalSplits) + + receivedData = new LinkedBlockingQueue[(Int, Array[Byte])] + combiners = new HashMap[K, C] + + var threadPool = Shuffle.newDaemonFixedThreadPool( + Shuffle.MaxRxConnections) + + while (hasSplits < totalSplits) { + // Local status of hasSplitsBitVector and splitsInRequestBitVector + val localSplitInfo = getLocalSplitInfo(myId) + + // DO NOT talk to the tracker if all the required splits are already busy + val hasOrWillHaveSplits = localSplitInfo.hasSplitsBitVector.cardinality + + var numThreadsToCreate = + Math.min(totalSplits - hasOrWillHaveSplits, Shuffle.MaxRxConnections) - + threadPool.getActiveCount + + while (hasSplits < totalSplits && numThreadsToCreate > 0) { + // Receive which split to pull from the tracker + logInfo("Talking to tracker...") + val startTime = System.currentTimeMillis + val splitIndices = getTrackerSelectedSplit(myId) + logInfo("Got %s from tracker in %d millis".format(splitIndices, System.currentTimeMillis - startTime)) + + if (splitIndices.size > 0) { + splitIndices.foreach { splitIndex => + val selectedSplitInfo = outputLocs(splitIndex) + val requestSplit = + "%d/%d/%d".format(shuffleId, selectedSplitInfo.splitId, myId) + + threadPool.execute(new ShuffleClient(splitIndex, selectedSplitInfo, + requestSplit, myId)) + + // splitIndex is in transit. Will be unset in the ShuffleClient + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set(splitIndex) + } + } + } else { + // Tracker replied back with a NO. Sleep for a while. + Thread.sleep(Shuffle.MinKnockInterval) + numThreadsToCreate = 0 + } + + numThreadsToCreate = numThreadsToCreate - splitIndices.size + } + + // Sleep for a while before creating new threads + Thread.sleep(Shuffle.MinKnockInterval) + } + + threadPool.shutdown() + + // Start consumer + // TODO: Consumption is delayed until everything has been received. + // Otherwise it interferes with network performance + var shuffleConsumer = new ShuffleConsumer(mergeCombiners) + shuffleConsumer.setDaemon(true) + shuffleConsumer.start() + logInfo("ShuffleConsumer started...") + + // Don't return until consumption is finished + // while (receivedData.size > 0) { + // Thread.sleep(Shuffle.MinKnockInterval) + // } + + // Wait till shuffleConsumer is done + shuffleConsumer.join + + combiners + }) + } + + private def getLocalSplitInfo(myId: Int): SplitInfo = { + var localSplitInfo = SplitInfo(InetAddress.getLocalHost.getHostAddress, + SplitInfo.UnusedParam, myId) + + localSplitInfo.hasSplits = hasSplits + + hasSplitsBitVector.synchronized { + localSplitInfo.hasSplitsBitVector = + hasSplitsBitVector.clone.asInstanceOf[BitSet] + } + + // Include the splitsInRequest as well + splitsInRequestBitVector.synchronized { + localSplitInfo.hasSplitsBitVector.or(splitsInRequestBitVector) + } + + return localSplitInfo + } + + // Selects a random split using local information + private def selectRandomSplit: Int = { + var requiredSplits = new ArrayBuffer[Int] + + synchronized { + for (i <- 0 until totalSplits) { + if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) { + requiredSplits += i + } + } + } + + if (requiredSplits.size > 0) { + requiredSplits(TrackedCustomParallelLocalFileShuffle.ranGen.nextInt( + requiredSplits.size)) + } else { + -1 + } + } + + // Talks to the tracker and receives instruction + private def getTrackerSelectedSplit(myId: Int): ArrayBuffer[Int] = { + // Local status of hasSplitsBitVector and splitsInRequestBitVector + val localSplitInfo = getLocalSplitInfo(myId) + + // DO NOT talk to the tracker if all the required splits are already busy + if (localSplitInfo.hasSplitsBitVector.cardinality == totalSplits) { + return ArrayBuffer[Int]() + } + + val clientSocketToTracker = new Socket(Shuffle.MasterHostAddress, + Shuffle.MasterTrackerPort) + val oosTracker = + new ObjectOutputStream(clientSocketToTracker.getOutputStream) + oosTracker.flush() + val oisTracker = + new ObjectInputStream(clientSocketToTracker.getInputStream) + + var selectedSplitIndices = ArrayBuffer[Int]() + + // Setup the timeout mechanism + var timeOutTask = new TimerTask { + override def run: Unit = { + logInfo("Waited enough for tracker response... Take random response...") + + // sockets will be closed in finally + // TODO: Sometimes timer wont go off + + // TODO: Selecting randomly here. Tracker won't know about it and get an + // asssertion failure when this thread leaves + + selectedSplitIndices = ArrayBuffer(selectRandomSplit) + } + } + + var timeOutTimer = new Timer + // TODO: Which timeout to use? + // timeOutTimer.schedule(timeOutTask, Shuffle.MinKnockInterval) + + try { + // Send intention + oosTracker.writeObject(Shuffle.ReducerEntering) + oosTracker.flush() + + // Send what this reducer has + oosTracker.writeObject(localSplitInfo) + oosTracker.flush() + + // Receive reply from the tracker + selectedSplitIndices = oisTracker.readObject.asInstanceOf[ArrayBuffer[Int]] + + // Turn the timer OFF + timeOutTimer.cancel() + } catch { + case e: Exception => { + logInfo("getTrackerSelectedSplit had a " + e) + } + } finally { + oisTracker.close() + oosTracker.close() + clientSocketToTracker.close() + } + + return selectedSplitIndices + } + + class ShuffleTracker(outputLocs: Array[SplitInfo]) + extends Thread with Logging { + var threadPool = Shuffle.newDaemonCachedThreadPool + var serverSocket: ServerSocket = null + + // Create trackerStrategy object + val trackerStrategyClass = System.getProperty( + "spark.shuffle.trackerStrategy", + "spark.BalanceConnectionsShuffleTrackerStrategy") + + val trackerStrategy = + Class.forName(trackerStrategyClass).newInstance().asInstanceOf[ShuffleTrackerStrategy] + + // Must initialize here by supplying the outputLocs param + // TODO: This could be avoided by directly passing it to the constructor + trackerStrategy.initialize(outputLocs) + + override def run: Unit = { + serverSocket = new ServerSocket(Shuffle.MasterTrackerPort) + logInfo("ShuffleTracker" + serverSocket) + + try { + while (true) { + var clientSocket: Socket = null + try { + clientSocket = serverSocket.accept() + } catch { + case e: Exception => { + logInfo("ShuffleTracker had a " + e) + } + } + + if (clientSocket != null) { + try { + threadPool.execute(new Thread { + override def run: Unit = { + val oos = new ObjectOutputStream(clientSocket.getOutputStream) + oos.flush() + val ois = new ObjectInputStream(clientSocket.getInputStream) + + try { + // Receive intention + val reducerIntention = ois.readObject.asInstanceOf[Int] + + if (reducerIntention == Shuffle.ReducerEntering) { + // Receive what the reducer has + val reducerSplitInfo = + ois.readObject.asInstanceOf[SplitInfo] + + // Select split and update stats if necessary + var selectedSplitIndices = ArrayBuffer[Int]() + trackerStrategy.synchronized { + selectedSplitIndices = trackerStrategy.selectSplit( + reducerSplitInfo) + } + + // Send reply back + oos.writeObject(selectedSplitIndices) + oos.flush() + + // Update internal stats, only if receiver got the reply + trackerStrategy.synchronized { + trackerStrategy.AddReducerToSplit(reducerSplitInfo, + selectedSplitIndices) + } + } + else if (reducerIntention == Shuffle.ReducerLeaving) { + val reducerSplitInfo = + ois.readObject.asInstanceOf[SplitInfo] + + // Receive reception stats: how many blocks the reducer + // read in how much time and from where + val receptionStat = + ois.readObject.asInstanceOf[ReceptionStats] + + // Update stats + trackerStrategy.synchronized { + trackerStrategy.deleteReducerFrom(reducerSplitInfo, + receptionStat) + } + + // Send ACK + oos.writeObject(receptionStat.serverSplitIndex) + oos.flush() + } + else { + throw new SparkException("Undefined reducerIntention") + } + } catch { + // EOFException is expected to happen because receiver can + // break connection due to timeout and pick random instead + case eofe: java.io.EOFException => { } + case e: Exception => { + logInfo("ShuffleTracker had a " + e) + } + } finally { + ois.close() + oos.close() + clientSocket.close() + } + } + }) + } catch { + // In failure, close socket here; else, client thread will close + case ioe: IOException => { + clientSocket.close() + } + } + } + } + } finally { + serverSocket.close() + } + // Shutdown the thread pool + threadPool.shutdown() + } + } + + class ShuffleConsumer(mergeCombiners: (C, C) => C) + extends Thread with Logging { + override def run: Unit = { + // Run until all splits are here + while (receivedData.size > 0) { + var splitIndex = -1 + var recvByteArray: Array[Byte] = null + + try { + var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])] + splitIndex = tempPair._1 + recvByteArray = tempPair._2 + } catch { + case e: Exception => { + logInfo("Exception during taking data from receivedData") + } + } + + val inputStream = + new ObjectInputStream(new ByteArrayInputStream(recvByteArray)) + + try{ + while (true) { + val (k, c) = inputStream.readObject.asInstanceOf[(K, C)] + combiners(k) = combiners.get(k) match { + case Some(oldC) => mergeCombiners(oldC, c) + case None => c + } + } + } catch { + case e: EOFException => { } + } + inputStream.close() + } + } + } + + class ShuffleClient(splitIndex: Int, serversplitInfo: SplitInfo, + requestSplit: String, myId: Int) + extends Thread with Logging { + private var peerSocketToSource: Socket = null + private var oosSource: ObjectOutputStream = null + private var oisSource: ObjectInputStream = null + + private var receptionSucceeded = false + + // Make sure that multiple messages don't go to the tracker + private var alreadySentLeavingNotification = false + + // Keep track of bytes received and time spent + private var numBytesReceived = 0 + private var totalTimeSpent = 0 + + override def run: Unit = { + // Setup the timeout mechanism + var timeOutTask = new TimerTask { + override def run: Unit = { + cleanUp() + } + } + + var timeOutTimer = new Timer + timeOutTimer.schedule(timeOutTask, Shuffle.MaxKnockInterval) + + // Create a temp variable to be used in different places + val requestPath = "http://%s:%d/shuffle/%s".format( + serversplitInfo.hostAddress, serversplitInfo.listenPort, requestSplit) + + logInfo("ShuffleClient started... => " + requestPath) + + try { + // Connect to the source + peerSocketToSource = new Socket( + serversplitInfo.hostAddress, serversplitInfo.listenPort) + oosSource = + new ObjectOutputStream(peerSocketToSource.getOutputStream) + oosSource.flush() + var isSource = peerSocketToSource.getInputStream + oisSource = new ObjectInputStream(isSource) + + // Send the request + oosSource.writeObject(requestSplit) + oosSource.flush() + + // Receive the length of the requested file + var requestedFileLen = oisSource.readObject.asInstanceOf[Int] + logInfo("Received requestedFileLen = " + requestedFileLen) + + // Turn the timer OFF, if the sender responds before timeout + timeOutTimer.cancel() + + // Receive the file + if (requestedFileLen != -1) { + val readStartTime = System.currentTimeMillis + logInfo("BEGIN READ: " + requestPath) + + // Receive data in an Array[Byte] + var recvByteArray = new Array[Byte](requestedFileLen) + var alreadyRead = 0 + var bytesRead = 0 + + while (alreadyRead < requestedFileLen) { + bytesRead = isSource.read(recvByteArray, alreadyRead, + requestedFileLen - alreadyRead) + if (bytesRead > 0) { + alreadyRead = alreadyRead + bytesRead + } + } + + // Make it available to the consumer + try { + receivedData.put((splitIndex, recvByteArray)) + } catch { + case e: Exception => { + logInfo("Exception during putting data into receivedData") + } + } + + // TODO: Updating stats before consumption is completed + hasSplitsBitVector.synchronized { + hasSplitsBitVector.set(splitIndex) + hasSplits += 1 + } + + // We have received splitIndex + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set(splitIndex, false) + } + + receptionSucceeded = true + + logInfo("END READ: " + requestPath) + val readTime = System.currentTimeMillis - readStartTime + logInfo("Reading " + requestPath + " took " + readTime + " millis.") + + // Update stats + numBytesReceived = numBytesReceived + requestedFileLen + totalTimeSpent = totalTimeSpent + readTime.toInt + } else { + throw new SparkException("ShuffleServer " + serversplitInfo.hostAddress + " does not have " + requestSplit) + } + } catch { + // EOFException is expected to happen because sender can break + // connection due to timeout + case eofe: java.io.EOFException => { } + case e: Exception => { + logInfo("ShuffleClient had a " + e) + } + } finally { + // If reception failed, unset for future retry + if (!receptionSucceeded) { + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set(splitIndex, false) + } + } + cleanUp() + } + } + + // Connect to the tracker and update its stats + private def sendLeavingNotification(): Unit = synchronized { + if (!alreadySentLeavingNotification) { + val clientSocketToTracker = new Socket(Shuffle.MasterHostAddress, + Shuffle.MasterTrackerPort) + val oosTracker = + new ObjectOutputStream(clientSocketToTracker.getOutputStream) + oosTracker.flush() + val oisTracker = + new ObjectInputStream(clientSocketToTracker.getInputStream) + + try { + // Send intention + oosTracker.writeObject(Shuffle.ReducerLeaving) + oosTracker.flush() + + // Send reducerSplitInfo + oosTracker.writeObject(getLocalSplitInfo(myId)) + oosTracker.flush() + + // Send reception stats + oosTracker.writeObject(ReceptionStats( + numBytesReceived, totalTimeSpent, splitIndex)) + oosTracker.flush() + + // Receive ACK. No need to do anything with that + oisTracker.readObject.asInstanceOf[Int] + + // Now update sentLeavingNotifacation + alreadySentLeavingNotification = true + } catch { + case e: Exception => { + logInfo("sendLeavingNotification had a " + e) + } + } finally { + oisTracker.close() + oosTracker.close() + clientSocketToTracker.close() + } + } + } + + private def cleanUp(): Unit = { + // Update tracker stats first + sendLeavingNotification() + + // Clean up the connections to the mapper + if (oisSource != null) { + oisSource.close() + } + if (oosSource != null) { + oosSource.close() + } + if (peerSocketToSource != null) { + peerSocketToSource.close() + } + } + } +} + +object TrackedCustomParallelLocalFileShuffle extends Logging { + private var initialized = false + private var nextShuffleId = new AtomicLong(0) + + // Variables initialized by initializeIfNeeded() + private var shuffleDir: File = null + + private var shuffleServer: ShuffleServer = null + private var serverAddress = InetAddress.getLocalHost.getHostAddress + private var serverPort: Int = -1 + + // Random number generator + var ranGen = new Random + + private def initializeIfNeeded() = synchronized { + if (!initialized) { + // TODO: localDir should be created by some mechanism common to Spark + // so that it can be shared among shuffle, broadcast, etc + val localDirRoot = System.getProperty("spark.local.dir", "/tmp") + var tries = 0 + var foundLocalDir = false + var localDir: File = null + var localDirUuid: UUID = null + + while (!foundLocalDir && tries < 10) { + tries += 1 + try { + localDirUuid = UUID.randomUUID + localDir = new File(localDirRoot, "spark-local-" + localDirUuid) + if (!localDir.exists) { + localDir.mkdirs() + foundLocalDir = true + } + } catch { + case e: Exception => + logWarning("Attempt " + tries + " to create local dir failed", e) + } + } + if (!foundLocalDir) { + logError("Failed 10 attempts to create local dir in " + localDirRoot) + System.exit(1) + } + shuffleDir = new File(localDir, "shuffle") + shuffleDir.mkdirs() + logInfo("Shuffle dir: " + shuffleDir) + + // Create and start the shuffleServer + shuffleServer = new ShuffleServer + shuffleServer.setDaemon(true) + shuffleServer.start() + logInfo("ShuffleServer started...") + + initialized = true + } + } + + def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int): File = { + initializeIfNeeded() + val dir = new File(shuffleDir, shuffleId + "/" + inputId) + dir.mkdirs() + val file = new File(dir, "" + outputId) + return file + } + + def newShuffleId(): Long = { + nextShuffleId.getAndIncrement() + } + + class ShuffleServer + extends Thread with Logging { + var threadPool = Shuffle.newDaemonFixedThreadPool(Shuffle.MaxTxConnections) + + var serverSocket: ServerSocket = null + + override def run: Unit = { + serverSocket = new ServerSocket(0) + serverPort = serverSocket.getLocalPort + + logInfo("ShuffleServer started with " + serverSocket) + logInfo("Local URI: http://" + serverAddress + ":" + serverPort) + + try { + while (true) { + var clientSocket: Socket = null + try { + clientSocket = serverSocket.accept() + } catch { + case e: Exception => { } + } + if (clientSocket != null) { + logInfo("Serve: Accepted new client connection:" + clientSocket) + try { + threadPool.execute(new ShuffleServerThread(clientSocket)) + } catch { + // In failure, close socket here; else, the thread will close it + case ioe: IOException => { + clientSocket.close() + } + } + } + } + } finally { + if (serverSocket != null) { + logInfo("ShuffleServer now stopping...") + serverSocket.close() + } + } + // Shutdown the thread pool + threadPool.shutdown() + } + + class ShuffleServerThread(val clientSocket: Socket) + extends Thread with Logging { + private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream] + os.flush() + private val bos = new BufferedOutputStream(os) + bos.flush() + private val oos = new ObjectOutputStream(os) + oos.flush() + private val ois = new ObjectInputStream(clientSocket.getInputStream) + + logInfo("new ShuffleServerThread is running") + + override def run: Unit = { + try { + // Receive requestPath from the receiver + var requestPath = ois.readObject.asInstanceOf[String] + logInfo("requestPath: " + shuffleDir + "/" + requestPath) + + // Open the file + var requestedFile: File = null + var requestedFileLen = -1 + try { + requestedFile = new File(shuffleDir + "/" + requestPath) + requestedFileLen = requestedFile.length.toInt + } catch { + case e: Exception => { } + } + + // Send the length of the requestPath to let the receiver know that + // transfer is about to start + // In the case of receiver timeout and connection close, this will + // throw a java.net.SocketException: Broken pipe + oos.writeObject(requestedFileLen) + oos.flush() + + logInfo("requestedFileLen = " + requestedFileLen) + + // Read and send the requested file + if (requestedFileLen != -1) { + // Read + var byteArray = new Array[Byte](requestedFileLen) + val bis = + new BufferedInputStream(new FileInputStream(requestedFile)) + + var bytesRead = bis.read(byteArray, 0, byteArray.length) + var alreadyRead = bytesRead + + while (alreadyRead < requestedFileLen) { + bytesRead = bis.read(byteArray, alreadyRead, + (byteArray.length - alreadyRead)) + if(bytesRead > 0) { + alreadyRead = alreadyRead + bytesRead + } + } + bis.close() + + // Send + bos.write(byteArray, 0, byteArray.length) + bos.flush() + } else { + // Close the connection + } + } catch { + // If something went wrong, e.g., the worker at the other end died etc + // then close everything up + // Exception can happen if the receiver stops receiving + case e: Exception => { + logInfo("ShuffleServerThread had a " + e) + } + } finally { + logInfo("ShuffleServerThread is closing streams and sockets") + ois.close() + // TODO: Following can cause "java.net.SocketException: Socket closed" + oos.close() + bos.close() + clientSocket.close() + } + } + } + } +} diff --git a/examples/src/main/scala/spark/examples/GroupByTest.scala b/examples/src/main/scala/spark/examples/GroupByTest.scala new file mode 100644 index 0000000000..48c02a52c6 --- /dev/null +++ b/examples/src/main/scala/spark/examples/GroupByTest.scala @@ -0,0 +1,37 @@ +package spark.examples + +import spark.SparkContext +import spark.SparkContext._ +import java.util.Random + +object GroupByTest { + def main(args: Array[String]) { + if (args.length == 0) { + System.err.println("Usage: GroupByTest <host> [numMappers] [numKVPairs] [KeySize] [numReducers]") + System.exit(1) + } + + var numMappers = if (args.length > 1) args(1).toInt else 2 + var numKVPairs = if (args.length > 2) args(2).toInt else 1000 + var valSize = if (args.length > 3) args(3).toInt else 1000 + var numReducers = if (args.length > 4) args(4).toInt else numMappers + + val sc = new SparkContext(args(0), "GroupBy Test") + + val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p => + val ranGen = new Random + var arr1 = new Array[(Int, Array[Byte])](numKVPairs) + for (i <- 0 until numKVPairs) { + val byteArr = new Array[Byte](valSize) + ranGen.nextBytes(byteArr) + arr1(i) = (ranGen.nextInt(Int.MaxValue), byteArr) + } + arr1 + }.cache + // Enforce that everything has been calculated and in cache + pairs1.count + + println(pairs1.groupByKey(numReducers).count) + } +} + diff --git a/examples/src/main/scala/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/spark/examples/SimpleSkewedGroupByTest.scala new file mode 100644 index 0000000000..c8edb7d8b4 --- /dev/null +++ b/examples/src/main/scala/spark/examples/SimpleSkewedGroupByTest.scala @@ -0,0 +1,51 @@ +package spark.examples + +import spark.SparkContext +import spark.SparkContext._ +import java.util.Random + +object SimpleSkewedGroupByTest { + def main(args: Array[String]) { + if (args.length == 0) { + System.err.println("Usage: SimpleSkewedGroupByTest <host> " + + "[numMappers] [numKVPairs] [valSize] [numReducers] [ratio]") + System.exit(1) + } + + var numMappers = if (args.length > 1) args(1).toInt else 2 + var numKVPairs = if (args.length > 2) args(2).toInt else 1000 + var valSize = if (args.length > 3) args(3).toInt else 1000 + var numReducers = if (args.length > 4) args(4).toInt else numMappers + var ratio = if (args.length > 5) args(5).toInt else 5.0 + + val sc = new SparkContext(args(0), "GroupBy Test") + + val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p => + val ranGen = new Random + var result = new Array[(Int, Array[Byte])](numKVPairs) + for (i <- 0 until numKVPairs) { + val byteArr = new Array[Byte](valSize) + ranGen.nextBytes(byteArr) + val offset = ranGen.nextInt(1000) * numReducers + if (ranGen.nextDouble < ratio / (numReducers + ratio - 1)) { + // give ratio times higher chance of generating key 0 (for reducer 0) + result(i) = (offset, byteArr) + } else { + // generate a key for one of the other reducers + val key = 1 + ranGen.nextInt(numReducers-1) + offset + result(i) = (key, byteArr) + } + } + result + }.cache + // Enforce that everything has been calculated and in cache + pairs1.count + + println("RESULT: " + pairs1.groupByKey(numReducers).count) + // Print how many keys each reducer got (for debugging) + //println("RESULT: " + pairs1.groupByKey(numReducers) + // .map{case (k,v) => (k, v.size)} + // .collectAsMap) + } +} + diff --git a/examples/src/main/scala/spark/examples/SkewedGroupByTest.scala b/examples/src/main/scala/spark/examples/SkewedGroupByTest.scala new file mode 100644 index 0000000000..e6dec44bed --- /dev/null +++ b/examples/src/main/scala/spark/examples/SkewedGroupByTest.scala @@ -0,0 +1,41 @@ +package spark.examples + +import spark.SparkContext +import spark.SparkContext._ +import java.util.Random + +object SkewedGroupByTest { + def main(args: Array[String]) { + if (args.length == 0) { + System.err.println("Usage: GroupByTest <host> [numMappers] [numKVPairs] [KeySize] [numReducers]") + System.exit(1) + } + + var numMappers = if (args.length > 1) args(1).toInt else 2 + var numKVPairs = if (args.length > 2) args(2).toInt else 1000 + var valSize = if (args.length > 3) args(3).toInt else 1000 + var numReducers = if (args.length > 4) args(4).toInt else numMappers + + val sc = new SparkContext(args(0), "GroupBy Test") + + val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p => + val ranGen = new Random + + // map output sizes lineraly increase from the 1st to the last + numKVPairs = (1. * (p + 1) / numMappers * numKVPairs).toInt + + var arr1 = new Array[(Int, Array[Byte])](numKVPairs) + for (i <- 0 until numKVPairs) { + val byteArr = new Array[Byte](valSize) + ranGen.nextBytes(byteArr) + arr1(i) = (ranGen.nextInt(Int.MaxValue), byteArr) + } + arr1 + }.cache + // Enforce that everything has been calculated and in cache + pairs1.count + + println(pairs1.groupByKey(numReducers).count) + } +} + |