aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorReynold Xin <rxin@cs.berkeley.edu>2013-05-07 22:42:15 -0700
committerReynold Xin <rxin@cs.berkeley.edu>2013-05-07 22:42:15 -0700
commit5d70ee4663b5589611aef107f914a94f301c7d2a (patch)
treeeae0d0fb3ab45f20f4c24ce80c5ad7de0dd55205 /core
parent8388e8dd7ab5e55ea67b329d9359ba2147d796b0 (diff)
downloadspark-5d70ee4663b5589611aef107f914a94f301c7d2a.tar.gz
spark-5d70ee4663b5589611aef107f914a94f301c7d2a.tar.bz2
spark-5d70ee4663b5589611aef107f914a94f301c7d2a.zip
Cleaned up connection manager (moved many classes to their own files).
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/spark/network/BufferMessage.scala94
-rw-r--r--core/src/main/scala/spark/network/Connection.scala137
-rw-r--r--core/src/main/scala/spark/network/ConnectionManager.scala53
-rw-r--r--core/src/main/scala/spark/network/ConnectionManagerId.scala21
-rw-r--r--core/src/main/scala/spark/network/Message.scala179
-rw-r--r--core/src/main/scala/spark/network/MessageChunk.scala25
-rw-r--r--core/src/main/scala/spark/network/MessageChunkHeader.scala58
7 files changed, 315 insertions, 252 deletions
diff --git a/core/src/main/scala/spark/network/BufferMessage.scala b/core/src/main/scala/spark/network/BufferMessage.scala
new file mode 100644
index 0000000000..7b0e489a6c
--- /dev/null
+++ b/core/src/main/scala/spark/network/BufferMessage.scala
@@ -0,0 +1,94 @@
+package spark.network
+
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+
+import spark.storage.BlockManager
+
+
+private[spark]
+class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int)
+ extends Message(Message.BUFFER_MESSAGE, id_) {
+
+ val initialSize = currentSize()
+ var gotChunkForSendingOnce = false
+
+ def size = initialSize
+
+ def currentSize() = {
+ if (buffers == null || buffers.isEmpty) {
+ 0
+ } else {
+ buffers.map(_.remaining).reduceLeft(_ + _)
+ }
+ }
+
+ def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = {
+ if (maxChunkSize <= 0) {
+ throw new Exception("Max chunk size is " + maxChunkSize)
+ }
+
+ if (size == 0 && gotChunkForSendingOnce == false) {
+ val newChunk = new MessageChunk(
+ new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null)
+ gotChunkForSendingOnce = true
+ return Some(newChunk)
+ }
+
+ while(!buffers.isEmpty) {
+ val buffer = buffers(0)
+ if (buffer.remaining == 0) {
+ BlockManager.dispose(buffer)
+ buffers -= buffer
+ } else {
+ val newBuffer = if (buffer.remaining <= maxChunkSize) {
+ buffer.duplicate()
+ } else {
+ buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer]
+ }
+ buffer.position(buffer.position + newBuffer.remaining)
+ val newChunk = new MessageChunk(new MessageChunkHeader(
+ typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
+ gotChunkForSendingOnce = true
+ return Some(newChunk)
+ }
+ }
+ None
+ }
+
+ def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = {
+ // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer
+ if (buffers.size > 1) {
+ throw new Exception("Attempting to get chunk from message with multiple data buffers")
+ }
+ val buffer = buffers(0)
+ if (buffer.remaining > 0) {
+ if (buffer.remaining < chunkSize) {
+ throw new Exception("Not enough space in data buffer for receiving chunk")
+ }
+ val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
+ buffer.position(buffer.position + newBuffer.remaining)
+ val newChunk = new MessageChunk(new MessageChunkHeader(
+ typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
+ return Some(newChunk)
+ }
+ None
+ }
+
+ def flip() {
+ buffers.foreach(_.flip)
+ }
+
+ def hasAckId() = (ackId != 0)
+
+ def isCompletelyReceived() = !buffers(0).hasRemaining
+
+ override def toString = {
+ if (hasAckId) {
+ "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")"
+ } else {
+ "BufferMessage(id = " + id + ", size = " + size + ")"
+ }
+ }
+} \ No newline at end of file
diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala
index 00a0433a44..6e28f677a3 100644
--- a/core/src/main/scala/spark/network/Connection.scala
+++ b/core/src/main/scala/spark/network/Connection.scala
@@ -13,12 +13,13 @@ import java.net._
private[spark]
abstract class Connection(val channel: SocketChannel, val selector: Selector,
- val socketRemoteConnectionManagerId: ConnectionManagerId) extends Logging {
+ val socketRemoteConnectionManagerId: ConnectionManagerId)
+ extends Logging {
+
def this(channel_ : SocketChannel, selector_ : Selector) = {
this(channel_, selector_,
- ConnectionManagerId.fromSocketAddress(
- channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]
- ))
+ ConnectionManagerId.fromSocketAddress(
+ channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]))
}
channel.configureBlocking(false)
@@ -32,17 +33,19 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
var onKeyInterestChangeCallback: (Connection, Int) => Unit = null
val remoteAddress = getRemoteAddress()
-
+
// Read channels typically do not register for write and write does not for read
// Now, we do have write registering for read too (temporarily), but this is to detect
// channel close NOT to actually read/consume data on it !
// How does this work if/when we move to SSL ?
-
+
// What is the interest to register with selector for when we want this connection to be selected
def registerInterest()
- // What is the interest to register with selector for when we want this connection to be de-selected
- // Traditionally, 0 - but in our case, for example, for close-detection on SendingConnection hack, it will be
- // SelectionKey.OP_READ (until we fix it properly)
+
+ // What is the interest to register with selector for when we want this connection to
+ // be de-selected
+ // Traditionally, 0 - but in our case, for example, for close-detection on SendingConnection hack,
+ // it will be SelectionKey.OP_READ (until we fix it properly)
def unregisterInterest()
// On receiving a read event, should we change the interest for this channel or not ?
@@ -64,12 +67,14 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
// Returns whether we have to register for further reads or not.
def read(): Boolean = {
- throw new UnsupportedOperationException("Cannot read on connection of type " + this.getClass.toString)
+ throw new UnsupportedOperationException(
+ "Cannot read on connection of type " + this.getClass.toString)
}
// Returns whether we have to register for further writes or not.
def write(): Boolean = {
- throw new UnsupportedOperationException("Cannot write on connection of type " + this.getClass.toString)
+ throw new UnsupportedOperationException(
+ "Cannot write on connection of type " + this.getClass.toString)
}
def close() {
@@ -81,11 +86,17 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
callOnCloseCallback()
}
- def onClose(callback: Connection => Unit) {onCloseCallback = callback}
+ def onClose(callback: Connection => Unit) {
+ onCloseCallback = callback
+ }
- def onException(callback: (Connection, Exception) => Unit) {onExceptionCallback = callback}
+ def onException(callback: (Connection, Exception) => Unit) {
+ onExceptionCallback = callback
+ }
- def onKeyInterestChange(callback: (Connection, Int) => Unit) {onKeyInterestChangeCallback = callback}
+ def onKeyInterestChange(callback: (Connection, Int) => Unit) {
+ onKeyInterestChangeCallback = callback
+ }
def callOnExceptionCallback(e: Exception) {
if (onExceptionCallback != null) {
@@ -95,7 +106,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
" and OnExceptionCallback not registered", e)
}
}
-
+
def callOnCloseCallback() {
if (onCloseCallback != null) {
onCloseCallback(this)
@@ -132,24 +143,25 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
print(" (" + position + ", " + length + ")")
buffer.position(curPosition)
}
-
}
-private[spark] class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
- remoteId_ : ConnectionManagerId)
-extends Connection(SocketChannel.open, selector_, remoteId_) {
+private[spark]
+class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
+ remoteId_ : ConnectionManagerId)
+ extends Connection(SocketChannel.open, selector_, remoteId_) {
class Outbox(fair: Int = 0) {
val messages = new Queue[Message]()
- val defaultChunkSize = 65536 //32768 //16384
+ val defaultChunkSize = 65536 //32768 //16384
var nextMessageToBeUsed = 0
def addMessage(message: Message) {
- messages.synchronized{
+ messages.synchronized{
/*messages += message*/
messages.enqueue(message)
- logDebug("Added [" + message + "] to outbox for sending to [" + getRemoteConnectionManagerId() + "]")
+ logDebug("Added [" + message + "] to outbox for sending to " +
+ "[" + getRemoteConnectionManagerId() + "]")
}
}
@@ -174,7 +186,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
message.started = true
message.startTime = System.currentTimeMillis
}
- return chunk
+ return chunk
} else {
/*logInfo("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + "]")*/
message.finishTime = System.currentTimeMillis
@@ -185,7 +197,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
}
None
}
-
+
private def getChunkRR(): Option[MessageChunk] = {
messages.synchronized {
while (!messages.isEmpty) {
@@ -197,12 +209,14 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
messages.enqueue(message)
nextMessageToBeUsed = nextMessageToBeUsed + 1
if (!message.started) {
- logDebug("Starting to send [" + message + "] to [" + getRemoteConnectionManagerId() + "]")
+ logDebug(
+ "Starting to send [" + message + "] to [" + getRemoteConnectionManagerId() + "]")
message.started = true
message.startTime = System.currentTimeMillis
}
- logTrace("Sending chunk from [" + message+ "] to [" + getRemoteConnectionManagerId() + "]")
- return chunk
+ logTrace(
+ "Sending chunk from [" + message+ "] to [" + getRemoteConnectionManagerId() + "]")
+ return chunk
} else {
message.finishTime = System.currentTimeMillis
logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() +
@@ -213,7 +227,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
None
}
}
-
+
private val outbox = new Outbox(1)
val currentBuffers = new ArrayBuffer[ByteBuffer]()
@@ -228,11 +242,11 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
// it does - so let us keep it for now.
changeConnectionKeyInterest(SelectionKey.OP_WRITE | DEFAULT_INTEREST)
}
-
+
override def unregisterInterest() {
changeConnectionKeyInterest(DEFAULT_INTEREST)
}
-
+
def send(message: Message) {
outbox.synchronized {
outbox.addMessage(message)
@@ -262,12 +276,14 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
// selection - though need not necessarily always complete successfully.
val connected = channel.finishConnect
if (!force && !connected) {
- logInfo("finish connect failed [" + address + "], " + outbox.messages.size + " messages pending")
+ logInfo(
+ "finish connect failed [" + address + "], " + outbox.messages.size + " messages pending")
return false
}
// Fallback to previous behavior - assume finishConnect completed
- // This will happen only when finishConnect failed for some repeated number of times (10 or so)
+ // This will happen only when finishConnect failed for some repeated number of times
+ // (10 or so)
// Is highly unlikely unless there was an unclean close of socket, etc
registerInterest()
logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending")
@@ -283,13 +299,13 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
}
override def write(): Boolean = {
- try{
- while(true) {
+ try {
+ while (true) {
if (currentBuffers.size == 0) {
outbox.synchronized {
outbox.getChunk() match {
case Some(chunk) => {
- currentBuffers ++= chunk.buffers
+ currentBuffers ++= chunk.buffers
}
case None => {
// changeConnectionKeyInterest(0)
@@ -299,7 +315,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
}
}
}
-
+
if (currentBuffers.size > 0) {
val buffer = currentBuffers(0)
val remainingBytes = buffer.remaining
@@ -314,7 +330,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
}
}
} catch {
- case e: Exception => {
+ case e: Exception => {
logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e)
callOnExceptionCallback(e)
close()
@@ -336,7 +352,8 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
if (length == -1) { // EOF
close()
} else if (length > 0) {
- logWarning("Unexpected data read from SendingConnection to " + getRemoteConnectionManagerId())
+ logWarning(
+ "Unexpected data read from SendingConnection to " + getRemoteConnectionManagerId())
}
} catch {
case e: Exception =>
@@ -355,30 +372,32 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
// Must be created within selector loop - else deadlock
-private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector)
-extends Connection(channel_, selector_) {
-
+private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector)
+ extends Connection(channel_, selector_) {
+
class Inbox() {
val messages = new HashMap[Int, BufferMessage]()
-
+
def getChunk(header: MessageChunkHeader): Option[MessageChunk] = {
-
+
def createNewMessage: BufferMessage = {
val newMessage = Message.create(header).asInstanceOf[BufferMessage]
newMessage.started = true
newMessage.startTime = System.currentTimeMillis
- logDebug("Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]")
+ logDebug(
+ "Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]")
messages += ((newMessage.id, newMessage))
newMessage
}
-
+
val message = messages.getOrElseUpdate(header.id, createNewMessage)
- logTrace("Receiving chunk of [" + message + "] from [" + getRemoteConnectionManagerId() + "]")
+ logTrace(
+ "Receiving chunk of [" + message + "] from [" + getRemoteConnectionManagerId() + "]")
message.getChunkForReceiving(header.chunkSize)
}
-
+
def getMessageForChunk(chunk: MessageChunk): Option[BufferMessage] = {
- messages.get(chunk.header.id)
+ messages.get(chunk.header.id)
}
def removeMessage(message: Message) {
@@ -387,12 +406,14 @@ extends Connection(channel_, selector_) {
}
@volatile private var inferredRemoteManagerId: ConnectionManagerId = null
+
override def getRemoteConnectionManagerId(): ConnectionManagerId = {
val currId = inferredRemoteManagerId
if (currId != null) currId else super.getRemoteConnectionManagerId()
}
- // The reciever's remote address is the local socket on remote side : which is NOT the connection manager id of the receiver.
+ // The reciever's remote address is the local socket on remote side : which is NOT
+ // the connection manager id of the receiver.
// We infer that from the messages we receive on the receiver socket.
private def processConnectionManagerId(header: MessageChunkHeader) {
val currId = inferredRemoteManagerId
@@ -428,7 +449,8 @@ extends Connection(channel_, selector_) {
}
headerBuffer.flip
if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) {
- throw new Exception("Unexpected number of bytes (" + headerBuffer.remaining + ") in the header")
+ throw new Exception(
+ "Unexpected number of bytes (" + headerBuffer.remaining + ") in the header")
}
val header = MessageChunkHeader.create(headerBuffer)
headerBuffer.clear()
@@ -451,9 +473,9 @@ extends Connection(channel_, selector_) {
case _ => throw new Exception("Message of unknown type received")
}
}
-
+
if (currentChunk == null) throw new Exception("No message chunk to receive data")
-
+
val bytesRead = channel.read(currentChunk.buffer)
if (bytesRead == 0) {
// re-register for read event ...
@@ -464,14 +486,15 @@ extends Connection(channel_, selector_) {
}
/*logDebug("Read " + bytesRead + " bytes for the buffer")*/
-
+
if (currentChunk.buffer.remaining == 0) {
/*println("Filled buffer at " + System.currentTimeMillis)*/
val bufferMessage = inbox.getMessageForChunk(currentChunk).get
if (bufferMessage.isCompletelyReceived) {
bufferMessage.flip
bufferMessage.finishTime = System.currentTimeMillis
- logDebug("Finished receiving [" + bufferMessage + "] from [" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken)
+ logDebug("Finished receiving [" + bufferMessage + "] from " +
+ "[" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken)
if (onReceiveCallback != null) {
onReceiveCallback(this, bufferMessage)
}
@@ -481,7 +504,7 @@ extends Connection(channel_, selector_) {
}
}
} catch {
- case e: Exception => {
+ case e: Exception => {
logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e)
callOnExceptionCallback(e)
close()
@@ -491,7 +514,7 @@ extends Connection(channel_, selector_) {
// should not happen - to keep scala compiler happy
return true
}
-
+
def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback}
override def changeInterestForRead(): Boolean = true
@@ -505,7 +528,7 @@ extends Connection(channel_, selector_) {
// it does - so let us keep it for now.
changeConnectionKeyInterest(SelectionKey.OP_READ)
}
-
+
override def unregisterInterest() {
changeConnectionKeyInterest(0)
}
diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala
index 0eb03630d0..624a094856 100644
--- a/core/src/main/scala/spark/network/ConnectionManager.scala
+++ b/core/src/main/scala/spark/network/ConnectionManager.scala
@@ -18,20 +18,7 @@ import akka.dispatch.{Await, Promise, ExecutionContext, Future}
import akka.util.Duration
import akka.util.duration._
-private[spark] case class ConnectionManagerId(host: String, port: Int) {
- // DEBUG code
- Utils.checkHost(host)
- assert (port > 0)
- def toSocketAddress() = new InetSocketAddress(host, port)
-}
-
-private[spark] object ConnectionManagerId {
- def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = {
- new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort())
- }
-}
-
private[spark] class ConnectionManager(port: Int) extends Logging {
class MessageStatus(
@@ -45,7 +32,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
def markDone() { completionHandler(this) }
}
-
+
private val selector = SelectorProvider.provider.openSelector()
private val handleMessageExecutor = new ThreadPoolExecutor(
@@ -80,7 +67,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
serverChannel.configureBlocking(false)
serverChannel.socket.setReuseAddress(true)
- serverChannel.socket.setReceiveBufferSize(256 * 1024)
+ serverChannel.socket.setReceiveBufferSize(256 * 1024)
serverChannel.socket.bind(new InetSocketAddress(port))
serverChannel.register(selector, SelectionKey.OP_ACCEPT)
@@ -351,7 +338,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
case e: Exception => logError("Error in select loop", e)
}
}
-
+
def acceptConnection(key: SelectionKey) {
val serverChannel = key.channel.asInstanceOf[ServerSocketChannel]
@@ -463,7 +450,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
def receiveMessage(connection: Connection, message: Message) {
val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress)
- logDebug("Received [" + message + "] from [" + connectionManagerId + "]")
+ logDebug("Received [" + message + "] from [" + connectionManagerId + "]")
val runnable = new Runnable() {
val creationTime = System.currentTimeMillis
def run() {
@@ -483,11 +470,11 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
if (bufferMessage.hasAckId) {
val sentMessageStatus = messageStatuses.synchronized {
messageStatuses.get(bufferMessage.ackId) match {
- case Some(status) => {
- messageStatuses -= bufferMessage.ackId
+ case Some(status) => {
+ messageStatuses -= bufferMessage.ackId
status
}
- case None => {
+ case None => {
throw new Exception("Could not find reference for received ack message " + message.id)
null
}
@@ -507,7 +494,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
logDebug("Not calling back as callback is null")
None
}
-
+
if (ackMessage.isDefined) {
if (!ackMessage.get.isInstanceOf[BufferMessage]) {
logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " + ackMessage.get.getClass())
@@ -517,7 +504,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
}
}
- sendMessage(connectionManagerId, ackMessage.getOrElse {
+ sendMessage(connectionManagerId, ackMessage.getOrElse {
Message.createBufferMessage(bufferMessage.id)
})
}
@@ -588,17 +575,17 @@ private[spark] object ConnectionManager {
def main(args: Array[String]) {
val manager = new ConnectionManager(9999)
- manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
println("Received [" + msg + "] from [" + id + "]")
None
})
-
+
/*testSequentialSending(manager)*/
/*System.gc()*/
/*testParallelSending(manager)*/
/*System.gc()*/
-
+
/*testParallelDecreasingSending(manager)*/
/*System.gc()*/
@@ -610,9 +597,9 @@ private[spark] object ConnectionManager {
println("--------------------------")
println("Sequential Sending")
println("--------------------------")
- val size = 10 * 1024 * 1024
+ val size = 10 * 1024 * 1024
val count = 10
-
+
val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
buffer.flip
@@ -628,7 +615,7 @@ private[spark] object ConnectionManager {
println("--------------------------")
println("Parallel Sending")
println("--------------------------")
- val size = 10 * 1024 * 1024
+ val size = 10 * 1024 * 1024
val count = 10
val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
@@ -643,12 +630,12 @@ private[spark] object ConnectionManager {
if (!g.isDefined) println("Failed")
})
val finishTime = System.currentTimeMillis
-
+
val mb = size * count / 1024.0 / 1024.0
val ms = finishTime - startTime
val tput = mb * 1000.0 / ms
println("--------------------------")
- println("Started at " + startTime + ", finished at " + finishTime)
+ println("Started at " + startTime + ", finished at " + finishTime)
println("Sent " + count + " messages of size " + size + " in " + ms + " ms (" + tput + " MB/s)")
println("--------------------------")
println()
@@ -658,7 +645,7 @@ private[spark] object ConnectionManager {
println("--------------------------")
println("Parallel Decreasing Sending")
println("--------------------------")
- val size = 10 * 1024 * 1024
+ val size = 10 * 1024 * 1024
val count = 10
val buffers = Array.tabulate(count)(i => ByteBuffer.allocate(size * (i + 1)).put(Array.tabulate[Byte](size * (i + 1))(x => x.toByte)))
buffers.foreach(_.flip)
@@ -673,7 +660,7 @@ private[spark] object ConnectionManager {
if (!g.isDefined) println("Failed")
})
val finishTime = System.currentTimeMillis
-
+
val ms = finishTime - startTime
val tput = mb * 1000.0 / ms
println("--------------------------")
@@ -687,7 +674,7 @@ private[spark] object ConnectionManager {
println("--------------------------")
println("Continuous Sending")
println("--------------------------")
- val size = 10 * 1024 * 1024
+ val size = 10 * 1024 * 1024
val count = 10
val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
diff --git a/core/src/main/scala/spark/network/ConnectionManagerId.scala b/core/src/main/scala/spark/network/ConnectionManagerId.scala
new file mode 100644
index 0000000000..b554e84251
--- /dev/null
+++ b/core/src/main/scala/spark/network/ConnectionManagerId.scala
@@ -0,0 +1,21 @@
+package spark.network
+
+import java.net.InetSocketAddress
+
+import spark.Utils
+
+
+private[spark] case class ConnectionManagerId(host: String, port: Int) {
+ // DEBUG code
+ Utils.checkHost(host)
+ assert (port > 0)
+
+ def toSocketAddress() = new InetSocketAddress(host, port)
+}
+
+
+private[spark] object ConnectionManagerId {
+ def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = {
+ new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort())
+ }
+}
diff --git a/core/src/main/scala/spark/network/Message.scala b/core/src/main/scala/spark/network/Message.scala
index 34fac9e776..d4f03610eb 100644
--- a/core/src/main/scala/spark/network/Message.scala
+++ b/core/src/main/scala/spark/network/Message.scala
@@ -1,56 +1,10 @@
package spark.network
-import spark._
-
-import scala.collection.mutable.ArrayBuffer
-
import java.nio.ByteBuffer
-import java.net.InetAddress
import java.net.InetSocketAddress
-import storage.BlockManager
-
-private[spark] class MessageChunkHeader(
- val typ: Long,
- val id: Int,
- val totalSize: Int,
- val chunkSize: Int,
- val other: Int,
- val address: InetSocketAddress) {
- lazy val buffer = {
- // No need to change this, at 'use' time, we do a reverse lookup of the hostname. Refer to network.Connection
- val ip = address.getAddress.getAddress()
- val port = address.getPort()
- ByteBuffer.
- allocate(MessageChunkHeader.HEADER_SIZE).
- putLong(typ).
- putInt(id).
- putInt(totalSize).
- putInt(chunkSize).
- putInt(other).
- putInt(ip.size).
- put(ip).
- putInt(port).
- position(MessageChunkHeader.HEADER_SIZE).
- flip.asInstanceOf[ByteBuffer]
- }
-
- override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ +
- " and sizes " + totalSize + " / " + chunkSize + " bytes"
-}
-private[spark] class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) {
- val size = if (buffer == null) 0 else buffer.remaining
- lazy val buffers = {
- val ab = new ArrayBuffer[ByteBuffer]()
- ab += header.buffer
- if (buffer != null) {
- ab += buffer
- }
- ab
- }
+import scala.collection.mutable.ArrayBuffer
- override def toString = "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")"
-}
private[spark] abstract class Message(val typ: Long, val id: Int) {
var senderAddress: InetSocketAddress = null
@@ -59,120 +13,16 @@ private[spark] abstract class Message(val typ: Long, val id: Int) {
var finishTime = -1L
def size: Int
-
+
def getChunkForSending(maxChunkSize: Int): Option[MessageChunk]
-
+
def getChunkForReceiving(chunkSize: Int): Option[MessageChunk]
-
+
def timeTaken(): String = (finishTime - startTime).toString + " ms"
override def toString = this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")"
}
-private[spark] class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int)
-extends Message(Message.BUFFER_MESSAGE, id_) {
-
- val initialSize = currentSize()
- var gotChunkForSendingOnce = false
-
- def size = initialSize
-
- def currentSize() = {
- if (buffers == null || buffers.isEmpty) {
- 0
- } else {
- buffers.map(_.remaining).reduceLeft(_ + _)
- }
- }
-
- def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = {
- if (maxChunkSize <= 0) {
- throw new Exception("Max chunk size is " + maxChunkSize)
- }
-
- if (size == 0 && gotChunkForSendingOnce == false) {
- val newChunk = new MessageChunk(new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null)
- gotChunkForSendingOnce = true
- return Some(newChunk)
- }
-
- while(!buffers.isEmpty) {
- val buffer = buffers(0)
- if (buffer.remaining == 0) {
- BlockManager.dispose(buffer)
- buffers -= buffer
- } else {
- val newBuffer = if (buffer.remaining <= maxChunkSize) {
- buffer.duplicate()
- } else {
- buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer]
- }
- buffer.position(buffer.position + newBuffer.remaining)
- val newChunk = new MessageChunk(new MessageChunkHeader(
- typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
- gotChunkForSendingOnce = true
- return Some(newChunk)
- }
- }
- None
- }
-
- def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = {
- // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer
- if (buffers.size > 1) {
- throw new Exception("Attempting to get chunk from message with multiple data buffers")
- }
- val buffer = buffers(0)
- if (buffer.remaining > 0) {
- if (buffer.remaining < chunkSize) {
- throw new Exception("Not enough space in data buffer for receiving chunk")
- }
- val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
- buffer.position(buffer.position + newBuffer.remaining)
- val newChunk = new MessageChunk(new MessageChunkHeader(
- typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
- return Some(newChunk)
- }
- None
- }
-
- def flip() {
- buffers.foreach(_.flip)
- }
-
- def hasAckId() = (ackId != 0)
-
- def isCompletelyReceived() = !buffers(0).hasRemaining
-
- override def toString = {
- if (hasAckId) {
- "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")"
- } else {
- "BufferMessage(id = " + id + ", size = " + size + ")"
- }
- }
-}
-
-private[spark] object MessageChunkHeader {
- val HEADER_SIZE = 40
-
- def create(buffer: ByteBuffer): MessageChunkHeader = {
- if (buffer.remaining != HEADER_SIZE) {
- throw new IllegalArgumentException("Cannot convert buffer data to Message")
- }
- val typ = buffer.getLong()
- val id = buffer.getInt()
- val totalSize = buffer.getInt()
- val chunkSize = buffer.getInt()
- val other = buffer.getInt()
- val ipSize = buffer.getInt()
- val ipBytes = new Array[Byte](ipSize)
- buffer.get(ipBytes)
- val ip = InetAddress.getByAddress(ipBytes)
- val port = buffer.getInt()
- new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port))
- }
-}
private[spark] object Message {
val BUFFER_MESSAGE = 1111111111L
@@ -181,14 +31,16 @@ private[spark] object Message {
def getNewId() = synchronized {
lastId += 1
- if (lastId == 0) lastId += 1
+ if (lastId == 0) {
+ lastId += 1
+ }
lastId
}
def createBufferMessage(dataBuffers: Seq[ByteBuffer], ackId: Int): BufferMessage = {
if (dataBuffers == null) {
return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer], ackId)
- }
+ }
if (dataBuffers.exists(_ == null)) {
throw new Exception("Attempting to create buffer message with null buffer")
}
@@ -197,7 +49,7 @@ private[spark] object Message {
def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage =
createBufferMessage(dataBuffers, 0)
-
+
def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = {
if (dataBuffer == null) {
return createBufferMessage(Array(ByteBuffer.allocate(0)), ackId)
@@ -205,15 +57,18 @@ private[spark] object Message {
return createBufferMessage(Array(dataBuffer), ackId)
}
}
-
- def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage =
+
+ def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage =
createBufferMessage(dataBuffer, 0)
-
- def createBufferMessage(ackId: Int): BufferMessage = createBufferMessage(new Array[ByteBuffer](0), ackId)
+
+ def createBufferMessage(ackId: Int): BufferMessage = {
+ createBufferMessage(new Array[ByteBuffer](0), ackId)
+ }
def create(header: MessageChunkHeader): Message = {
val newMessage: Message = header.typ match {
- case BUFFER_MESSAGE => new BufferMessage(header.id, ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other)
+ case BUFFER_MESSAGE => new BufferMessage(header.id,
+ ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other)
}
newMessage.senderAddress = header.address
newMessage
diff --git a/core/src/main/scala/spark/network/MessageChunk.scala b/core/src/main/scala/spark/network/MessageChunk.scala
new file mode 100644
index 0000000000..aaf9204d0e
--- /dev/null
+++ b/core/src/main/scala/spark/network/MessageChunk.scala
@@ -0,0 +1,25 @@
+package spark.network
+
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+
+
+private[network]
+class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) {
+
+ val size = if (buffer == null) 0 else buffer.remaining
+
+ lazy val buffers = {
+ val ab = new ArrayBuffer[ByteBuffer]()
+ ab += header.buffer
+ if (buffer != null) {
+ ab += buffer
+ }
+ ab
+ }
+
+ override def toString = {
+ "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")"
+ }
+}
diff --git a/core/src/main/scala/spark/network/MessageChunkHeader.scala b/core/src/main/scala/spark/network/MessageChunkHeader.scala
new file mode 100644
index 0000000000..3693d509d6
--- /dev/null
+++ b/core/src/main/scala/spark/network/MessageChunkHeader.scala
@@ -0,0 +1,58 @@
+package spark.network
+
+import java.net.InetAddress
+import java.net.InetSocketAddress
+import java.nio.ByteBuffer
+
+
+private[spark] class MessageChunkHeader(
+ val typ: Long,
+ val id: Int,
+ val totalSize: Int,
+ val chunkSize: Int,
+ val other: Int,
+ val address: InetSocketAddress) {
+ lazy val buffer = {
+ // No need to change this, at 'use' time, we do a reverse lookup of the hostname.
+ // Refer to network.Connection
+ val ip = address.getAddress.getAddress()
+ val port = address.getPort()
+ ByteBuffer.
+ allocate(MessageChunkHeader.HEADER_SIZE).
+ putLong(typ).
+ putInt(id).
+ putInt(totalSize).
+ putInt(chunkSize).
+ putInt(other).
+ putInt(ip.size).
+ put(ip).
+ putInt(port).
+ position(MessageChunkHeader.HEADER_SIZE).
+ flip.asInstanceOf[ByteBuffer]
+ }
+
+ override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ +
+ " and sizes " + totalSize + " / " + chunkSize + " bytes"
+}
+
+
+private[spark] object MessageChunkHeader {
+ val HEADER_SIZE = 40
+
+ def create(buffer: ByteBuffer): MessageChunkHeader = {
+ if (buffer.remaining != HEADER_SIZE) {
+ throw new IllegalArgumentException("Cannot convert buffer data to Message")
+ }
+ val typ = buffer.getLong()
+ val id = buffer.getInt()
+ val totalSize = buffer.getInt()
+ val chunkSize = buffer.getInt()
+ val other = buffer.getInt()
+ val ipSize = buffer.getInt()
+ val ipBytes = new Array[Byte](ipSize)
+ buffer.get(ipBytes)
+ val ip = InetAddress.getByAddress(ipBytes)
+ val port = buffer.getInt()
+ new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port))
+ }
+}