From 5d70ee4663b5589611aef107f914a94f301c7d2a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 7 May 2013 22:42:15 -0700 Subject: Cleaned up connection manager (moved many classes to their own files). --- .../main/scala/spark/network/BufferMessage.scala | 94 +++++++++++ core/src/main/scala/spark/network/Connection.scala | 137 +++++++++------- .../scala/spark/network/ConnectionManager.scala | 53 +++--- .../scala/spark/network/ConnectionManagerId.scala | 21 +++ core/src/main/scala/spark/network/Message.scala | 179 ++------------------- .../main/scala/spark/network/MessageChunk.scala | 25 +++ .../scala/spark/network/MessageChunkHeader.scala | 58 +++++++ 7 files changed, 315 insertions(+), 252 deletions(-) create mode 100644 core/src/main/scala/spark/network/BufferMessage.scala create mode 100644 core/src/main/scala/spark/network/ConnectionManagerId.scala create mode 100644 core/src/main/scala/spark/network/MessageChunk.scala create mode 100644 core/src/main/scala/spark/network/MessageChunkHeader.scala (limited to 'core') diff --git a/core/src/main/scala/spark/network/BufferMessage.scala b/core/src/main/scala/spark/network/BufferMessage.scala new file mode 100644 index 0000000000..7b0e489a6c --- /dev/null +++ b/core/src/main/scala/spark/network/BufferMessage.scala @@ -0,0 +1,94 @@ +package spark.network + +import java.nio.ByteBuffer + +import scala.collection.mutable.ArrayBuffer + +import spark.storage.BlockManager + + +private[spark] +class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int) + extends Message(Message.BUFFER_MESSAGE, id_) { + + val initialSize = currentSize() + var gotChunkForSendingOnce = false + + def size = initialSize + + def currentSize() = { + if (buffers == null || buffers.isEmpty) { + 0 + } else { + buffers.map(_.remaining).reduceLeft(_ + _) + } + } + + def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = { + if (maxChunkSize <= 0) { + throw new Exception("Max chunk size is " + maxChunkSize) + } + + if (size == 0 && gotChunkForSendingOnce == false) { + val newChunk = new MessageChunk( + new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null) + gotChunkForSendingOnce = true + return Some(newChunk) + } + + while(!buffers.isEmpty) { + val buffer = buffers(0) + if (buffer.remaining == 0) { + BlockManager.dispose(buffer) + buffers -= buffer + } else { + val newBuffer = if (buffer.remaining <= maxChunkSize) { + buffer.duplicate() + } else { + buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer] + } + buffer.position(buffer.position + newBuffer.remaining) + val newChunk = new MessageChunk(new MessageChunkHeader( + typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer) + gotChunkForSendingOnce = true + return Some(newChunk) + } + } + None + } + + def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = { + // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer + if (buffers.size > 1) { + throw new Exception("Attempting to get chunk from message with multiple data buffers") + } + val buffer = buffers(0) + if (buffer.remaining > 0) { + if (buffer.remaining < chunkSize) { + throw new Exception("Not enough space in data buffer for receiving chunk") + } + val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer] + buffer.position(buffer.position + newBuffer.remaining) + val newChunk = new MessageChunk(new MessageChunkHeader( + typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer) + return Some(newChunk) + } + None + } + + def flip() { + buffers.foreach(_.flip) + } + + def hasAckId() = (ackId != 0) + + def isCompletelyReceived() = !buffers(0).hasRemaining + + override def toString = { + if (hasAckId) { + "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")" + } else { + "BufferMessage(id = " + id + ", size = " + size + ")" + } + } +} \ No newline at end of file diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala index 00a0433a44..6e28f677a3 100644 --- a/core/src/main/scala/spark/network/Connection.scala +++ b/core/src/main/scala/spark/network/Connection.scala @@ -13,12 +13,13 @@ import java.net._ private[spark] abstract class Connection(val channel: SocketChannel, val selector: Selector, - val socketRemoteConnectionManagerId: ConnectionManagerId) extends Logging { + val socketRemoteConnectionManagerId: ConnectionManagerId) + extends Logging { + def this(channel_ : SocketChannel, selector_ : Selector) = { this(channel_, selector_, - ConnectionManagerId.fromSocketAddress( - channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress] - )) + ConnectionManagerId.fromSocketAddress( + channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress])) } channel.configureBlocking(false) @@ -32,17 +33,19 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, var onKeyInterestChangeCallback: (Connection, Int) => Unit = null val remoteAddress = getRemoteAddress() - + // Read channels typically do not register for write and write does not for read // Now, we do have write registering for read too (temporarily), but this is to detect // channel close NOT to actually read/consume data on it ! // How does this work if/when we move to SSL ? - + // What is the interest to register with selector for when we want this connection to be selected def registerInterest() - // What is the interest to register with selector for when we want this connection to be de-selected - // Traditionally, 0 - but in our case, for example, for close-detection on SendingConnection hack, it will be - // SelectionKey.OP_READ (until we fix it properly) + + // What is the interest to register with selector for when we want this connection to + // be de-selected + // Traditionally, 0 - but in our case, for example, for close-detection on SendingConnection hack, + // it will be SelectionKey.OP_READ (until we fix it properly) def unregisterInterest() // On receiving a read event, should we change the interest for this channel or not ? @@ -64,12 +67,14 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, // Returns whether we have to register for further reads or not. def read(): Boolean = { - throw new UnsupportedOperationException("Cannot read on connection of type " + this.getClass.toString) + throw new UnsupportedOperationException( + "Cannot read on connection of type " + this.getClass.toString) } // Returns whether we have to register for further writes or not. def write(): Boolean = { - throw new UnsupportedOperationException("Cannot write on connection of type " + this.getClass.toString) + throw new UnsupportedOperationException( + "Cannot write on connection of type " + this.getClass.toString) } def close() { @@ -81,11 +86,17 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, callOnCloseCallback() } - def onClose(callback: Connection => Unit) {onCloseCallback = callback} + def onClose(callback: Connection => Unit) { + onCloseCallback = callback + } - def onException(callback: (Connection, Exception) => Unit) {onExceptionCallback = callback} + def onException(callback: (Connection, Exception) => Unit) { + onExceptionCallback = callback + } - def onKeyInterestChange(callback: (Connection, Int) => Unit) {onKeyInterestChangeCallback = callback} + def onKeyInterestChange(callback: (Connection, Int) => Unit) { + onKeyInterestChangeCallback = callback + } def callOnExceptionCallback(e: Exception) { if (onExceptionCallback != null) { @@ -95,7 +106,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, " and OnExceptionCallback not registered", e) } } - + def callOnCloseCallback() { if (onCloseCallback != null) { onCloseCallback(this) @@ -132,24 +143,25 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, print(" (" + position + ", " + length + ")") buffer.position(curPosition) } - } -private[spark] class SendingConnection(val address: InetSocketAddress, selector_ : Selector, - remoteId_ : ConnectionManagerId) -extends Connection(SocketChannel.open, selector_, remoteId_) { +private[spark] +class SendingConnection(val address: InetSocketAddress, selector_ : Selector, + remoteId_ : ConnectionManagerId) + extends Connection(SocketChannel.open, selector_, remoteId_) { class Outbox(fair: Int = 0) { val messages = new Queue[Message]() - val defaultChunkSize = 65536 //32768 //16384 + val defaultChunkSize = 65536 //32768 //16384 var nextMessageToBeUsed = 0 def addMessage(message: Message) { - messages.synchronized{ + messages.synchronized{ /*messages += message*/ messages.enqueue(message) - logDebug("Added [" + message + "] to outbox for sending to [" + getRemoteConnectionManagerId() + "]") + logDebug("Added [" + message + "] to outbox for sending to " + + "[" + getRemoteConnectionManagerId() + "]") } } @@ -174,7 +186,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { message.started = true message.startTime = System.currentTimeMillis } - return chunk + return chunk } else { /*logInfo("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + "]")*/ message.finishTime = System.currentTimeMillis @@ -185,7 +197,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { } None } - + private def getChunkRR(): Option[MessageChunk] = { messages.synchronized { while (!messages.isEmpty) { @@ -197,12 +209,14 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { messages.enqueue(message) nextMessageToBeUsed = nextMessageToBeUsed + 1 if (!message.started) { - logDebug("Starting to send [" + message + "] to [" + getRemoteConnectionManagerId() + "]") + logDebug( + "Starting to send [" + message + "] to [" + getRemoteConnectionManagerId() + "]") message.started = true message.startTime = System.currentTimeMillis } - logTrace("Sending chunk from [" + message+ "] to [" + getRemoteConnectionManagerId() + "]") - return chunk + logTrace( + "Sending chunk from [" + message+ "] to [" + getRemoteConnectionManagerId() + "]") + return chunk } else { message.finishTime = System.currentTimeMillis logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + @@ -213,7 +227,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { None } } - + private val outbox = new Outbox(1) val currentBuffers = new ArrayBuffer[ByteBuffer]() @@ -228,11 +242,11 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { // it does - so let us keep it for now. changeConnectionKeyInterest(SelectionKey.OP_WRITE | DEFAULT_INTEREST) } - + override def unregisterInterest() { changeConnectionKeyInterest(DEFAULT_INTEREST) } - + def send(message: Message) { outbox.synchronized { outbox.addMessage(message) @@ -262,12 +276,14 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { // selection - though need not necessarily always complete successfully. val connected = channel.finishConnect if (!force && !connected) { - logInfo("finish connect failed [" + address + "], " + outbox.messages.size + " messages pending") + logInfo( + "finish connect failed [" + address + "], " + outbox.messages.size + " messages pending") return false } // Fallback to previous behavior - assume finishConnect completed - // This will happen only when finishConnect failed for some repeated number of times (10 or so) + // This will happen only when finishConnect failed for some repeated number of times + // (10 or so) // Is highly unlikely unless there was an unclean close of socket, etc registerInterest() logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending") @@ -283,13 +299,13 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { } override def write(): Boolean = { - try{ - while(true) { + try { + while (true) { if (currentBuffers.size == 0) { outbox.synchronized { outbox.getChunk() match { case Some(chunk) => { - currentBuffers ++= chunk.buffers + currentBuffers ++= chunk.buffers } case None => { // changeConnectionKeyInterest(0) @@ -299,7 +315,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { } } } - + if (currentBuffers.size > 0) { val buffer = currentBuffers(0) val remainingBytes = buffer.remaining @@ -314,7 +330,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { } } } catch { - case e: Exception => { + case e: Exception => { logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e) callOnExceptionCallback(e) close() @@ -336,7 +352,8 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { if (length == -1) { // EOF close() } else if (length > 0) { - logWarning("Unexpected data read from SendingConnection to " + getRemoteConnectionManagerId()) + logWarning( + "Unexpected data read from SendingConnection to " + getRemoteConnectionManagerId()) } } catch { case e: Exception => @@ -355,30 +372,32 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { // Must be created within selector loop - else deadlock -private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector) -extends Connection(channel_, selector_) { - +private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector) + extends Connection(channel_, selector_) { + class Inbox() { val messages = new HashMap[Int, BufferMessage]() - + def getChunk(header: MessageChunkHeader): Option[MessageChunk] = { - + def createNewMessage: BufferMessage = { val newMessage = Message.create(header).asInstanceOf[BufferMessage] newMessage.started = true newMessage.startTime = System.currentTimeMillis - logDebug("Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]") + logDebug( + "Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]") messages += ((newMessage.id, newMessage)) newMessage } - + val message = messages.getOrElseUpdate(header.id, createNewMessage) - logTrace("Receiving chunk of [" + message + "] from [" + getRemoteConnectionManagerId() + "]") + logTrace( + "Receiving chunk of [" + message + "] from [" + getRemoteConnectionManagerId() + "]") message.getChunkForReceiving(header.chunkSize) } - + def getMessageForChunk(chunk: MessageChunk): Option[BufferMessage] = { - messages.get(chunk.header.id) + messages.get(chunk.header.id) } def removeMessage(message: Message) { @@ -387,12 +406,14 @@ extends Connection(channel_, selector_) { } @volatile private var inferredRemoteManagerId: ConnectionManagerId = null + override def getRemoteConnectionManagerId(): ConnectionManagerId = { val currId = inferredRemoteManagerId if (currId != null) currId else super.getRemoteConnectionManagerId() } - // The reciever's remote address is the local socket on remote side : which is NOT the connection manager id of the receiver. + // The reciever's remote address is the local socket on remote side : which is NOT + // the connection manager id of the receiver. // We infer that from the messages we receive on the receiver socket. private def processConnectionManagerId(header: MessageChunkHeader) { val currId = inferredRemoteManagerId @@ -428,7 +449,8 @@ extends Connection(channel_, selector_) { } headerBuffer.flip if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) { - throw new Exception("Unexpected number of bytes (" + headerBuffer.remaining + ") in the header") + throw new Exception( + "Unexpected number of bytes (" + headerBuffer.remaining + ") in the header") } val header = MessageChunkHeader.create(headerBuffer) headerBuffer.clear() @@ -451,9 +473,9 @@ extends Connection(channel_, selector_) { case _ => throw new Exception("Message of unknown type received") } } - + if (currentChunk == null) throw new Exception("No message chunk to receive data") - + val bytesRead = channel.read(currentChunk.buffer) if (bytesRead == 0) { // re-register for read event ... @@ -464,14 +486,15 @@ extends Connection(channel_, selector_) { } /*logDebug("Read " + bytesRead + " bytes for the buffer")*/ - + if (currentChunk.buffer.remaining == 0) { /*println("Filled buffer at " + System.currentTimeMillis)*/ val bufferMessage = inbox.getMessageForChunk(currentChunk).get if (bufferMessage.isCompletelyReceived) { bufferMessage.flip bufferMessage.finishTime = System.currentTimeMillis - logDebug("Finished receiving [" + bufferMessage + "] from [" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken) + logDebug("Finished receiving [" + bufferMessage + "] from " + + "[" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken) if (onReceiveCallback != null) { onReceiveCallback(this, bufferMessage) } @@ -481,7 +504,7 @@ extends Connection(channel_, selector_) { } } } catch { - case e: Exception => { + case e: Exception => { logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e) callOnExceptionCallback(e) close() @@ -491,7 +514,7 @@ extends Connection(channel_, selector_) { // should not happen - to keep scala compiler happy return true } - + def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback} override def changeInterestForRead(): Boolean = true @@ -505,7 +528,7 @@ extends Connection(channel_, selector_) { // it does - so let us keep it for now. changeConnectionKeyInterest(SelectionKey.OP_READ) } - + override def unregisterInterest() { changeConnectionKeyInterest(0) } diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index 0eb03630d0..624a094856 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -18,20 +18,7 @@ import akka.dispatch.{Await, Promise, ExecutionContext, Future} import akka.util.Duration import akka.util.duration._ -private[spark] case class ConnectionManagerId(host: String, port: Int) { - // DEBUG code - Utils.checkHost(host) - assert (port > 0) - def toSocketAddress() = new InetSocketAddress(host, port) -} - -private[spark] object ConnectionManagerId { - def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = { - new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort()) - } -} - private[spark] class ConnectionManager(port: Int) extends Logging { class MessageStatus( @@ -45,7 +32,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging { def markDone() { completionHandler(this) } } - + private val selector = SelectorProvider.provider.openSelector() private val handleMessageExecutor = new ThreadPoolExecutor( @@ -80,7 +67,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging { serverChannel.configureBlocking(false) serverChannel.socket.setReuseAddress(true) - serverChannel.socket.setReceiveBufferSize(256 * 1024) + serverChannel.socket.setReceiveBufferSize(256 * 1024) serverChannel.socket.bind(new InetSocketAddress(port)) serverChannel.register(selector, SelectionKey.OP_ACCEPT) @@ -351,7 +338,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging { case e: Exception => logError("Error in select loop", e) } } - + def acceptConnection(key: SelectionKey) { val serverChannel = key.channel.asInstanceOf[ServerSocketChannel] @@ -463,7 +450,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging { def receiveMessage(connection: Connection, message: Message) { val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress) - logDebug("Received [" + message + "] from [" + connectionManagerId + "]") + logDebug("Received [" + message + "] from [" + connectionManagerId + "]") val runnable = new Runnable() { val creationTime = System.currentTimeMillis def run() { @@ -483,11 +470,11 @@ private[spark] class ConnectionManager(port: Int) extends Logging { if (bufferMessage.hasAckId) { val sentMessageStatus = messageStatuses.synchronized { messageStatuses.get(bufferMessage.ackId) match { - case Some(status) => { - messageStatuses -= bufferMessage.ackId + case Some(status) => { + messageStatuses -= bufferMessage.ackId status } - case None => { + case None => { throw new Exception("Could not find reference for received ack message " + message.id) null } @@ -507,7 +494,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging { logDebug("Not calling back as callback is null") None } - + if (ackMessage.isDefined) { if (!ackMessage.get.isInstanceOf[BufferMessage]) { logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " + ackMessage.get.getClass()) @@ -517,7 +504,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging { } } - sendMessage(connectionManagerId, ackMessage.getOrElse { + sendMessage(connectionManagerId, ackMessage.getOrElse { Message.createBufferMessage(bufferMessage.id) }) } @@ -588,17 +575,17 @@ private[spark] object ConnectionManager { def main(args: Array[String]) { val manager = new ConnectionManager(9999) - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { println("Received [" + msg + "] from [" + id + "]") None }) - + /*testSequentialSending(manager)*/ /*System.gc()*/ /*testParallelSending(manager)*/ /*System.gc()*/ - + /*testParallelDecreasingSending(manager)*/ /*System.gc()*/ @@ -610,9 +597,9 @@ private[spark] object ConnectionManager { println("--------------------------") println("Sequential Sending") println("--------------------------") - val size = 10 * 1024 * 1024 + val size = 10 * 1024 * 1024 val count = 10 - + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) buffer.flip @@ -628,7 +615,7 @@ private[spark] object ConnectionManager { println("--------------------------") println("Parallel Sending") println("--------------------------") - val size = 10 * 1024 * 1024 + val size = 10 * 1024 * 1024 val count = 10 val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) @@ -643,12 +630,12 @@ private[spark] object ConnectionManager { if (!g.isDefined) println("Failed") }) val finishTime = System.currentTimeMillis - + val mb = size * count / 1024.0 / 1024.0 val ms = finishTime - startTime val tput = mb * 1000.0 / ms println("--------------------------") - println("Started at " + startTime + ", finished at " + finishTime) + println("Started at " + startTime + ", finished at " + finishTime) println("Sent " + count + " messages of size " + size + " in " + ms + " ms (" + tput + " MB/s)") println("--------------------------") println() @@ -658,7 +645,7 @@ private[spark] object ConnectionManager { println("--------------------------") println("Parallel Decreasing Sending") println("--------------------------") - val size = 10 * 1024 * 1024 + val size = 10 * 1024 * 1024 val count = 10 val buffers = Array.tabulate(count)(i => ByteBuffer.allocate(size * (i + 1)).put(Array.tabulate[Byte](size * (i + 1))(x => x.toByte))) buffers.foreach(_.flip) @@ -673,7 +660,7 @@ private[spark] object ConnectionManager { if (!g.isDefined) println("Failed") }) val finishTime = System.currentTimeMillis - + val ms = finishTime - startTime val tput = mb * 1000.0 / ms println("--------------------------") @@ -687,7 +674,7 @@ private[spark] object ConnectionManager { println("--------------------------") println("Continuous Sending") println("--------------------------") - val size = 10 * 1024 * 1024 + val size = 10 * 1024 * 1024 val count = 10 val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) diff --git a/core/src/main/scala/spark/network/ConnectionManagerId.scala b/core/src/main/scala/spark/network/ConnectionManagerId.scala new file mode 100644 index 0000000000..b554e84251 --- /dev/null +++ b/core/src/main/scala/spark/network/ConnectionManagerId.scala @@ -0,0 +1,21 @@ +package spark.network + +import java.net.InetSocketAddress + +import spark.Utils + + +private[spark] case class ConnectionManagerId(host: String, port: Int) { + // DEBUG code + Utils.checkHost(host) + assert (port > 0) + + def toSocketAddress() = new InetSocketAddress(host, port) +} + + +private[spark] object ConnectionManagerId { + def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = { + new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort()) + } +} diff --git a/core/src/main/scala/spark/network/Message.scala b/core/src/main/scala/spark/network/Message.scala index 34fac9e776..d4f03610eb 100644 --- a/core/src/main/scala/spark/network/Message.scala +++ b/core/src/main/scala/spark/network/Message.scala @@ -1,56 +1,10 @@ package spark.network -import spark._ - -import scala.collection.mutable.ArrayBuffer - import java.nio.ByteBuffer -import java.net.InetAddress import java.net.InetSocketAddress -import storage.BlockManager - -private[spark] class MessageChunkHeader( - val typ: Long, - val id: Int, - val totalSize: Int, - val chunkSize: Int, - val other: Int, - val address: InetSocketAddress) { - lazy val buffer = { - // No need to change this, at 'use' time, we do a reverse lookup of the hostname. Refer to network.Connection - val ip = address.getAddress.getAddress() - val port = address.getPort() - ByteBuffer. - allocate(MessageChunkHeader.HEADER_SIZE). - putLong(typ). - putInt(id). - putInt(totalSize). - putInt(chunkSize). - putInt(other). - putInt(ip.size). - put(ip). - putInt(port). - position(MessageChunkHeader.HEADER_SIZE). - flip.asInstanceOf[ByteBuffer] - } - - override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ + - " and sizes " + totalSize + " / " + chunkSize + " bytes" -} -private[spark] class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { - val size = if (buffer == null) 0 else buffer.remaining - lazy val buffers = { - val ab = new ArrayBuffer[ByteBuffer]() - ab += header.buffer - if (buffer != null) { - ab += buffer - } - ab - } +import scala.collection.mutable.ArrayBuffer - override def toString = "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")" -} private[spark] abstract class Message(val typ: Long, val id: Int) { var senderAddress: InetSocketAddress = null @@ -59,120 +13,16 @@ private[spark] abstract class Message(val typ: Long, val id: Int) { var finishTime = -1L def size: Int - + def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] - + def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] - + def timeTaken(): String = (finishTime - startTime).toString + " ms" override def toString = this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")" } -private[spark] class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int) -extends Message(Message.BUFFER_MESSAGE, id_) { - - val initialSize = currentSize() - var gotChunkForSendingOnce = false - - def size = initialSize - - def currentSize() = { - if (buffers == null || buffers.isEmpty) { - 0 - } else { - buffers.map(_.remaining).reduceLeft(_ + _) - } - } - - def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = { - if (maxChunkSize <= 0) { - throw new Exception("Max chunk size is " + maxChunkSize) - } - - if (size == 0 && gotChunkForSendingOnce == false) { - val newChunk = new MessageChunk(new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null) - gotChunkForSendingOnce = true - return Some(newChunk) - } - - while(!buffers.isEmpty) { - val buffer = buffers(0) - if (buffer.remaining == 0) { - BlockManager.dispose(buffer) - buffers -= buffer - } else { - val newBuffer = if (buffer.remaining <= maxChunkSize) { - buffer.duplicate() - } else { - buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer] - } - buffer.position(buffer.position + newBuffer.remaining) - val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer) - gotChunkForSendingOnce = true - return Some(newChunk) - } - } - None - } - - def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = { - // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer - if (buffers.size > 1) { - throw new Exception("Attempting to get chunk from message with multiple data buffers") - } - val buffer = buffers(0) - if (buffer.remaining > 0) { - if (buffer.remaining < chunkSize) { - throw new Exception("Not enough space in data buffer for receiving chunk") - } - val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer] - buffer.position(buffer.position + newBuffer.remaining) - val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer) - return Some(newChunk) - } - None - } - - def flip() { - buffers.foreach(_.flip) - } - - def hasAckId() = (ackId != 0) - - def isCompletelyReceived() = !buffers(0).hasRemaining - - override def toString = { - if (hasAckId) { - "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")" - } else { - "BufferMessage(id = " + id + ", size = " + size + ")" - } - } -} - -private[spark] object MessageChunkHeader { - val HEADER_SIZE = 40 - - def create(buffer: ByteBuffer): MessageChunkHeader = { - if (buffer.remaining != HEADER_SIZE) { - throw new IllegalArgumentException("Cannot convert buffer data to Message") - } - val typ = buffer.getLong() - val id = buffer.getInt() - val totalSize = buffer.getInt() - val chunkSize = buffer.getInt() - val other = buffer.getInt() - val ipSize = buffer.getInt() - val ipBytes = new Array[Byte](ipSize) - buffer.get(ipBytes) - val ip = InetAddress.getByAddress(ipBytes) - val port = buffer.getInt() - new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port)) - } -} private[spark] object Message { val BUFFER_MESSAGE = 1111111111L @@ -181,14 +31,16 @@ private[spark] object Message { def getNewId() = synchronized { lastId += 1 - if (lastId == 0) lastId += 1 + if (lastId == 0) { + lastId += 1 + } lastId } def createBufferMessage(dataBuffers: Seq[ByteBuffer], ackId: Int): BufferMessage = { if (dataBuffers == null) { return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer], ackId) - } + } if (dataBuffers.exists(_ == null)) { throw new Exception("Attempting to create buffer message with null buffer") } @@ -197,7 +49,7 @@ private[spark] object Message { def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage = createBufferMessage(dataBuffers, 0) - + def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = { if (dataBuffer == null) { return createBufferMessage(Array(ByteBuffer.allocate(0)), ackId) @@ -205,15 +57,18 @@ private[spark] object Message { return createBufferMessage(Array(dataBuffer), ackId) } } - - def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage = + + def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage = createBufferMessage(dataBuffer, 0) - - def createBufferMessage(ackId: Int): BufferMessage = createBufferMessage(new Array[ByteBuffer](0), ackId) + + def createBufferMessage(ackId: Int): BufferMessage = { + createBufferMessage(new Array[ByteBuffer](0), ackId) + } def create(header: MessageChunkHeader): Message = { val newMessage: Message = header.typ match { - case BUFFER_MESSAGE => new BufferMessage(header.id, ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other) + case BUFFER_MESSAGE => new BufferMessage(header.id, + ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other) } newMessage.senderAddress = header.address newMessage diff --git a/core/src/main/scala/spark/network/MessageChunk.scala b/core/src/main/scala/spark/network/MessageChunk.scala new file mode 100644 index 0000000000..aaf9204d0e --- /dev/null +++ b/core/src/main/scala/spark/network/MessageChunk.scala @@ -0,0 +1,25 @@ +package spark.network + +import java.nio.ByteBuffer + +import scala.collection.mutable.ArrayBuffer + + +private[network] +class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { + + val size = if (buffer == null) 0 else buffer.remaining + + lazy val buffers = { + val ab = new ArrayBuffer[ByteBuffer]() + ab += header.buffer + if (buffer != null) { + ab += buffer + } + ab + } + + override def toString = { + "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")" + } +} diff --git a/core/src/main/scala/spark/network/MessageChunkHeader.scala b/core/src/main/scala/spark/network/MessageChunkHeader.scala new file mode 100644 index 0000000000..3693d509d6 --- /dev/null +++ b/core/src/main/scala/spark/network/MessageChunkHeader.scala @@ -0,0 +1,58 @@ +package spark.network + +import java.net.InetAddress +import java.net.InetSocketAddress +import java.nio.ByteBuffer + + +private[spark] class MessageChunkHeader( + val typ: Long, + val id: Int, + val totalSize: Int, + val chunkSize: Int, + val other: Int, + val address: InetSocketAddress) { + lazy val buffer = { + // No need to change this, at 'use' time, we do a reverse lookup of the hostname. + // Refer to network.Connection + val ip = address.getAddress.getAddress() + val port = address.getPort() + ByteBuffer. + allocate(MessageChunkHeader.HEADER_SIZE). + putLong(typ). + putInt(id). + putInt(totalSize). + putInt(chunkSize). + putInt(other). + putInt(ip.size). + put(ip). + putInt(port). + position(MessageChunkHeader.HEADER_SIZE). + flip.asInstanceOf[ByteBuffer] + } + + override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ + + " and sizes " + totalSize + " / " + chunkSize + " bytes" +} + + +private[spark] object MessageChunkHeader { + val HEADER_SIZE = 40 + + def create(buffer: ByteBuffer): MessageChunkHeader = { + if (buffer.remaining != HEADER_SIZE) { + throw new IllegalArgumentException("Cannot convert buffer data to Message") + } + val typ = buffer.getLong() + val id = buffer.getInt() + val totalSize = buffer.getInt() + val chunkSize = buffer.getInt() + val other = buffer.getInt() + val ipSize = buffer.getInt() + val ipBytes = new Array[Byte](ipSize) + buffer.get(ipBytes) + val ip = InetAddress.getByAddress(ipBytes) + val port = buffer.getInt() + new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port)) + } +} -- cgit v1.2.3