aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2013-05-15 18:06:24 -0700
committerMatei Zaharia <matei@eecs.berkeley.edu>2013-05-15 18:06:24 -0700
commit2f576aba8f6a7a101d2862808d03ec0da5ad00d4 (patch)
tree2c17071022e4051c803fae945cf868e11a15d454 /core
parent48c6f46a6f0257bcf5e9d7f9a0e0b5b29d5815f9 (diff)
parentf3491cb89bcbd8748bedbd2fe6a200ba09d9e3c2 (diff)
downloadspark-2f576aba8f6a7a101d2862808d03ec0da5ad00d4.tar.gz
spark-2f576aba8f6a7a101d2862808d03ec0da5ad00d4.tar.bz2
spark-2f576aba8f6a7a101d2862808d03ec0da5ad00d4.zip
Merge pull request #602 from rxin/shufflemerge
Manual merge & cleanup of Shane's Shuffle Performance Optimization
Diffstat (limited to 'core')
-rw-r--r--core/src/main/java/spark/network/netty/FileClient.java82
-rw-r--r--core/src/main/java/spark/network/netty/FileClientChannelInitializer.java24
-rw-r--r--core/src/main/java/spark/network/netty/FileClientHandler.java35
-rw-r--r--core/src/main/java/spark/network/netty/FileServer.java51
-rw-r--r--core/src/main/java/spark/network/netty/FileServerChannelInitializer.java25
-rw-r--r--core/src/main/java/spark/network/netty/FileServerHandler.java65
-rwxr-xr-xcore/src/main/java/spark/network/netty/PathResolver.java12
-rw-r--r--core/src/main/scala/spark/network/BufferMessage.scala94
-rw-r--r--core/src/main/scala/spark/network/Connection.scala137
-rw-r--r--core/src/main/scala/spark/network/ConnectionManager.scala53
-rw-r--r--core/src/main/scala/spark/network/ConnectionManagerId.scala21
-rw-r--r--core/src/main/scala/spark/network/Message.scala179
-rw-r--r--core/src/main/scala/spark/network/MessageChunk.scala25
-rw-r--r--core/src/main/scala/spark/network/MessageChunkHeader.scala58
-rw-r--r--core/src/main/scala/spark/network/netty/FileHeader.scala57
-rw-r--r--core/src/main/scala/spark/network/netty/ShuffleCopier.scala84
-rw-r--r--core/src/main/scala/spark/network/netty/ShuffleSender.scala56
-rw-r--r--core/src/main/scala/spark/storage/BlockFetchTracker.scala12
-rw-r--r--core/src/main/scala/spark/storage/BlockFetcherIterator.scala360
-rw-r--r--core/src/main/scala/spark/storage/BlockManager.scala194
-rw-r--r--core/src/main/scala/spark/storage/DiskStore.scala52
-rw-r--r--core/src/test/scala/spark/DistributedSuite.scala5
-rw-r--r--core/src/test/scala/spark/ShuffleNettySuite.scala17
23 files changed, 1258 insertions, 440 deletions
diff --git a/core/src/main/java/spark/network/netty/FileClient.java b/core/src/main/java/spark/network/netty/FileClient.java
new file mode 100644
index 0000000000..3a62dacbc8
--- /dev/null
+++ b/core/src/main/java/spark/network/netty/FileClient.java
@@ -0,0 +1,82 @@
+package spark.network.netty;
+
+import io.netty.bootstrap.Bootstrap;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelFutureListener;
+import io.netty.channel.ChannelOption;
+import io.netty.channel.oio.OioEventLoopGroup;
+import io.netty.channel.socket.oio.OioSocketChannel;
+
+
+class FileClient {
+
+ private FileClientHandler handler = null;
+ private Channel channel = null;
+ private Bootstrap bootstrap = null;
+
+ public FileClient(FileClientHandler handler) {
+ this.handler = handler;
+ }
+
+ public void init() {
+ bootstrap = new Bootstrap();
+ bootstrap.group(new OioEventLoopGroup())
+ .channel(OioSocketChannel.class)
+ .option(ChannelOption.SO_KEEPALIVE, true)
+ .option(ChannelOption.TCP_NODELAY, true)
+ .handler(new FileClientChannelInitializer(handler));
+ }
+
+ public static final class ChannelCloseListener implements ChannelFutureListener {
+ private FileClient fc = null;
+
+ public ChannelCloseListener(FileClient fc){
+ this.fc = fc;
+ }
+
+ @Override
+ public void operationComplete(ChannelFuture future) {
+ if (fc.bootstrap!=null){
+ fc.bootstrap.shutdown();
+ fc.bootstrap = null;
+ }
+ }
+ }
+
+ public void connect(String host, int port) {
+ try {
+ // Start the connection attempt.
+ channel = bootstrap.connect(host, port).sync().channel();
+ // ChannelFuture cf = channel.closeFuture();
+ //cf.addListener(new ChannelCloseListener(this));
+ } catch (InterruptedException e) {
+ close();
+ }
+ }
+
+ public void waitForClose() {
+ try {
+ channel.closeFuture().sync();
+ } catch (InterruptedException e){
+ e.printStackTrace();
+ }
+ }
+
+ public void sendRequest(String file) {
+ //assert(file == null);
+ //assert(channel == null);
+ channel.write(file + "\r\n");
+ }
+
+ public void close() {
+ if(channel != null) {
+ channel.close();
+ channel = null;
+ }
+ if ( bootstrap!=null) {
+ bootstrap.shutdown();
+ bootstrap = null;
+ }
+ }
+}
diff --git a/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java b/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java
new file mode 100644
index 0000000000..af25baf641
--- /dev/null
+++ b/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java
@@ -0,0 +1,24 @@
+package spark.network.netty;
+
+import io.netty.buffer.BufType;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.socket.SocketChannel;
+import io.netty.handler.codec.string.StringEncoder;
+
+
+class FileClientChannelInitializer extends ChannelInitializer<SocketChannel> {
+
+ private FileClientHandler fhandler;
+
+ public FileClientChannelInitializer(FileClientHandler handler) {
+ fhandler = handler;
+ }
+
+ @Override
+ public void initChannel(SocketChannel channel) {
+ // file no more than 2G
+ channel.pipeline()
+ .addLast("encoder", new StringEncoder(BufType.BYTE))
+ .addLast("handler", fhandler);
+ }
+}
diff --git a/core/src/main/java/spark/network/netty/FileClientHandler.java b/core/src/main/java/spark/network/netty/FileClientHandler.java
new file mode 100644
index 0000000000..2069dee5ca
--- /dev/null
+++ b/core/src/main/java/spark/network/netty/FileClientHandler.java
@@ -0,0 +1,35 @@
+package spark.network.netty;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundByteHandlerAdapter;
+
+
+abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter {
+
+ private FileHeader currentHeader = null;
+
+ public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header);
+
+ @Override
+ public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) {
+ // Use direct buffer if possible.
+ return ctx.alloc().ioBuffer();
+ }
+
+ @Override
+ public void inboundBufferUpdated(ChannelHandlerContext ctx, ByteBuf in) {
+ // get header
+ if (currentHeader == null && in.readableBytes() >= FileHeader.HEADER_SIZE()) {
+ currentHeader = FileHeader.create(in.readBytes(FileHeader.HEADER_SIZE()));
+ }
+ // get file
+ if(in.readableBytes() >= currentHeader.fileLen()) {
+ handle(ctx, in, currentHeader);
+ currentHeader = null;
+ ctx.close();
+ }
+ }
+
+}
+
diff --git a/core/src/main/java/spark/network/netty/FileServer.java b/core/src/main/java/spark/network/netty/FileServer.java
new file mode 100644
index 0000000000..647b26bf8a
--- /dev/null
+++ b/core/src/main/java/spark/network/netty/FileServer.java
@@ -0,0 +1,51 @@
+package spark.network.netty;
+
+import io.netty.bootstrap.ServerBootstrap;
+import io.netty.channel.ChannelOption;
+import io.netty.channel.Channel;
+import io.netty.channel.oio.OioEventLoopGroup;
+import io.netty.channel.socket.oio.OioServerSocketChannel;
+
+
+/**
+ * Server that accept the path of a file an echo back its content.
+ */
+class FileServer {
+
+ private ServerBootstrap bootstrap = null;
+ private Channel channel = null;
+ private PathResolver pResolver;
+
+ public FileServer(PathResolver pResolver) {
+ this.pResolver = pResolver;
+ }
+
+ public void run(int port) {
+ // Configure the server.
+ bootstrap = new ServerBootstrap();
+ try {
+ bootstrap.group(new OioEventLoopGroup(), new OioEventLoopGroup())
+ .channel(OioServerSocketChannel.class)
+ .option(ChannelOption.SO_BACKLOG, 100)
+ .option(ChannelOption.SO_RCVBUF, 1500)
+ .childHandler(new FileServerChannelInitializer(pResolver));
+ // Start the server.
+ channel = bootstrap.bind(port).sync().channel();
+ channel.closeFuture().sync();
+ } catch (InterruptedException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ } finally{
+ bootstrap.shutdown();
+ }
+ }
+
+ public void stop() {
+ if (channel!=null) {
+ channel.close();
+ }
+ if (bootstrap != null) {
+ bootstrap.shutdown();
+ }
+ }
+}
diff --git a/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java b/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java
new file mode 100644
index 0000000000..8f1f5c65cd
--- /dev/null
+++ b/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java
@@ -0,0 +1,25 @@
+package spark.network.netty;
+
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.socket.SocketChannel;
+import io.netty.handler.codec.DelimiterBasedFrameDecoder;
+import io.netty.handler.codec.Delimiters;
+import io.netty.handler.codec.string.StringDecoder;
+
+
+class FileServerChannelInitializer extends ChannelInitializer<SocketChannel> {
+
+ PathResolver pResolver;
+
+ public FileServerChannelInitializer(PathResolver pResolver) {
+ this.pResolver = pResolver;
+ }
+
+ @Override
+ public void initChannel(SocketChannel channel) {
+ channel.pipeline()
+ .addLast("framer", new DelimiterBasedFrameDecoder(8192, Delimiters.lineDelimiter()))
+ .addLast("strDecoder", new StringDecoder())
+ .addLast("handler", new FileServerHandler(pResolver));
+ }
+}
diff --git a/core/src/main/java/spark/network/netty/FileServerHandler.java b/core/src/main/java/spark/network/netty/FileServerHandler.java
new file mode 100644
index 0000000000..a78eddb1b5
--- /dev/null
+++ b/core/src/main/java/spark/network/netty/FileServerHandler.java
@@ -0,0 +1,65 @@
+package spark.network.netty;
+
+import java.io.File;
+import java.io.FileInputStream;
+
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundMessageHandlerAdapter;
+import io.netty.channel.DefaultFileRegion;
+
+
+class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> {
+
+ PathResolver pResolver;
+
+ public FileServerHandler(PathResolver pResolver){
+ this.pResolver = pResolver;
+ }
+
+ @Override
+ public void messageReceived(ChannelHandlerContext ctx, String blockId) {
+ String path = pResolver.getAbsolutePath(blockId);
+ // if getFilePath returns null, close the channel
+ if (path == null) {
+ //ctx.close();
+ return;
+ }
+ File file = new File(path);
+ if (file.exists()) {
+ if (!file.isFile()) {
+ //logger.info("Not a file : " + file.getAbsolutePath());
+ ctx.write(new FileHeader(0, blockId).buffer());
+ ctx.flush();
+ return;
+ }
+ long length = file.length();
+ if (length > Integer.MAX_VALUE || length <= 0) {
+ //logger.info("too large file : " + file.getAbsolutePath() + " of size "+ length);
+ ctx.write(new FileHeader(0, blockId).buffer());
+ ctx.flush();
+ return;
+ }
+ int len = new Long(length).intValue();
+ //logger.info("Sending block "+blockId+" filelen = "+len);
+ //logger.info("header = "+ (new FileHeader(len, blockId)).buffer());
+ ctx.write((new FileHeader(len, blockId)).buffer());
+ try {
+ ctx.sendFile(new DefaultFileRegion(new FileInputStream(file)
+ .getChannel(), 0, file.length()));
+ } catch (Exception e) {
+ //logger.warning("Exception when sending file : " + file.getAbsolutePath());
+ e.printStackTrace();
+ }
+ } else {
+ //logger.warning("File not found: " + file.getAbsolutePath());
+ ctx.write(new FileHeader(0, blockId).buffer());
+ }
+ ctx.flush();
+ }
+
+ @Override
+ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
+ cause.printStackTrace();
+ ctx.close();
+ }
+}
diff --git a/core/src/main/java/spark/network/netty/PathResolver.java b/core/src/main/java/spark/network/netty/PathResolver.java
new file mode 100755
index 0000000000..302411672c
--- /dev/null
+++ b/core/src/main/java/spark/network/netty/PathResolver.java
@@ -0,0 +1,12 @@
+package spark.network.netty;
+
+
+public interface PathResolver {
+ /**
+ * Get the absolute path of the file
+ *
+ * @param fileId
+ * @return the absolute path of file
+ */
+ public String getAbsolutePath(String fileId);
+}
diff --git a/core/src/main/scala/spark/network/BufferMessage.scala b/core/src/main/scala/spark/network/BufferMessage.scala
new file mode 100644
index 0000000000..7b0e489a6c
--- /dev/null
+++ b/core/src/main/scala/spark/network/BufferMessage.scala
@@ -0,0 +1,94 @@
+package spark.network
+
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+
+import spark.storage.BlockManager
+
+
+private[spark]
+class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int)
+ extends Message(Message.BUFFER_MESSAGE, id_) {
+
+ val initialSize = currentSize()
+ var gotChunkForSendingOnce = false
+
+ def size = initialSize
+
+ def currentSize() = {
+ if (buffers == null || buffers.isEmpty) {
+ 0
+ } else {
+ buffers.map(_.remaining).reduceLeft(_ + _)
+ }
+ }
+
+ def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = {
+ if (maxChunkSize <= 0) {
+ throw new Exception("Max chunk size is " + maxChunkSize)
+ }
+
+ if (size == 0 && gotChunkForSendingOnce == false) {
+ val newChunk = new MessageChunk(
+ new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null)
+ gotChunkForSendingOnce = true
+ return Some(newChunk)
+ }
+
+ while(!buffers.isEmpty) {
+ val buffer = buffers(0)
+ if (buffer.remaining == 0) {
+ BlockManager.dispose(buffer)
+ buffers -= buffer
+ } else {
+ val newBuffer = if (buffer.remaining <= maxChunkSize) {
+ buffer.duplicate()
+ } else {
+ buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer]
+ }
+ buffer.position(buffer.position + newBuffer.remaining)
+ val newChunk = new MessageChunk(new MessageChunkHeader(
+ typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
+ gotChunkForSendingOnce = true
+ return Some(newChunk)
+ }
+ }
+ None
+ }
+
+ def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = {
+ // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer
+ if (buffers.size > 1) {
+ throw new Exception("Attempting to get chunk from message with multiple data buffers")
+ }
+ val buffer = buffers(0)
+ if (buffer.remaining > 0) {
+ if (buffer.remaining < chunkSize) {
+ throw new Exception("Not enough space in data buffer for receiving chunk")
+ }
+ val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
+ buffer.position(buffer.position + newBuffer.remaining)
+ val newChunk = new MessageChunk(new MessageChunkHeader(
+ typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
+ return Some(newChunk)
+ }
+ None
+ }
+
+ def flip() {
+ buffers.foreach(_.flip)
+ }
+
+ def hasAckId() = (ackId != 0)
+
+ def isCompletelyReceived() = !buffers(0).hasRemaining
+
+ override def toString = {
+ if (hasAckId) {
+ "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")"
+ } else {
+ "BufferMessage(id = " + id + ", size = " + size + ")"
+ }
+ }
+} \ No newline at end of file
diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala
index 00a0433a44..6e28f677a3 100644
--- a/core/src/main/scala/spark/network/Connection.scala
+++ b/core/src/main/scala/spark/network/Connection.scala
@@ -13,12 +13,13 @@ import java.net._
private[spark]
abstract class Connection(val channel: SocketChannel, val selector: Selector,
- val socketRemoteConnectionManagerId: ConnectionManagerId) extends Logging {
+ val socketRemoteConnectionManagerId: ConnectionManagerId)
+ extends Logging {
+
def this(channel_ : SocketChannel, selector_ : Selector) = {
this(channel_, selector_,
- ConnectionManagerId.fromSocketAddress(
- channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]
- ))
+ ConnectionManagerId.fromSocketAddress(
+ channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]))
}
channel.configureBlocking(false)
@@ -32,17 +33,19 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
var onKeyInterestChangeCallback: (Connection, Int) => Unit = null
val remoteAddress = getRemoteAddress()
-
+
// Read channels typically do not register for write and write does not for read
// Now, we do have write registering for read too (temporarily), but this is to detect
// channel close NOT to actually read/consume data on it !
// How does this work if/when we move to SSL ?
-
+
// What is the interest to register with selector for when we want this connection to be selected
def registerInterest()
- // What is the interest to register with selector for when we want this connection to be de-selected
- // Traditionally, 0 - but in our case, for example, for close-detection on SendingConnection hack, it will be
- // SelectionKey.OP_READ (until we fix it properly)
+
+ // What is the interest to register with selector for when we want this connection to
+ // be de-selected
+ // Traditionally, 0 - but in our case, for example, for close-detection on SendingConnection hack,
+ // it will be SelectionKey.OP_READ (until we fix it properly)
def unregisterInterest()
// On receiving a read event, should we change the interest for this channel or not ?
@@ -64,12 +67,14 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
// Returns whether we have to register for further reads or not.
def read(): Boolean = {
- throw new UnsupportedOperationException("Cannot read on connection of type " + this.getClass.toString)
+ throw new UnsupportedOperationException(
+ "Cannot read on connection of type " + this.getClass.toString)
}
// Returns whether we have to register for further writes or not.
def write(): Boolean = {
- throw new UnsupportedOperationException("Cannot write on connection of type " + this.getClass.toString)
+ throw new UnsupportedOperationException(
+ "Cannot write on connection of type " + this.getClass.toString)
}
def close() {
@@ -81,11 +86,17 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
callOnCloseCallback()
}
- def onClose(callback: Connection => Unit) {onCloseCallback = callback}
+ def onClose(callback: Connection => Unit) {
+ onCloseCallback = callback
+ }
- def onException(callback: (Connection, Exception) => Unit) {onExceptionCallback = callback}
+ def onException(callback: (Connection, Exception) => Unit) {
+ onExceptionCallback = callback
+ }
- def onKeyInterestChange(callback: (Connection, Int) => Unit) {onKeyInterestChangeCallback = callback}
+ def onKeyInterestChange(callback: (Connection, Int) => Unit) {
+ onKeyInterestChangeCallback = callback
+ }
def callOnExceptionCallback(e: Exception) {
if (onExceptionCallback != null) {
@@ -95,7 +106,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
" and OnExceptionCallback not registered", e)
}
}
-
+
def callOnCloseCallback() {
if (onCloseCallback != null) {
onCloseCallback(this)
@@ -132,24 +143,25 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
print(" (" + position + ", " + length + ")")
buffer.position(curPosition)
}
-
}
-private[spark] class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
- remoteId_ : ConnectionManagerId)
-extends Connection(SocketChannel.open, selector_, remoteId_) {
+private[spark]
+class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
+ remoteId_ : ConnectionManagerId)
+ extends Connection(SocketChannel.open, selector_, remoteId_) {
class Outbox(fair: Int = 0) {
val messages = new Queue[Message]()
- val defaultChunkSize = 65536 //32768 //16384
+ val defaultChunkSize = 65536 //32768 //16384
var nextMessageToBeUsed = 0
def addMessage(message: Message) {
- messages.synchronized{
+ messages.synchronized{
/*messages += message*/
messages.enqueue(message)
- logDebug("Added [" + message + "] to outbox for sending to [" + getRemoteConnectionManagerId() + "]")
+ logDebug("Added [" + message + "] to outbox for sending to " +
+ "[" + getRemoteConnectionManagerId() + "]")
}
}
@@ -174,7 +186,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
message.started = true
message.startTime = System.currentTimeMillis
}
- return chunk
+ return chunk
} else {
/*logInfo("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + "]")*/
message.finishTime = System.currentTimeMillis
@@ -185,7 +197,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
}
None
}
-
+
private def getChunkRR(): Option[MessageChunk] = {
messages.synchronized {
while (!messages.isEmpty) {
@@ -197,12 +209,14 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
messages.enqueue(message)
nextMessageToBeUsed = nextMessageToBeUsed + 1
if (!message.started) {
- logDebug("Starting to send [" + message + "] to [" + getRemoteConnectionManagerId() + "]")
+ logDebug(
+ "Starting to send [" + message + "] to [" + getRemoteConnectionManagerId() + "]")
message.started = true
message.startTime = System.currentTimeMillis
}
- logTrace("Sending chunk from [" + message+ "] to [" + getRemoteConnectionManagerId() + "]")
- return chunk
+ logTrace(
+ "Sending chunk from [" + message+ "] to [" + getRemoteConnectionManagerId() + "]")
+ return chunk
} else {
message.finishTime = System.currentTimeMillis
logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() +
@@ -213,7 +227,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
None
}
}
-
+
private val outbox = new Outbox(1)
val currentBuffers = new ArrayBuffer[ByteBuffer]()
@@ -228,11 +242,11 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
// it does - so let us keep it for now.
changeConnectionKeyInterest(SelectionKey.OP_WRITE | DEFAULT_INTEREST)
}
-
+
override def unregisterInterest() {
changeConnectionKeyInterest(DEFAULT_INTEREST)
}
-
+
def send(message: Message) {
outbox.synchronized {
outbox.addMessage(message)
@@ -262,12 +276,14 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
// selection - though need not necessarily always complete successfully.
val connected = channel.finishConnect
if (!force && !connected) {
- logInfo("finish connect failed [" + address + "], " + outbox.messages.size + " messages pending")
+ logInfo(
+ "finish connect failed [" + address + "], " + outbox.messages.size + " messages pending")
return false
}
// Fallback to previous behavior - assume finishConnect completed
- // This will happen only when finishConnect failed for some repeated number of times (10 or so)
+ // This will happen only when finishConnect failed for some repeated number of times
+ // (10 or so)
// Is highly unlikely unless there was an unclean close of socket, etc
registerInterest()
logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending")
@@ -283,13 +299,13 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
}
override def write(): Boolean = {
- try{
- while(true) {
+ try {
+ while (true) {
if (currentBuffers.size == 0) {
outbox.synchronized {
outbox.getChunk() match {
case Some(chunk) => {
- currentBuffers ++= chunk.buffers
+ currentBuffers ++= chunk.buffers
}
case None => {
// changeConnectionKeyInterest(0)
@@ -299,7 +315,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
}
}
}
-
+
if (currentBuffers.size > 0) {
val buffer = currentBuffers(0)
val remainingBytes = buffer.remaining
@@ -314,7 +330,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
}
}
} catch {
- case e: Exception => {
+ case e: Exception => {
logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e)
callOnExceptionCallback(e)
close()
@@ -336,7 +352,8 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
if (length == -1) { // EOF
close()
} else if (length > 0) {
- logWarning("Unexpected data read from SendingConnection to " + getRemoteConnectionManagerId())
+ logWarning(
+ "Unexpected data read from SendingConnection to " + getRemoteConnectionManagerId())
}
} catch {
case e: Exception =>
@@ -355,30 +372,32 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
// Must be created within selector loop - else deadlock
-private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector)
-extends Connection(channel_, selector_) {
-
+private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector)
+ extends Connection(channel_, selector_) {
+
class Inbox() {
val messages = new HashMap[Int, BufferMessage]()
-
+
def getChunk(header: MessageChunkHeader): Option[MessageChunk] = {
-
+
def createNewMessage: BufferMessage = {
val newMessage = Message.create(header).asInstanceOf[BufferMessage]
newMessage.started = true
newMessage.startTime = System.currentTimeMillis
- logDebug("Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]")
+ logDebug(
+ "Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]")
messages += ((newMessage.id, newMessage))
newMessage
}
-
+
val message = messages.getOrElseUpdate(header.id, createNewMessage)
- logTrace("Receiving chunk of [" + message + "] from [" + getRemoteConnectionManagerId() + "]")
+ logTrace(
+ "Receiving chunk of [" + message + "] from [" + getRemoteConnectionManagerId() + "]")
message.getChunkForReceiving(header.chunkSize)
}
-
+
def getMessageForChunk(chunk: MessageChunk): Option[BufferMessage] = {
- messages.get(chunk.header.id)
+ messages.get(chunk.header.id)
}
def removeMessage(message: Message) {
@@ -387,12 +406,14 @@ extends Connection(channel_, selector_) {
}
@volatile private var inferredRemoteManagerId: ConnectionManagerId = null
+
override def getRemoteConnectionManagerId(): ConnectionManagerId = {
val currId = inferredRemoteManagerId
if (currId != null) currId else super.getRemoteConnectionManagerId()
}
- // The reciever's remote address is the local socket on remote side : which is NOT the connection manager id of the receiver.
+ // The reciever's remote address is the local socket on remote side : which is NOT
+ // the connection manager id of the receiver.
// We infer that from the messages we receive on the receiver socket.
private def processConnectionManagerId(header: MessageChunkHeader) {
val currId = inferredRemoteManagerId
@@ -428,7 +449,8 @@ extends Connection(channel_, selector_) {
}
headerBuffer.flip
if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) {
- throw new Exception("Unexpected number of bytes (" + headerBuffer.remaining + ") in the header")
+ throw new Exception(
+ "Unexpected number of bytes (" + headerBuffer.remaining + ") in the header")
}
val header = MessageChunkHeader.create(headerBuffer)
headerBuffer.clear()
@@ -451,9 +473,9 @@ extends Connection(channel_, selector_) {
case _ => throw new Exception("Message of unknown type received")
}
}
-
+
if (currentChunk == null) throw new Exception("No message chunk to receive data")
-
+
val bytesRead = channel.read(currentChunk.buffer)
if (bytesRead == 0) {
// re-register for read event ...
@@ -464,14 +486,15 @@ extends Connection(channel_, selector_) {
}
/*logDebug("Read " + bytesRead + " bytes for the buffer")*/
-
+
if (currentChunk.buffer.remaining == 0) {
/*println("Filled buffer at " + System.currentTimeMillis)*/
val bufferMessage = inbox.getMessageForChunk(currentChunk).get
if (bufferMessage.isCompletelyReceived) {
bufferMessage.flip
bufferMessage.finishTime = System.currentTimeMillis
- logDebug("Finished receiving [" + bufferMessage + "] from [" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken)
+ logDebug("Finished receiving [" + bufferMessage + "] from " +
+ "[" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken)
if (onReceiveCallback != null) {
onReceiveCallback(this, bufferMessage)
}
@@ -481,7 +504,7 @@ extends Connection(channel_, selector_) {
}
}
} catch {
- case e: Exception => {
+ case e: Exception => {
logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e)
callOnExceptionCallback(e)
close()
@@ -491,7 +514,7 @@ extends Connection(channel_, selector_) {
// should not happen - to keep scala compiler happy
return true
}
-
+
def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback}
override def changeInterestForRead(): Boolean = true
@@ -505,7 +528,7 @@ extends Connection(channel_, selector_) {
// it does - so let us keep it for now.
changeConnectionKeyInterest(SelectionKey.OP_READ)
}
-
+
override def unregisterInterest() {
changeConnectionKeyInterest(0)
}
diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala
index 0eb03630d0..624a094856 100644
--- a/core/src/main/scala/spark/network/ConnectionManager.scala
+++ b/core/src/main/scala/spark/network/ConnectionManager.scala
@@ -18,20 +18,7 @@ import akka.dispatch.{Await, Promise, ExecutionContext, Future}
import akka.util.Duration
import akka.util.duration._
-private[spark] case class ConnectionManagerId(host: String, port: Int) {
- // DEBUG code
- Utils.checkHost(host)
- assert (port > 0)
- def toSocketAddress() = new InetSocketAddress(host, port)
-}
-
-private[spark] object ConnectionManagerId {
- def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = {
- new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort())
- }
-}
-
private[spark] class ConnectionManager(port: Int) extends Logging {
class MessageStatus(
@@ -45,7 +32,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
def markDone() { completionHandler(this) }
}
-
+
private val selector = SelectorProvider.provider.openSelector()
private val handleMessageExecutor = new ThreadPoolExecutor(
@@ -80,7 +67,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
serverChannel.configureBlocking(false)
serverChannel.socket.setReuseAddress(true)
- serverChannel.socket.setReceiveBufferSize(256 * 1024)
+ serverChannel.socket.setReceiveBufferSize(256 * 1024)
serverChannel.socket.bind(new InetSocketAddress(port))
serverChannel.register(selector, SelectionKey.OP_ACCEPT)
@@ -351,7 +338,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
case e: Exception => logError("Error in select loop", e)
}
}
-
+
def acceptConnection(key: SelectionKey) {
val serverChannel = key.channel.asInstanceOf[ServerSocketChannel]
@@ -463,7 +450,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
def receiveMessage(connection: Connection, message: Message) {
val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress)
- logDebug("Received [" + message + "] from [" + connectionManagerId + "]")
+ logDebug("Received [" + message + "] from [" + connectionManagerId + "]")
val runnable = new Runnable() {
val creationTime = System.currentTimeMillis
def run() {
@@ -483,11 +470,11 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
if (bufferMessage.hasAckId) {
val sentMessageStatus = messageStatuses.synchronized {
messageStatuses.get(bufferMessage.ackId) match {
- case Some(status) => {
- messageStatuses -= bufferMessage.ackId
+ case Some(status) => {
+ messageStatuses -= bufferMessage.ackId
status
}
- case None => {
+ case None => {
throw new Exception("Could not find reference for received ack message " + message.id)
null
}
@@ -507,7 +494,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
logDebug("Not calling back as callback is null")
None
}
-
+
if (ackMessage.isDefined) {
if (!ackMessage.get.isInstanceOf[BufferMessage]) {
logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " + ackMessage.get.getClass())
@@ -517,7 +504,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
}
}
- sendMessage(connectionManagerId, ackMessage.getOrElse {
+ sendMessage(connectionManagerId, ackMessage.getOrElse {
Message.createBufferMessage(bufferMessage.id)
})
}
@@ -588,17 +575,17 @@ private[spark] object ConnectionManager {
def main(args: Array[String]) {
val manager = new ConnectionManager(9999)
- manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
println("Received [" + msg + "] from [" + id + "]")
None
})
-
+
/*testSequentialSending(manager)*/
/*System.gc()*/
/*testParallelSending(manager)*/
/*System.gc()*/
-
+
/*testParallelDecreasingSending(manager)*/
/*System.gc()*/
@@ -610,9 +597,9 @@ private[spark] object ConnectionManager {
println("--------------------------")
println("Sequential Sending")
println("--------------------------")
- val size = 10 * 1024 * 1024
+ val size = 10 * 1024 * 1024
val count = 10
-
+
val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
buffer.flip
@@ -628,7 +615,7 @@ private[spark] object ConnectionManager {
println("--------------------------")
println("Parallel Sending")
println("--------------------------")
- val size = 10 * 1024 * 1024
+ val size = 10 * 1024 * 1024
val count = 10
val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
@@ -643,12 +630,12 @@ private[spark] object ConnectionManager {
if (!g.isDefined) println("Failed")
})
val finishTime = System.currentTimeMillis
-
+
val mb = size * count / 1024.0 / 1024.0
val ms = finishTime - startTime
val tput = mb * 1000.0 / ms
println("--------------------------")
- println("Started at " + startTime + ", finished at " + finishTime)
+ println("Started at " + startTime + ", finished at " + finishTime)
println("Sent " + count + " messages of size " + size + " in " + ms + " ms (" + tput + " MB/s)")
println("--------------------------")
println()
@@ -658,7 +645,7 @@ private[spark] object ConnectionManager {
println("--------------------------")
println("Parallel Decreasing Sending")
println("--------------------------")
- val size = 10 * 1024 * 1024
+ val size = 10 * 1024 * 1024
val count = 10
val buffers = Array.tabulate(count)(i => ByteBuffer.allocate(size * (i + 1)).put(Array.tabulate[Byte](size * (i + 1))(x => x.toByte)))
buffers.foreach(_.flip)
@@ -673,7 +660,7 @@ private[spark] object ConnectionManager {
if (!g.isDefined) println("Failed")
})
val finishTime = System.currentTimeMillis
-
+
val ms = finishTime - startTime
val tput = mb * 1000.0 / ms
println("--------------------------")
@@ -687,7 +674,7 @@ private[spark] object ConnectionManager {
println("--------------------------")
println("Continuous Sending")
println("--------------------------")
- val size = 10 * 1024 * 1024
+ val size = 10 * 1024 * 1024
val count = 10
val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
diff --git a/core/src/main/scala/spark/network/ConnectionManagerId.scala b/core/src/main/scala/spark/network/ConnectionManagerId.scala
new file mode 100644
index 0000000000..b554e84251
--- /dev/null
+++ b/core/src/main/scala/spark/network/ConnectionManagerId.scala
@@ -0,0 +1,21 @@
+package spark.network
+
+import java.net.InetSocketAddress
+
+import spark.Utils
+
+
+private[spark] case class ConnectionManagerId(host: String, port: Int) {
+ // DEBUG code
+ Utils.checkHost(host)
+ assert (port > 0)
+
+ def toSocketAddress() = new InetSocketAddress(host, port)
+}
+
+
+private[spark] object ConnectionManagerId {
+ def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = {
+ new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort())
+ }
+}
diff --git a/core/src/main/scala/spark/network/Message.scala b/core/src/main/scala/spark/network/Message.scala
index 34fac9e776..d4f03610eb 100644
--- a/core/src/main/scala/spark/network/Message.scala
+++ b/core/src/main/scala/spark/network/Message.scala
@@ -1,56 +1,10 @@
package spark.network
-import spark._
-
-import scala.collection.mutable.ArrayBuffer
-
import java.nio.ByteBuffer
-import java.net.InetAddress
import java.net.InetSocketAddress
-import storage.BlockManager
-
-private[spark] class MessageChunkHeader(
- val typ: Long,
- val id: Int,
- val totalSize: Int,
- val chunkSize: Int,
- val other: Int,
- val address: InetSocketAddress) {
- lazy val buffer = {
- // No need to change this, at 'use' time, we do a reverse lookup of the hostname. Refer to network.Connection
- val ip = address.getAddress.getAddress()
- val port = address.getPort()
- ByteBuffer.
- allocate(MessageChunkHeader.HEADER_SIZE).
- putLong(typ).
- putInt(id).
- putInt(totalSize).
- putInt(chunkSize).
- putInt(other).
- putInt(ip.size).
- put(ip).
- putInt(port).
- position(MessageChunkHeader.HEADER_SIZE).
- flip.asInstanceOf[ByteBuffer]
- }
-
- override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ +
- " and sizes " + totalSize + " / " + chunkSize + " bytes"
-}
-private[spark] class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) {
- val size = if (buffer == null) 0 else buffer.remaining
- lazy val buffers = {
- val ab = new ArrayBuffer[ByteBuffer]()
- ab += header.buffer
- if (buffer != null) {
- ab += buffer
- }
- ab
- }
+import scala.collection.mutable.ArrayBuffer
- override def toString = "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")"
-}
private[spark] abstract class Message(val typ: Long, val id: Int) {
var senderAddress: InetSocketAddress = null
@@ -59,120 +13,16 @@ private[spark] abstract class Message(val typ: Long, val id: Int) {
var finishTime = -1L
def size: Int
-
+
def getChunkForSending(maxChunkSize: Int): Option[MessageChunk]
-
+
def getChunkForReceiving(chunkSize: Int): Option[MessageChunk]
-
+
def timeTaken(): String = (finishTime - startTime).toString + " ms"
override def toString = this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")"
}
-private[spark] class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int)
-extends Message(Message.BUFFER_MESSAGE, id_) {
-
- val initialSize = currentSize()
- var gotChunkForSendingOnce = false
-
- def size = initialSize
-
- def currentSize() = {
- if (buffers == null || buffers.isEmpty) {
- 0
- } else {
- buffers.map(_.remaining).reduceLeft(_ + _)
- }
- }
-
- def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = {
- if (maxChunkSize <= 0) {
- throw new Exception("Max chunk size is " + maxChunkSize)
- }
-
- if (size == 0 && gotChunkForSendingOnce == false) {
- val newChunk = new MessageChunk(new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null)
- gotChunkForSendingOnce = true
- return Some(newChunk)
- }
-
- while(!buffers.isEmpty) {
- val buffer = buffers(0)
- if (buffer.remaining == 0) {
- BlockManager.dispose(buffer)
- buffers -= buffer
- } else {
- val newBuffer = if (buffer.remaining <= maxChunkSize) {
- buffer.duplicate()
- } else {
- buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer]
- }
- buffer.position(buffer.position + newBuffer.remaining)
- val newChunk = new MessageChunk(new MessageChunkHeader(
- typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
- gotChunkForSendingOnce = true
- return Some(newChunk)
- }
- }
- None
- }
-
- def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = {
- // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer
- if (buffers.size > 1) {
- throw new Exception("Attempting to get chunk from message with multiple data buffers")
- }
- val buffer = buffers(0)
- if (buffer.remaining > 0) {
- if (buffer.remaining < chunkSize) {
- throw new Exception("Not enough space in data buffer for receiving chunk")
- }
- val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
- buffer.position(buffer.position + newBuffer.remaining)
- val newChunk = new MessageChunk(new MessageChunkHeader(
- typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
- return Some(newChunk)
- }
- None
- }
-
- def flip() {
- buffers.foreach(_.flip)
- }
-
- def hasAckId() = (ackId != 0)
-
- def isCompletelyReceived() = !buffers(0).hasRemaining
-
- override def toString = {
- if (hasAckId) {
- "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")"
- } else {
- "BufferMessage(id = " + id + ", size = " + size + ")"
- }
- }
-}
-
-private[spark] object MessageChunkHeader {
- val HEADER_SIZE = 40
-
- def create(buffer: ByteBuffer): MessageChunkHeader = {
- if (buffer.remaining != HEADER_SIZE) {
- throw new IllegalArgumentException("Cannot convert buffer data to Message")
- }
- val typ = buffer.getLong()
- val id = buffer.getInt()
- val totalSize = buffer.getInt()
- val chunkSize = buffer.getInt()
- val other = buffer.getInt()
- val ipSize = buffer.getInt()
- val ipBytes = new Array[Byte](ipSize)
- buffer.get(ipBytes)
- val ip = InetAddress.getByAddress(ipBytes)
- val port = buffer.getInt()
- new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port))
- }
-}
private[spark] object Message {
val BUFFER_MESSAGE = 1111111111L
@@ -181,14 +31,16 @@ private[spark] object Message {
def getNewId() = synchronized {
lastId += 1
- if (lastId == 0) lastId += 1
+ if (lastId == 0) {
+ lastId += 1
+ }
lastId
}
def createBufferMessage(dataBuffers: Seq[ByteBuffer], ackId: Int): BufferMessage = {
if (dataBuffers == null) {
return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer], ackId)
- }
+ }
if (dataBuffers.exists(_ == null)) {
throw new Exception("Attempting to create buffer message with null buffer")
}
@@ -197,7 +49,7 @@ private[spark] object Message {
def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage =
createBufferMessage(dataBuffers, 0)
-
+
def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = {
if (dataBuffer == null) {
return createBufferMessage(Array(ByteBuffer.allocate(0)), ackId)
@@ -205,15 +57,18 @@ private[spark] object Message {
return createBufferMessage(Array(dataBuffer), ackId)
}
}
-
- def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage =
+
+ def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage =
createBufferMessage(dataBuffer, 0)
-
- def createBufferMessage(ackId: Int): BufferMessage = createBufferMessage(new Array[ByteBuffer](0), ackId)
+
+ def createBufferMessage(ackId: Int): BufferMessage = {
+ createBufferMessage(new Array[ByteBuffer](0), ackId)
+ }
def create(header: MessageChunkHeader): Message = {
val newMessage: Message = header.typ match {
- case BUFFER_MESSAGE => new BufferMessage(header.id, ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other)
+ case BUFFER_MESSAGE => new BufferMessage(header.id,
+ ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other)
}
newMessage.senderAddress = header.address
newMessage
diff --git a/core/src/main/scala/spark/network/MessageChunk.scala b/core/src/main/scala/spark/network/MessageChunk.scala
new file mode 100644
index 0000000000..aaf9204d0e
--- /dev/null
+++ b/core/src/main/scala/spark/network/MessageChunk.scala
@@ -0,0 +1,25 @@
+package spark.network
+
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+
+
+private[network]
+class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) {
+
+ val size = if (buffer == null) 0 else buffer.remaining
+
+ lazy val buffers = {
+ val ab = new ArrayBuffer[ByteBuffer]()
+ ab += header.buffer
+ if (buffer != null) {
+ ab += buffer
+ }
+ ab
+ }
+
+ override def toString = {
+ "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")"
+ }
+}
diff --git a/core/src/main/scala/spark/network/MessageChunkHeader.scala b/core/src/main/scala/spark/network/MessageChunkHeader.scala
new file mode 100644
index 0000000000..3693d509d6
--- /dev/null
+++ b/core/src/main/scala/spark/network/MessageChunkHeader.scala
@@ -0,0 +1,58 @@
+package spark.network
+
+import java.net.InetAddress
+import java.net.InetSocketAddress
+import java.nio.ByteBuffer
+
+
+private[spark] class MessageChunkHeader(
+ val typ: Long,
+ val id: Int,
+ val totalSize: Int,
+ val chunkSize: Int,
+ val other: Int,
+ val address: InetSocketAddress) {
+ lazy val buffer = {
+ // No need to change this, at 'use' time, we do a reverse lookup of the hostname.
+ // Refer to network.Connection
+ val ip = address.getAddress.getAddress()
+ val port = address.getPort()
+ ByteBuffer.
+ allocate(MessageChunkHeader.HEADER_SIZE).
+ putLong(typ).
+ putInt(id).
+ putInt(totalSize).
+ putInt(chunkSize).
+ putInt(other).
+ putInt(ip.size).
+ put(ip).
+ putInt(port).
+ position(MessageChunkHeader.HEADER_SIZE).
+ flip.asInstanceOf[ByteBuffer]
+ }
+
+ override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ +
+ " and sizes " + totalSize + " / " + chunkSize + " bytes"
+}
+
+
+private[spark] object MessageChunkHeader {
+ val HEADER_SIZE = 40
+
+ def create(buffer: ByteBuffer): MessageChunkHeader = {
+ if (buffer.remaining != HEADER_SIZE) {
+ throw new IllegalArgumentException("Cannot convert buffer data to Message")
+ }
+ val typ = buffer.getLong()
+ val id = buffer.getInt()
+ val totalSize = buffer.getInt()
+ val chunkSize = buffer.getInt()
+ val other = buffer.getInt()
+ val ipSize = buffer.getInt()
+ val ipBytes = new Array[Byte](ipSize)
+ buffer.get(ipBytes)
+ val ip = InetAddress.getByAddress(ipBytes)
+ val port = buffer.getInt()
+ new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port))
+ }
+}
diff --git a/core/src/main/scala/spark/network/netty/FileHeader.scala b/core/src/main/scala/spark/network/netty/FileHeader.scala
new file mode 100644
index 0000000000..aed4254234
--- /dev/null
+++ b/core/src/main/scala/spark/network/netty/FileHeader.scala
@@ -0,0 +1,57 @@
+package spark.network.netty
+
+import io.netty.buffer._
+
+import spark.Logging
+
+private[spark] class FileHeader (
+ val fileLen: Int,
+ val blockId: String) extends Logging {
+
+ lazy val buffer = {
+ val buf = Unpooled.buffer()
+ buf.capacity(FileHeader.HEADER_SIZE)
+ buf.writeInt(fileLen)
+ buf.writeInt(blockId.length)
+ blockId.foreach((x: Char) => buf.writeByte(x))
+ //padding the rest of header
+ if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) {
+ buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes)
+ } else {
+ throw new Exception("too long header " + buf.readableBytes)
+ logInfo("too long header")
+ }
+ buf
+ }
+
+}
+
+private[spark] object FileHeader {
+
+ val HEADER_SIZE = 40
+
+ def getFileLenOffset = 0
+ def getFileLenSize = Integer.SIZE/8
+
+ def create(buf: ByteBuf): FileHeader = {
+ val length = buf.readInt
+ val idLength = buf.readInt
+ val idBuilder = new StringBuilder(idLength)
+ for (i <- 1 to idLength) {
+ idBuilder += buf.readByte().asInstanceOf[Char]
+ }
+ val blockId = idBuilder.toString()
+ new FileHeader(length, blockId)
+ }
+
+
+ def main (args:Array[String]){
+
+ val header = new FileHeader(25,"block_0");
+ val buf = header.buffer;
+ val newheader = FileHeader.create(buf);
+ System.out.println("id="+newheader.blockId+",size="+newheader.fileLen)
+
+ }
+}
+
diff --git a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala
new file mode 100644
index 0000000000..a91f5a886d
--- /dev/null
+++ b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala
@@ -0,0 +1,84 @@
+package spark.network.netty
+
+import java.util.concurrent.Executors
+
+import io.netty.buffer.ByteBuf
+import io.netty.channel.ChannelHandlerContext
+import io.netty.util.CharsetUtil
+
+import spark.Logging
+import spark.network.ConnectionManagerId
+
+
+private[spark] class ShuffleCopier extends Logging {
+
+ def getBlock(cmId: ConnectionManagerId, blockId: String,
+ resultCollectCallback: (String, Long, ByteBuf) => Unit) {
+
+ val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback)
+ val fc = new FileClient(handler)
+ fc.init()
+ fc.connect(cmId.host, cmId.port)
+ fc.sendRequest(blockId)
+ fc.waitForClose()
+ fc.close()
+ }
+
+ def getBlocks(cmId: ConnectionManagerId,
+ blocks: Seq[(String, Long)],
+ resultCollectCallback: (String, Long, ByteBuf) => Unit) {
+
+ for ((blockId, size) <- blocks) {
+ getBlock(cmId, blockId, resultCollectCallback)
+ }
+ }
+}
+
+
+private[spark] object ShuffleCopier extends Logging {
+
+ private class ShuffleClientHandler(resultCollectCallBack: (String, Long, ByteBuf) => Unit)
+ extends FileClientHandler with Logging {
+
+ override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) {
+ logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)");
+ resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen))
+ }
+ }
+
+ def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) {
+ logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"")
+ }
+
+ def runGetBlock(host:String, port:Int, file:String){
+ val handler = new ShuffleClientHandler(echoResultCollectCallBack)
+ val fc = new FileClient(handler)
+ fc.init();
+ fc.connect(host, port)
+ fc.sendRequest(file)
+ fc.waitForClose();
+ fc.close()
+ }
+
+ def main(args: Array[String]) {
+ if (args.length < 3) {
+ System.err.println("Usage: ShuffleCopier <host> <port> <shuffle_block_id> <threads>")
+ System.exit(1)
+ }
+ val host = args(0)
+ val port = args(1).toInt
+ val file = args(2)
+ val threads = if (args.length > 3) args(3).toInt else 10
+
+ val copiers = Executors.newFixedThreadPool(80)
+ for (i <- Range(0, threads)) {
+ val runnable = new Runnable() {
+ def run() {
+ runGetBlock(host, port, file)
+ }
+ }
+ copiers.execute(runnable)
+ }
+ copiers.shutdown
+ }
+}
diff --git a/core/src/main/scala/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/spark/network/netty/ShuffleSender.scala
new file mode 100644
index 0000000000..dc87fefc56
--- /dev/null
+++ b/core/src/main/scala/spark/network/netty/ShuffleSender.scala
@@ -0,0 +1,56 @@
+package spark.network.netty
+
+import java.io.File
+
+import spark.Logging
+
+
+private[spark] class ShuffleSender(val port: Int, val pResolver: PathResolver) extends Logging {
+ val server = new FileServer(pResolver)
+
+ Runtime.getRuntime().addShutdownHook(
+ new Thread() {
+ override def run() {
+ server.stop()
+ }
+ }
+ )
+
+ def start() {
+ server.run(port)
+ }
+}
+
+
+private[spark] object ShuffleSender {
+
+ def main(args: Array[String]) {
+ if (args.length < 3) {
+ System.err.println(
+ "Usage: ShuffleSender <port> <subDirsPerLocalDir> <list of shuffle_block_directories>")
+ System.exit(1)
+ }
+
+ val port = args(0).toInt
+ val subDirsPerLocalDir = args(1).toInt
+ val localDirs = args.drop(2).map(new File(_))
+
+ val pResovler = new PathResolver {
+ override def getAbsolutePath(blockId: String): String = {
+ if (!blockId.startsWith("shuffle_")) {
+ throw new Exception("Block " + blockId + " is not a shuffle block")
+ }
+ // Figure out which local directory it hashes to, and which subdirectory in that
+ val hash = math.abs(blockId.hashCode)
+ val dirId = hash % localDirs.length
+ val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
+ val subDir = new File(localDirs(dirId), "%02x".format(subDirId))
+ val file = new File(subDir, blockId)
+ return file.getAbsolutePath
+ }
+ }
+ val sender = new ShuffleSender(port, pResovler)
+
+ sender.start()
+ }
+}
diff --git a/core/src/main/scala/spark/storage/BlockFetchTracker.scala b/core/src/main/scala/spark/storage/BlockFetchTracker.scala
index 993aece1f7..0718156b1b 100644
--- a/core/src/main/scala/spark/storage/BlockFetchTracker.scala
+++ b/core/src/main/scala/spark/storage/BlockFetchTracker.scala
@@ -1,10 +1,10 @@
package spark.storage
private[spark] trait BlockFetchTracker {
- def totalBlocks : Int
- def numLocalBlocks: Int
- def numRemoteBlocks: Int
- def remoteFetchTime : Long
- def fetchWaitTime: Long
- def remoteBytesRead : Long
+ def totalBlocks : Int
+ def numLocalBlocks: Int
+ def numRemoteBlocks: Int
+ def remoteFetchTime : Long
+ def fetchWaitTime: Long
+ def remoteBytesRead : Long
}
diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala
new file mode 100644
index 0000000000..43f835237c
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala
@@ -0,0 +1,360 @@
+package spark.storage
+
+import java.nio.ByteBuffer
+import java.util.concurrent.LinkedBlockingQueue
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashSet
+import scala.collection.mutable.Queue
+
+import io.netty.buffer.ByteBuf
+
+import spark.Logging
+import spark.Utils
+import spark.SparkException
+import spark.network.BufferMessage
+import spark.network.ConnectionManagerId
+import spark.network.netty.ShuffleCopier
+import spark.serializer.Serializer
+
+
+/**
+ * A block fetcher iterator interface. There are two implementations:
+ *
+ * BasicBlockFetcherIterator: uses a custom-built NIO communication layer.
+ * NettyBlockFetcherIterator: uses Netty (OIO) as the communication layer.
+ *
+ * Eventually we would like the two to converge and use a single NIO-based communication layer,
+ * but extensive tests show that under some circumstances (e.g. large shuffles with lots of cores),
+ * NIO would perform poorly and thus the need for the Netty OIO one.
+ */
+
+private[storage]
+trait BlockFetcherIterator extends Iterator[(String, Option[Iterator[Any]])]
+ with Logging with BlockFetchTracker {
+ def initialize()
+}
+
+
+private[storage]
+object BlockFetcherIterator {
+
+ // A request to fetch one or more blocks, complete with their sizes
+ class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) {
+ val size = blocks.map(_._2).sum
+ }
+
+ // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize
+ // the block (since we want all deserializaton to happen in the calling thread); can also
+ // represent a fetch failure if size == -1.
+ class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) {
+ def failed: Boolean = size == -1
+ }
+
+ class BasicBlockFetcherIterator(
+ private val blockManager: BlockManager,
+ val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])],
+ serializer: Serializer)
+ extends BlockFetcherIterator {
+
+ import blockManager._
+
+ private var _remoteBytesRead = 0l
+ private var _remoteFetchTime = 0l
+ private var _fetchWaitTime = 0l
+
+ if (blocksByAddress == null) {
+ throw new IllegalArgumentException("BlocksByAddress is null")
+ }
+
+ protected var _totalBlocks = blocksByAddress.map(_._2.size).sum
+ logDebug("Getting " + _totalBlocks + " blocks")
+ protected var startTime = System.currentTimeMillis
+ protected val localBlockIds = new ArrayBuffer[String]()
+ protected val remoteBlockIds = new HashSet[String]()
+
+ // A queue to hold our results.
+ protected val results = new LinkedBlockingQueue[FetchResult]
+
+ // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
+ // the number of bytes in flight is limited to maxBytesInFlight
+ private val fetchRequests = new Queue[FetchRequest]
+
+ // Current bytes in flight from our requests
+ private var bytesInFlight = 0L
+
+ protected def sendRequest(req: FetchRequest) {
+ logDebug("Sending request for %d blocks (%s) from %s".format(
+ req.blocks.size, Utils.memoryBytesToString(req.size), req.address.hostPort))
+ val cmId = new ConnectionManagerId(req.address.host, req.address.port)
+ val blockMessageArray = new BlockMessageArray(req.blocks.map {
+ case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId))
+ })
+ bytesInFlight += req.size
+ val sizeMap = req.blocks.toMap // so we can look up the size of each blockID
+ val fetchStart = System.currentTimeMillis()
+ val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
+ future.onSuccess {
+ case Some(message) => {
+ val fetchDone = System.currentTimeMillis()
+ _remoteFetchTime += fetchDone - fetchStart
+ val bufferMessage = message.asInstanceOf[BufferMessage]
+ val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
+ for (blockMessage <- blockMessageArray) {
+ if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
+ throw new SparkException(
+ "Unexpected message " + blockMessage.getType + " received from " + cmId)
+ }
+ val blockId = blockMessage.getId
+ results.put(new FetchResult(blockId, sizeMap(blockId),
+ () => dataDeserialize(blockId, blockMessage.getData, serializer)))
+ _remoteBytesRead += req.size
+ logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
+ }
+ }
+ case None => {
+ logError("Could not get block(s) from " + cmId)
+ for ((blockId, size) <- req.blocks) {
+ results.put(new FetchResult(blockId, -1, null))
+ }
+ }
+ }
+ }
+
+ protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
+ // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
+ // at most maxBytesInFlight in order to limit the amount of data in flight.
+ val remoteRequests = new ArrayBuffer[FetchRequest]
+ for ((address, blockInfos) <- blocksByAddress) {
+ if (address == blockManagerId) {
+ localBlockIds ++= blockInfos.map(_._1)
+ } else {
+ remoteBlockIds ++= blockInfos.map(_._1)
+ // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
+ // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
+ // nodes, rather than blocking on reading output from one node.
+ val minRequestSize = math.max(maxBytesInFlight / 5, 1L)
+ logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize)
+ val iterator = blockInfos.iterator
+ var curRequestSize = 0L
+ var curBlocks = new ArrayBuffer[(String, Long)]
+ while (iterator.hasNext) {
+ val (blockId, size) = iterator.next()
+ curBlocks += ((blockId, size))
+ curRequestSize += size
+ if (curRequestSize >= minRequestSize) {
+ // Add this FetchRequest
+ remoteRequests += new FetchRequest(address, curBlocks)
+ curRequestSize = 0
+ curBlocks = new ArrayBuffer[(String, Long)]
+ }
+ }
+ // Add in the final request
+ if (!curBlocks.isEmpty) {
+ remoteRequests += new FetchRequest(address, curBlocks)
+ }
+ }
+ }
+ remoteRequests
+ }
+
+ protected def getLocalBlocks() {
+ // Get the local blocks while remote blocks are being fetched. Note that it's okay to do
+ // these all at once because they will just memory-map some files, so they won't consume
+ // any memory that might exceed our maxBytesInFlight
+ for (id <- localBlockIds) {
+ getLocal(id) match {
+ case Some(iter) => {
+ // Pass 0 as size since it's not in flight
+ results.put(new FetchResult(id, 0, () => iter))
+ logDebug("Got local block " + id)
+ }
+ case None => {
+ throw new BlockException(id, "Could not get block " + id + " from local machine")
+ }
+ }
+ }
+ }
+
+ override def initialize() {
+ // Split local and remote blocks.
+ val remoteRequests = splitLocalRemoteBlocks()
+ // Add the remote requests into our queue in a random order
+ fetchRequests ++= Utils.randomize(remoteRequests)
+
+ // Send out initial requests for blocks, up to our maxBytesInFlight
+ while (!fetchRequests.isEmpty &&
+ (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
+ sendRequest(fetchRequests.dequeue())
+ }
+
+ val numGets = remoteBlockIds.size - fetchRequests.size
+ logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime))
+
+ // Get Local Blocks
+ startTime = System.currentTimeMillis
+ getLocalBlocks()
+ logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
+ }
+
+ //an iterator that will read fetched blocks off the queue as they arrive.
+ @volatile protected var resultsGotten = 0
+
+ override def hasNext: Boolean = resultsGotten < _totalBlocks
+
+ override def next(): (String, Option[Iterator[Any]]) = {
+ resultsGotten += 1
+ val startFetchWait = System.currentTimeMillis()
+ val result = results.take()
+ val stopFetchWait = System.currentTimeMillis()
+ _fetchWaitTime += (stopFetchWait - startFetchWait)
+ if (! result.failed) bytesInFlight -= result.size
+ while (!fetchRequests.isEmpty &&
+ (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
+ sendRequest(fetchRequests.dequeue())
+ }
+ (result.blockId, if (result.failed) None else Some(result.deserialize()))
+ }
+
+ // Implementing BlockFetchTracker trait.
+ override def totalBlocks: Int = _totalBlocks
+ override def numLocalBlocks: Int = localBlockIds.size
+ override def numRemoteBlocks: Int = remoteBlockIds.size
+ override def remoteFetchTime: Long = _remoteFetchTime
+ override def fetchWaitTime: Long = _fetchWaitTime
+ override def remoteBytesRead: Long = _remoteBytesRead
+ }
+ // End of BasicBlockFetcherIterator
+
+ class NettyBlockFetcherIterator(
+ blockManager: BlockManager,
+ blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])],
+ serializer: Serializer)
+ extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer) {
+
+ import blockManager._
+
+ val fetchRequestsSync = new LinkedBlockingQueue[FetchRequest]
+
+ private def startCopiers(numCopiers: Int): List[_ <: Thread] = {
+ (for ( i <- Range(0,numCopiers) ) yield {
+ val copier = new Thread {
+ override def run(){
+ try {
+ while(!isInterrupted && !fetchRequestsSync.isEmpty) {
+ sendRequest(fetchRequestsSync.take())
+ }
+ } catch {
+ case x: InterruptedException => logInfo("Copier Interrupted")
+ //case _ => throw new SparkException("Exception Throw in Shuffle Copier")
+ }
+ }
+ }
+ copier.start
+ copier
+ }).toList
+ }
+
+ //keep this to interrupt the threads when necessary
+ private def stopCopiers() {
+ for (copier <- copiers) {
+ copier.interrupt()
+ }
+ }
+
+ override protected def sendRequest(req: FetchRequest) {
+
+ def putResult(blockId: String, blockSize: Long, blockData: ByteBuf) {
+ val fetchResult = new FetchResult(blockId, blockSize,
+ () => dataDeserialize(blockId, blockData.nioBuffer, serializer))
+ results.put(fetchResult)
+ }
+
+ logDebug("Sending request for %d blocks (%s) from %s".format(
+ req.blocks.size, Utils.memoryBytesToString(req.size), req.address.host))
+ val cmId = new ConnectionManagerId(
+ req.address.host, System.getProperty("spark.shuffle.sender.port", "6653").toInt)
+ val cpier = new ShuffleCopier
+ cpier.getBlocks(cmId, req.blocks, putResult)
+ logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.host )
+ }
+
+ override protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
+ // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
+ // at most maxBytesInFlight in order to limit the amount of data in flight.
+ val originalTotalBlocks = _totalBlocks;
+ val remoteRequests = new ArrayBuffer[FetchRequest]
+ for ((address, blockInfos) <- blocksByAddress) {
+ if (address == blockManagerId) {
+ localBlockIds ++= blockInfos.map(_._1)
+ } else {
+ remoteBlockIds ++= blockInfos.map(_._1)
+ // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
+ // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
+ // nodes, rather than blocking on reading output from one node.
+ val minRequestSize = math.max(maxBytesInFlight / 5, 1L)
+ logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize)
+ val iterator = blockInfos.iterator
+ var curRequestSize = 0L
+ var curBlocks = new ArrayBuffer[(String, Long)]
+ while (iterator.hasNext) {
+ val (blockId, size) = iterator.next()
+ if (size > 0) {
+ curBlocks += ((blockId, size))
+ curRequestSize += size
+ } else if (size == 0) {
+ //here we changes the totalBlocks
+ _totalBlocks -= 1
+ } else {
+ throw new BlockException(blockId, "Negative block size " + size)
+ }
+ if (curRequestSize >= minRequestSize) {
+ // Add this FetchRequest
+ remoteRequests += new FetchRequest(address, curBlocks)
+ curRequestSize = 0
+ curBlocks = new ArrayBuffer[(String, Long)]
+ }
+ }
+ // Add in the final request
+ if (!curBlocks.isEmpty) {
+ remoteRequests += new FetchRequest(address, curBlocks)
+ }
+ }
+ }
+ logInfo("Getting " + _totalBlocks + " non-zero-bytes blocks out of " +
+ originalTotalBlocks + " blocks")
+ remoteRequests
+ }
+
+ private var copiers: List[_ <: Thread] = null
+
+ override def initialize() {
+ // Split Local Remote Blocks and adjust totalBlocks to include only the non 0-byte blocks
+ val remoteRequests = splitLocalRemoteBlocks()
+ // Add the remote requests into our queue in a random order
+ for (request <- Utils.randomize(remoteRequests)) {
+ fetchRequestsSync.put(request)
+ }
+
+ copiers = startCopiers(System.getProperty("spark.shuffle.copier.threads", "6").toInt)
+ logInfo("Started " + fetchRequestsSync.size + " remote gets in " +
+ Utils.getUsedTimeMs(startTime))
+
+ // Get Local Blocks
+ startTime = System.currentTimeMillis
+ getLocalBlocks()
+ logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
+ }
+
+ override def next(): (String, Option[Iterator[Any]]) = {
+ resultsGotten += 1
+ val result = results.take()
+ // if all the results has been retrieved, shutdown the copiers
+ if (resultsGotten == _totalBlocks && copiers != null) {
+ stopCopiers()
+ }
+ (result.blockId, if (result.failed) None else Some(result.deserialize()))
+ }
+ }
+ // End of NettyBlockFetcherIterator
+}
diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala
index 72962a3ceb..40d608628e 100644
--- a/core/src/main/scala/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/spark/storage/BlockManager.scala
@@ -2,7 +2,6 @@ package spark.storage
import java.io.{InputStream, OutputStream}
import java.nio.{ByteBuffer, MappedByteBuffer}
-import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue}
import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet, Queue}
import scala.collection.JavaConversions._
@@ -24,8 +23,7 @@ import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStam
import sun.nio.ch.DirectBuffer
-private[spark]
-class BlockManager(
+private[spark] class BlockManager(
executorId: String,
actorSystem: ActorSystem,
val master: BlockManagerMaster,
@@ -484,7 +482,16 @@ class BlockManager(
def getMultiple(
blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], serializer: Serializer)
: BlockFetcherIterator = {
- return new BlockFetcherIterator(this, blocksByAddress, serializer)
+
+ val iter =
+ if (System.getProperty("spark.shuffle.use.netty", "false").toBoolean) {
+ new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer)
+ } else {
+ new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer)
+ }
+
+ iter.initialize()
+ iter
}
def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean)
@@ -928,8 +935,8 @@ class BlockManager(
}
}
-private[spark]
-object BlockManager extends Logging {
+
+private[spark] object BlockManager extends Logging {
val ID_GENERATOR = new IdGenerator
@@ -962,7 +969,7 @@ object BlockManager extends Logging {
def blockIdsToExecutorLocations(blockIds: Array[String], env: SparkEnv, blockManagerMaster: BlockManagerMaster = null): HashMap[String, List[String]] = {
// env == null and blockManagerMaster != null is used in tests
assert (env != null || blockManagerMaster != null)
- val locationBlockIds: Seq[Seq[BlockManagerId]] =
+ val locationBlockIds: Seq[Seq[BlockManagerId]] =
if (env != null) {
val blockManager = env.blockManager
blockManager.getLocationBlockIds(blockIds)
@@ -1000,176 +1007,3 @@ object BlockManager extends Logging {
}
-class BlockFetcherIterator(
- private val blockManager: BlockManager,
- val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])],
- serializer: Serializer
-) extends Iterator[(String, Option[Iterator[Any]])] with Logging with BlockFetchTracker {
-
- import blockManager._
-
- private var _remoteBytesRead = 0l
- private var _remoteFetchTime = 0l
- private var _fetchWaitTime = 0l
-
- if (blocksByAddress == null) {
- throw new IllegalArgumentException("BlocksByAddress is null")
- }
- val totalBlocks = blocksByAddress.map(_._2.size).sum
- logDebug("Getting " + totalBlocks + " blocks")
- var startTime = System.currentTimeMillis
- val localBlockIds = new ArrayBuffer[String]()
- val remoteBlockIds = new HashSet[String]()
-
- // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize
- // the block (since we want all deserializaton to happen in the calling thread); can also
- // represent a fetch failure if size == -1.
- class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) {
- def failed: Boolean = size == -1
- }
-
- // A queue to hold our results.
- val results = new LinkedBlockingQueue[FetchResult]
-
- // A request to fetch one or more blocks, complete with their sizes
- class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) {
- val size = blocks.map(_._2).sum
- }
-
- // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
- // the number of bytes in flight is limited to maxBytesInFlight
- val fetchRequests = new Queue[FetchRequest]
-
- // Current bytes in flight from our requests
- var bytesInFlight = 0L
-
- def sendRequest(req: FetchRequest) {
- logDebug("Sending request for %d blocks (%s) from %s".format(
- req.blocks.size, Utils.memoryBytesToString(req.size), req.address.hostPort))
- val cmId = new ConnectionManagerId(req.address.host, req.address.port)
- val blockMessageArray = new BlockMessageArray(req.blocks.map {
- case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId))
- })
- bytesInFlight += req.size
- val sizeMap = req.blocks.toMap // so we can look up the size of each blockID
- val fetchStart = System.currentTimeMillis()
- val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
- future.onSuccess {
- case Some(message) => {
- val fetchDone = System.currentTimeMillis()
- _remoteFetchTime += fetchDone - fetchStart
- val bufferMessage = message.asInstanceOf[BufferMessage]
- val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
- for (blockMessage <- blockMessageArray) {
- if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
- throw new SparkException(
- "Unexpected message " + blockMessage.getType + " received from " + cmId)
- }
- val blockId = blockMessage.getId
- results.put(new FetchResult(blockId, sizeMap(blockId),
- () => dataDeserialize(blockId, blockMessage.getData, serializer)))
- _remoteBytesRead += req.size
- logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
- }
- }
- case None => {
- logError("Could not get block(s) from " + cmId)
- for ((blockId, size) <- req.blocks) {
- results.put(new FetchResult(blockId, -1, null))
- }
- }
- }
- }
-
- // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
- // at most maxBytesInFlight in order to limit the amount of data in flight.
- val remoteRequests = new ArrayBuffer[FetchRequest]
- for ((address, blockInfos) <- blocksByAddress) {
- if (address == blockManagerId) {
- localBlockIds ++= blockInfos.map(_._1)
- } else {
- remoteBlockIds ++= blockInfos.map(_._1)
- // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
- // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
- // nodes, rather than blocking on reading output from one node.
- val minRequestSize = math.max(maxBytesInFlight / 5, 1L)
- logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize)
- val iterator = blockInfos.iterator
- var curRequestSize = 0L
- var curBlocks = new ArrayBuffer[(String, Long)]
- while (iterator.hasNext) {
- val (blockId, size) = iterator.next()
- curBlocks += ((blockId, size))
- curRequestSize += size
- if (curRequestSize >= minRequestSize) {
- // Add this FetchRequest
- remoteRequests += new FetchRequest(address, curBlocks)
- curRequestSize = 0
- curBlocks = new ArrayBuffer[(String, Long)]
- }
- }
- // Add in the final request
- if (!curBlocks.isEmpty) {
- remoteRequests += new FetchRequest(address, curBlocks)
- }
- }
- }
- // Add the remote requests into our queue in a random order
- fetchRequests ++= Utils.randomize(remoteRequests)
-
- // Send out initial requests for blocks, up to our maxBytesInFlight
- while (!fetchRequests.isEmpty &&
- (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
- sendRequest(fetchRequests.dequeue())
- }
-
- val numGets = remoteBlockIds.size - fetchRequests.size
- logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime))
-
- // Get the local blocks while remote blocks are being fetched. Note that it's okay to do
- // these all at once because they will just memory-map some files, so they won't consume
- // any memory that might exceed our maxBytesInFlight
- startTime = System.currentTimeMillis
- for (id <- localBlockIds) {
- getLocalFromDisk(id, serializer) match {
- case Some(iter) => {
- results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight
- logDebug("Got local block " + id)
- }
- case None => {
- throw new BlockException(id, "Could not get block " + id + " from local machine")
- }
- }
- }
- logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
-
- //an iterator that will read fetched blocks off the queue as they arrive.
- @volatile private var resultsGotten = 0
-
- def hasNext: Boolean = resultsGotten < totalBlocks
-
- def next(): (String, Option[Iterator[Any]]) = {
- resultsGotten += 1
- val startFetchWait = System.currentTimeMillis()
- val result = results.take()
- val stopFetchWait = System.currentTimeMillis()
- _fetchWaitTime += (stopFetchWait - startFetchWait)
- if (! result.failed) bytesInFlight -= result.size
- while (!fetchRequests.isEmpty &&
- (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
- sendRequest(fetchRequests.dequeue())
- }
- (result.blockId, if (result.failed) None else Some(result.deserialize()))
- }
-
-
- //methods to profile the block fetching
- def numLocalBlocks = localBlockIds.size
- def numRemoteBlocks = remoteBlockIds.size
-
- def remoteFetchTime = _remoteFetchTime
- def fetchWaitTime = _fetchWaitTime
-
- def remoteBytesRead = _remoteBytesRead
-
-}
diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala
index 8154b8ca74..933eeaa216 100644
--- a/core/src/main/scala/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/spark/storage/DiskStore.scala
@@ -14,13 +14,16 @@ import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
import spark.Utils
import spark.executor.ExecutorExitCode
import spark.serializer.{Serializer, SerializationStream}
+import spark.Logging
+import spark.network.netty.ShuffleSender
+import spark.network.netty.PathResolver
/**
* Stores BlockManager blocks on disk.
*/
private class DiskStore(blockManager: BlockManager, rootDirs: String)
- extends BlockStore(blockManager) {
+ extends BlockStore(blockManager) with Logging {
class DiskBlockObjectWriter(blockId: String, serializer: Serializer, bufferSize: Int)
extends BlockObjectWriter(blockId) {
@@ -79,19 +82,28 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
val MAX_DIR_CREATION_ATTEMPTS: Int = 10
val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
+ var shuffleSender : Thread = null
+ val thisInstance = this
// Create one local directory for each path mentioned in spark.local.dir; then, inside this
// directory, create multiple subdirectories that we will hash files into, in order to avoid
// having really large inodes at the top level.
val localDirs = createLocalDirs()
val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
+ val useNetty = System.getProperty("spark.shuffle.use.netty", "false").toBoolean
+
addShutdownHook()
+ if(useNetty){
+ startShuffleBlockSender()
+ }
+
def getBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int)
: BlockObjectWriter = {
new DiskBlockObjectWriter(blockId, serializer, bufferSize)
}
+
override def getSize(blockId: String): Long = {
getFile(blockId).length()
}
@@ -262,10 +274,48 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
localDirs.foreach { localDir =>
if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
}
+ if (useNetty && shuffleSender != null)
+ shuffleSender.stop
} catch {
case t: Throwable => logError("Exception while deleting local spark dirs", t)
}
}
})
}
+
+ private def startShuffleBlockSender() {
+ try {
+ val port = System.getProperty("spark.shuffle.sender.port", "6653").toInt
+
+ val pResolver = new PathResolver {
+ override def getAbsolutePath(blockId: String): String = {
+ if (!blockId.startsWith("shuffle_")) {
+ return null
+ }
+ thisInstance.getFile(blockId).getAbsolutePath()
+ }
+ }
+ shuffleSender = new Thread {
+ override def run() = {
+ val sender = new ShuffleSender(port, pResolver)
+ logInfo("Created ShuffleSender binding to port : "+ port)
+ sender.start
+ }
+ }
+ shuffleSender.setDaemon(true)
+ shuffleSender.start
+
+ } catch {
+ case interrupted: InterruptedException =>
+ logInfo("Runner thread for ShuffleBlockSender interrupted")
+
+ case e: Exception => {
+ logError("Error running ShuffleBlockSender ", e)
+ if (shuffleSender != null) {
+ shuffleSender.stop
+ shuffleSender = null
+ }
+ }
+ }
+ }
}
diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala
index 33c99471c6..06a94ed24c 100644
--- a/core/src/test/scala/spark/DistributedSuite.scala
+++ b/core/src/test/scala/spark/DistributedSuite.scala
@@ -18,10 +18,13 @@ import scala.collection.mutable.ArrayBuffer
import SparkContext._
import storage.{GetBlock, BlockManagerWorker, StorageLevel}
+
class NotSerializableClass
class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {}
-class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter with LocalSparkContext {
+
+class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
+ with LocalSparkContext {
val clusterUrl = "local-cluster[2,1,512]"
diff --git a/core/src/test/scala/spark/ShuffleNettySuite.scala b/core/src/test/scala/spark/ShuffleNettySuite.scala
new file mode 100644
index 0000000000..bfaffa953e
--- /dev/null
+++ b/core/src/test/scala/spark/ShuffleNettySuite.scala
@@ -0,0 +1,17 @@
+package spark
+
+import org.scalatest.BeforeAndAfterAll
+
+
+class ShuffleNettySuite extends ShuffleSuite with BeforeAndAfterAll {
+
+ // This test suite should run all tests in ShuffleSuite with Netty shuffle mode.
+
+ override def beforeAll(configMap: Map[String, Any]) {
+ System.setProperty("spark.shuffle.use.netty", "true")
+ }
+
+ override def afterAll(configMap: Map[String, Any]) {
+ System.setProperty("spark.shuffle.use.netty", "false")
+ }
+}