aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorshane-huang <shengsheng.huang@intel.com>2013-01-08 22:40:58 +0800
committershane-huang <shengsheng.huang@intel.com>2013-01-08 22:40:58 +0800
commite4cb72da8a5428c6b9097e92ddbdf4ceee087b85 (patch)
treea0d60175b69625c0789d8e4743d8d7443460ab27
parenta37adfa67bac51b2630c6e1673f8607a87273402 (diff)
downloadspark-e4cb72da8a5428c6b9097e92ddbdf4ceee087b85.tar.gz
spark-e4cb72da8a5428c6b9097e92ddbdf4ceee087b85.tar.bz2
spark-e4cb72da8a5428c6b9097e92ddbdf4ceee087b85.zip
Fix an issue in ConnectionManager where sendingMessage may create too many unnecessary SendingConnections.
-rw-r--r--core/src/main/scala/spark/network/Connection.scala7
-rw-r--r--core/src/main/scala/spark/network/ConnectionManager.scala17
-rw-r--r--core/src/main/scala/spark/network/ConnectionManagerTest.scala18
3 files changed, 23 insertions, 19 deletions
diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala
index 80262ab7b4..95096fd0ba 100644
--- a/core/src/main/scala/spark/network/Connection.scala
+++ b/core/src/main/scala/spark/network/Connection.scala
@@ -135,8 +135,11 @@ extends Connection(SocketChannel.open, selector_) {
val chunk = message.getChunkForSending(defaultChunkSize)
if (chunk.isDefined) {
messages += message // this is probably incorrect, it wont work as fifo
- if (!message.started) logDebug("Starting to send [" + message + "]")
- message.started = true
+ if (!message.started) {
+ logDebug("Starting to send [" + message + "]")
+ message.started = true
+ message.startTime = System.currentTimeMillis
+ }
return chunk
} else {
/*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/
diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala
index 642fa4b525..e7bd2d3bbd 100644
--- a/core/src/main/scala/spark/network/ConnectionManager.scala
+++ b/core/src/main/scala/spark/network/ConnectionManager.scala
@@ -43,12 +43,12 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
}
val selector = SelectorProvider.provider.openSelector()
- val handleMessageExecutor = Executors.newFixedThreadPool(4)
+ val handleMessageExecutor = Executors.newFixedThreadPool(20)
val serverChannel = ServerSocketChannel.open()
val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
val messageStatuses = new HashMap[Int, MessageStatus]
- val connectionRequests = new SynchronizedQueue[SendingConnection]
+ val connectionRequests = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
val sendMessageRequests = new Queue[(Message, SendingConnection)]
@@ -78,11 +78,12 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
def run() {
try {
- while(!selectorThread.isInterrupted) {
- while(!connectionRequests.isEmpty) {
- val sendingConnection = connectionRequests.dequeue
+ while(!selectorThread.isInterrupted) {
+ for( (connectionManagerId, sendingConnection) <- connectionRequests) {
+ //val sendingConnection = connectionRequests.dequeue
sendingConnection.connect()
addConnection(sendingConnection)
+ connectionRequests -= connectionManagerId
}
sendMessageRequests.synchronized {
while(!sendMessageRequests.isEmpty) {
@@ -300,8 +301,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) {
def startNewConnection(): SendingConnection = {
val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port)
- val newConnection = new SendingConnection(inetSocketAddress, selector)
- connectionRequests += newConnection
+ val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId, new SendingConnection(inetSocketAddress, selector))
newConnection
}
val lookupKey = ConnectionManagerId.fromSocketAddress(connectionManagerId.toSocketAddress)
@@ -465,7 +465,7 @@ private[spark] object ConnectionManager {
val bufferMessage = Message.createBufferMessage(buffer.duplicate)
manager.sendMessageReliably(manager.id, bufferMessage)
}).foreach(f => {
- val g = Await.result(f, 1 second)
+ val g = Await.result(f, 10 second)
if (!g.isDefined) println("Failed")
})
val finishTime = System.currentTimeMillis
@@ -473,6 +473,7 @@ private[spark] object ConnectionManager {
val mb = size * count / 1024.0 / 1024.0
val ms = finishTime - startTime
val tput = mb * 1000.0 / ms
+ println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)")
println("--------------------------")
println()
}
diff --git a/core/src/main/scala/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/spark/network/ConnectionManagerTest.scala
index 47ceaf3c07..0e79c518e0 100644
--- a/core/src/main/scala/spark/network/ConnectionManagerTest.scala
+++ b/core/src/main/scala/spark/network/ConnectionManagerTest.scala
@@ -13,8 +13,8 @@ import akka.util.duration._
private[spark] object ConnectionManagerTest extends Logging{
def main(args: Array[String]) {
- if (args.length < 2) {
- println("Usage: ConnectionManagerTest <mesos cluster> <slaves file>")
+ if (args.length < 5) {
+ println("Usage: ConnectionManagerTest <mesos cluster> <slaves file> <num of tasks> <size of msg> <count>")
System.exit(1)
}
@@ -29,16 +29,16 @@ private[spark] object ConnectionManagerTest extends Logging{
/*println("Slaves")*/
/*slaves.foreach(println)*/
-
- val slaveConnManagerIds = sc.parallelize(0 until slaves.length, slaves.length).map(
+ val tasknum = args(2).toInt
+ val slaveConnManagerIds = sc.parallelize(0 until tasknum, tasknum).map(
i => SparkEnv.get.connectionManager.id).collect()
println("\nSlave ConnectionManagerIds")
slaveConnManagerIds.foreach(println)
println
- val count = 10
+ val count = args(4).toInt
(0 until count).foreach(i => {
- val resultStrs = sc.parallelize(0 until slaves.length, slaves.length).map(i => {
+ val resultStrs = sc.parallelize(0 until tasknum, tasknum).map(i => {
val connManager = SparkEnv.get.connectionManager
val thisConnManagerId = connManager.id
connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
@@ -46,7 +46,7 @@ private[spark] object ConnectionManagerTest extends Logging{
None
})
- val size = 100 * 1024 * 1024
+ val size = (args(3).toInt) * 1024 * 1024
val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
buffer.flip
@@ -56,13 +56,13 @@ private[spark] object ConnectionManagerTest extends Logging{
logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]")
connManager.sendMessageReliably(slaveConnManagerId, bufferMessage)
})
- val results = futures.map(f => Await.result(f, 1.second))
+ val results = futures.map(f => Await.result(f, 999.second))
val finishTime = System.currentTimeMillis
Thread.sleep(5000)
val mb = size * results.size / 1024.0 / 1024.0
val ms = finishTime - startTime
- val resultStr = "Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s"
+ val resultStr = thisConnManagerId + " Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s"
logInfo(resultStr)
resultStr
}).collect()