diff options
Diffstat (limited to 'core')
23 files changed, 1258 insertions, 440 deletions
diff --git a/core/src/main/java/spark/network/netty/FileClient.java b/core/src/main/java/spark/network/netty/FileClient.java new file mode 100644 index 0000000000..3a62dacbc8 --- /dev/null +++ b/core/src/main/java/spark/network/netty/FileClient.java @@ -0,0 +1,82 @@ +package spark.network.netty; + +import io.netty.bootstrap.Bootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelOption; +import io.netty.channel.oio.OioEventLoopGroup; +import io.netty.channel.socket.oio.OioSocketChannel; + + +class FileClient { + + private FileClientHandler handler = null; + private Channel channel = null; + private Bootstrap bootstrap = null; + + public FileClient(FileClientHandler handler) { + this.handler = handler; + } + + public void init() { + bootstrap = new Bootstrap(); + bootstrap.group(new OioEventLoopGroup()) + .channel(OioSocketChannel.class) + .option(ChannelOption.SO_KEEPALIVE, true) + .option(ChannelOption.TCP_NODELAY, true) + .handler(new FileClientChannelInitializer(handler)); + } + + public static final class ChannelCloseListener implements ChannelFutureListener { + private FileClient fc = null; + + public ChannelCloseListener(FileClient fc){ + this.fc = fc; + } + + @Override + public void operationComplete(ChannelFuture future) { + if (fc.bootstrap!=null){ + fc.bootstrap.shutdown(); + fc.bootstrap = null; + } + } + } + + public void connect(String host, int port) { + try { + // Start the connection attempt. + channel = bootstrap.connect(host, port).sync().channel(); + // ChannelFuture cf = channel.closeFuture(); + //cf.addListener(new ChannelCloseListener(this)); + } catch (InterruptedException e) { + close(); + } + } + + public void waitForClose() { + try { + channel.closeFuture().sync(); + } catch (InterruptedException e){ + e.printStackTrace(); + } + } + + public void sendRequest(String file) { + //assert(file == null); + //assert(channel == null); + channel.write(file + "\r\n"); + } + + public void close() { + if(channel != null) { + channel.close(); + channel = null; + } + if ( bootstrap!=null) { + bootstrap.shutdown(); + bootstrap = null; + } + } +} diff --git a/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java b/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java new file mode 100644 index 0000000000..af25baf641 --- /dev/null +++ b/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java @@ -0,0 +1,24 @@ +package spark.network.netty; + +import io.netty.buffer.BufType; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.socket.SocketChannel; +import io.netty.handler.codec.string.StringEncoder; + + +class FileClientChannelInitializer extends ChannelInitializer<SocketChannel> { + + private FileClientHandler fhandler; + + public FileClientChannelInitializer(FileClientHandler handler) { + fhandler = handler; + } + + @Override + public void initChannel(SocketChannel channel) { + // file no more than 2G + channel.pipeline() + .addLast("encoder", new StringEncoder(BufType.BYTE)) + .addLast("handler", fhandler); + } +} diff --git a/core/src/main/java/spark/network/netty/FileClientHandler.java b/core/src/main/java/spark/network/netty/FileClientHandler.java new file mode 100644 index 0000000000..2069dee5ca --- /dev/null +++ b/core/src/main/java/spark/network/netty/FileClientHandler.java @@ -0,0 +1,35 @@ +package spark.network.netty; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundByteHandlerAdapter; + + +abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter { + + private FileHeader currentHeader = null; + + public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header); + + @Override + public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) { + // Use direct buffer if possible. + return ctx.alloc().ioBuffer(); + } + + @Override + public void inboundBufferUpdated(ChannelHandlerContext ctx, ByteBuf in) { + // get header + if (currentHeader == null && in.readableBytes() >= FileHeader.HEADER_SIZE()) { + currentHeader = FileHeader.create(in.readBytes(FileHeader.HEADER_SIZE())); + } + // get file + if(in.readableBytes() >= currentHeader.fileLen()) { + handle(ctx, in, currentHeader); + currentHeader = null; + ctx.close(); + } + } + +} + diff --git a/core/src/main/java/spark/network/netty/FileServer.java b/core/src/main/java/spark/network/netty/FileServer.java new file mode 100644 index 0000000000..647b26bf8a --- /dev/null +++ b/core/src/main/java/spark/network/netty/FileServer.java @@ -0,0 +1,51 @@ +package spark.network.netty; + +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.ChannelOption; +import io.netty.channel.Channel; +import io.netty.channel.oio.OioEventLoopGroup; +import io.netty.channel.socket.oio.OioServerSocketChannel; + + +/** + * Server that accept the path of a file an echo back its content. + */ +class FileServer { + + private ServerBootstrap bootstrap = null; + private Channel channel = null; + private PathResolver pResolver; + + public FileServer(PathResolver pResolver) { + this.pResolver = pResolver; + } + + public void run(int port) { + // Configure the server. + bootstrap = new ServerBootstrap(); + try { + bootstrap.group(new OioEventLoopGroup(), new OioEventLoopGroup()) + .channel(OioServerSocketChannel.class) + .option(ChannelOption.SO_BACKLOG, 100) + .option(ChannelOption.SO_RCVBUF, 1500) + .childHandler(new FileServerChannelInitializer(pResolver)); + // Start the server. + channel = bootstrap.bind(port).sync().channel(); + channel.closeFuture().sync(); + } catch (InterruptedException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } finally{ + bootstrap.shutdown(); + } + } + + public void stop() { + if (channel!=null) { + channel.close(); + } + if (bootstrap != null) { + bootstrap.shutdown(); + } + } +} diff --git a/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java b/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java new file mode 100644 index 0000000000..8f1f5c65cd --- /dev/null +++ b/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java @@ -0,0 +1,25 @@ +package spark.network.netty; + +import io.netty.channel.ChannelInitializer; +import io.netty.channel.socket.SocketChannel; +import io.netty.handler.codec.DelimiterBasedFrameDecoder; +import io.netty.handler.codec.Delimiters; +import io.netty.handler.codec.string.StringDecoder; + + +class FileServerChannelInitializer extends ChannelInitializer<SocketChannel> { + + PathResolver pResolver; + + public FileServerChannelInitializer(PathResolver pResolver) { + this.pResolver = pResolver; + } + + @Override + public void initChannel(SocketChannel channel) { + channel.pipeline() + .addLast("framer", new DelimiterBasedFrameDecoder(8192, Delimiters.lineDelimiter())) + .addLast("strDecoder", new StringDecoder()) + .addLast("handler", new FileServerHandler(pResolver)); + } +} diff --git a/core/src/main/java/spark/network/netty/FileServerHandler.java b/core/src/main/java/spark/network/netty/FileServerHandler.java new file mode 100644 index 0000000000..a78eddb1b5 --- /dev/null +++ b/core/src/main/java/spark/network/netty/FileServerHandler.java @@ -0,0 +1,65 @@ +package spark.network.netty; + +import java.io.File; +import java.io.FileInputStream; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundMessageHandlerAdapter; +import io.netty.channel.DefaultFileRegion; + + +class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> { + + PathResolver pResolver; + + public FileServerHandler(PathResolver pResolver){ + this.pResolver = pResolver; + } + + @Override + public void messageReceived(ChannelHandlerContext ctx, String blockId) { + String path = pResolver.getAbsolutePath(blockId); + // if getFilePath returns null, close the channel + if (path == null) { + //ctx.close(); + return; + } + File file = new File(path); + if (file.exists()) { + if (!file.isFile()) { + //logger.info("Not a file : " + file.getAbsolutePath()); + ctx.write(new FileHeader(0, blockId).buffer()); + ctx.flush(); + return; + } + long length = file.length(); + if (length > Integer.MAX_VALUE || length <= 0) { + //logger.info("too large file : " + file.getAbsolutePath() + " of size "+ length); + ctx.write(new FileHeader(0, blockId).buffer()); + ctx.flush(); + return; + } + int len = new Long(length).intValue(); + //logger.info("Sending block "+blockId+" filelen = "+len); + //logger.info("header = "+ (new FileHeader(len, blockId)).buffer()); + ctx.write((new FileHeader(len, blockId)).buffer()); + try { + ctx.sendFile(new DefaultFileRegion(new FileInputStream(file) + .getChannel(), 0, file.length())); + } catch (Exception e) { + //logger.warning("Exception when sending file : " + file.getAbsolutePath()); + e.printStackTrace(); + } + } else { + //logger.warning("File not found: " + file.getAbsolutePath()); + ctx.write(new FileHeader(0, blockId).buffer()); + } + ctx.flush(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + cause.printStackTrace(); + ctx.close(); + } +} diff --git a/core/src/main/java/spark/network/netty/PathResolver.java b/core/src/main/java/spark/network/netty/PathResolver.java new file mode 100755 index 0000000000..302411672c --- /dev/null +++ b/core/src/main/java/spark/network/netty/PathResolver.java @@ -0,0 +1,12 @@ +package spark.network.netty;
+
+
+public interface PathResolver {
+ /**
+ * Get the absolute path of the file
+ *
+ * @param fileId
+ * @return the absolute path of file
+ */
+ public String getAbsolutePath(String fileId);
+}
diff --git a/core/src/main/scala/spark/network/BufferMessage.scala b/core/src/main/scala/spark/network/BufferMessage.scala new file mode 100644 index 0000000000..7b0e489a6c --- /dev/null +++ b/core/src/main/scala/spark/network/BufferMessage.scala @@ -0,0 +1,94 @@ +package spark.network + +import java.nio.ByteBuffer + +import scala.collection.mutable.ArrayBuffer + +import spark.storage.BlockManager + + +private[spark] +class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int) + extends Message(Message.BUFFER_MESSAGE, id_) { + + val initialSize = currentSize() + var gotChunkForSendingOnce = false + + def size = initialSize + + def currentSize() = { + if (buffers == null || buffers.isEmpty) { + 0 + } else { + buffers.map(_.remaining).reduceLeft(_ + _) + } + } + + def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = { + if (maxChunkSize <= 0) { + throw new Exception("Max chunk size is " + maxChunkSize) + } + + if (size == 0 && gotChunkForSendingOnce == false) { + val newChunk = new MessageChunk( + new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null) + gotChunkForSendingOnce = true + return Some(newChunk) + } + + while(!buffers.isEmpty) { + val buffer = buffers(0) + if (buffer.remaining == 0) { + BlockManager.dispose(buffer) + buffers -= buffer + } else { + val newBuffer = if (buffer.remaining <= maxChunkSize) { + buffer.duplicate() + } else { + buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer] + } + buffer.position(buffer.position + newBuffer.remaining) + val newChunk = new MessageChunk(new MessageChunkHeader( + typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer) + gotChunkForSendingOnce = true + return Some(newChunk) + } + } + None + } + + def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = { + // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer + if (buffers.size > 1) { + throw new Exception("Attempting to get chunk from message with multiple data buffers") + } + val buffer = buffers(0) + if (buffer.remaining > 0) { + if (buffer.remaining < chunkSize) { + throw new Exception("Not enough space in data buffer for receiving chunk") + } + val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer] + buffer.position(buffer.position + newBuffer.remaining) + val newChunk = new MessageChunk(new MessageChunkHeader( + typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer) + return Some(newChunk) + } + None + } + + def flip() { + buffers.foreach(_.flip) + } + + def hasAckId() = (ackId != 0) + + def isCompletelyReceived() = !buffers(0).hasRemaining + + override def toString = { + if (hasAckId) { + "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")" + } else { + "BufferMessage(id = " + id + ", size = " + size + ")" + } + } +}
\ No newline at end of file diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala index 00a0433a44..6e28f677a3 100644 --- a/core/src/main/scala/spark/network/Connection.scala +++ b/core/src/main/scala/spark/network/Connection.scala @@ -13,12 +13,13 @@ import java.net._ private[spark] abstract class Connection(val channel: SocketChannel, val selector: Selector, - val socketRemoteConnectionManagerId: ConnectionManagerId) extends Logging { + val socketRemoteConnectionManagerId: ConnectionManagerId) + extends Logging { + def this(channel_ : SocketChannel, selector_ : Selector) = { this(channel_, selector_, - ConnectionManagerId.fromSocketAddress( - channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress] - )) + ConnectionManagerId.fromSocketAddress( + channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress])) } channel.configureBlocking(false) @@ -32,17 +33,19 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, var onKeyInterestChangeCallback: (Connection, Int) => Unit = null val remoteAddress = getRemoteAddress() - + // Read channels typically do not register for write and write does not for read // Now, we do have write registering for read too (temporarily), but this is to detect // channel close NOT to actually read/consume data on it ! // How does this work if/when we move to SSL ? - + // What is the interest to register with selector for when we want this connection to be selected def registerInterest() - // What is the interest to register with selector for when we want this connection to be de-selected - // Traditionally, 0 - but in our case, for example, for close-detection on SendingConnection hack, it will be - // SelectionKey.OP_READ (until we fix it properly) + + // What is the interest to register with selector for when we want this connection to + // be de-selected + // Traditionally, 0 - but in our case, for example, for close-detection on SendingConnection hack, + // it will be SelectionKey.OP_READ (until we fix it properly) def unregisterInterest() // On receiving a read event, should we change the interest for this channel or not ? @@ -64,12 +67,14 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, // Returns whether we have to register for further reads or not. def read(): Boolean = { - throw new UnsupportedOperationException("Cannot read on connection of type " + this.getClass.toString) + throw new UnsupportedOperationException( + "Cannot read on connection of type " + this.getClass.toString) } // Returns whether we have to register for further writes or not. def write(): Boolean = { - throw new UnsupportedOperationException("Cannot write on connection of type " + this.getClass.toString) + throw new UnsupportedOperationException( + "Cannot write on connection of type " + this.getClass.toString) } def close() { @@ -81,11 +86,17 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, callOnCloseCallback() } - def onClose(callback: Connection => Unit) {onCloseCallback = callback} + def onClose(callback: Connection => Unit) { + onCloseCallback = callback + } - def onException(callback: (Connection, Exception) => Unit) {onExceptionCallback = callback} + def onException(callback: (Connection, Exception) => Unit) { + onExceptionCallback = callback + } - def onKeyInterestChange(callback: (Connection, Int) => Unit) {onKeyInterestChangeCallback = callback} + def onKeyInterestChange(callback: (Connection, Int) => Unit) { + onKeyInterestChangeCallback = callback + } def callOnExceptionCallback(e: Exception) { if (onExceptionCallback != null) { @@ -95,7 +106,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, " and OnExceptionCallback not registered", e) } } - + def callOnCloseCallback() { if (onCloseCallback != null) { onCloseCallback(this) @@ -132,24 +143,25 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, print(" (" + position + ", " + length + ")") buffer.position(curPosition) } - } -private[spark] class SendingConnection(val address: InetSocketAddress, selector_ : Selector, - remoteId_ : ConnectionManagerId) -extends Connection(SocketChannel.open, selector_, remoteId_) { +private[spark] +class SendingConnection(val address: InetSocketAddress, selector_ : Selector, + remoteId_ : ConnectionManagerId) + extends Connection(SocketChannel.open, selector_, remoteId_) { class Outbox(fair: Int = 0) { val messages = new Queue[Message]() - val defaultChunkSize = 65536 //32768 //16384 + val defaultChunkSize = 65536 //32768 //16384 var nextMessageToBeUsed = 0 def addMessage(message: Message) { - messages.synchronized{ + messages.synchronized{ /*messages += message*/ messages.enqueue(message) - logDebug("Added [" + message + "] to outbox for sending to [" + getRemoteConnectionManagerId() + "]") + logDebug("Added [" + message + "] to outbox for sending to " + + "[" + getRemoteConnectionManagerId() + "]") } } @@ -174,7 +186,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { message.started = true message.startTime = System.currentTimeMillis } - return chunk + return chunk } else { /*logInfo("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + "]")*/ message.finishTime = System.currentTimeMillis @@ -185,7 +197,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { } None } - + private def getChunkRR(): Option[MessageChunk] = { messages.synchronized { while (!messages.isEmpty) { @@ -197,12 +209,14 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { messages.enqueue(message) nextMessageToBeUsed = nextMessageToBeUsed + 1 if (!message.started) { - logDebug("Starting to send [" + message + "] to [" + getRemoteConnectionManagerId() + "]") + logDebug( + "Starting to send [" + message + "] to [" + getRemoteConnectionManagerId() + "]") message.started = true message.startTime = System.currentTimeMillis } - logTrace("Sending chunk from [" + message+ "] to [" + getRemoteConnectionManagerId() + "]") - return chunk + logTrace( + "Sending chunk from [" + message+ "] to [" + getRemoteConnectionManagerId() + "]") + return chunk } else { message.finishTime = System.currentTimeMillis logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + @@ -213,7 +227,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { None } } - + private val outbox = new Outbox(1) val currentBuffers = new ArrayBuffer[ByteBuffer]() @@ -228,11 +242,11 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { // it does - so let us keep it for now. changeConnectionKeyInterest(SelectionKey.OP_WRITE | DEFAULT_INTEREST) } - + override def unregisterInterest() { changeConnectionKeyInterest(DEFAULT_INTEREST) } - + def send(message: Message) { outbox.synchronized { outbox.addMessage(message) @@ -262,12 +276,14 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { // selection - though need not necessarily always complete successfully. val connected = channel.finishConnect if (!force && !connected) { - logInfo("finish connect failed [" + address + "], " + outbox.messages.size + " messages pending") + logInfo( + "finish connect failed [" + address + "], " + outbox.messages.size + " messages pending") return false } // Fallback to previous behavior - assume finishConnect completed - // This will happen only when finishConnect failed for some repeated number of times (10 or so) + // This will happen only when finishConnect failed for some repeated number of times + // (10 or so) // Is highly unlikely unless there was an unclean close of socket, etc registerInterest() logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending") @@ -283,13 +299,13 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { } override def write(): Boolean = { - try{ - while(true) { + try { + while (true) { if (currentBuffers.size == 0) { outbox.synchronized { outbox.getChunk() match { case Some(chunk) => { - currentBuffers ++= chunk.buffers + currentBuffers ++= chunk.buffers } case None => { // changeConnectionKeyInterest(0) @@ -299,7 +315,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { } } } - + if (currentBuffers.size > 0) { val buffer = currentBuffers(0) val remainingBytes = buffer.remaining @@ -314,7 +330,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { } } } catch { - case e: Exception => { + case e: Exception => { logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e) callOnExceptionCallback(e) close() @@ -336,7 +352,8 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { if (length == -1) { // EOF close() } else if (length > 0) { - logWarning("Unexpected data read from SendingConnection to " + getRemoteConnectionManagerId()) + logWarning( + "Unexpected data read from SendingConnection to " + getRemoteConnectionManagerId()) } } catch { case e: Exception => @@ -355,30 +372,32 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { // Must be created within selector loop - else deadlock -private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector) -extends Connection(channel_, selector_) { - +private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector) + extends Connection(channel_, selector_) { + class Inbox() { val messages = new HashMap[Int, BufferMessage]() - + def getChunk(header: MessageChunkHeader): Option[MessageChunk] = { - + def createNewMessage: BufferMessage = { val newMessage = Message.create(header).asInstanceOf[BufferMessage] newMessage.started = true newMessage.startTime = System.currentTimeMillis - logDebug("Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]") + logDebug( + "Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]") messages += ((newMessage.id, newMessage)) newMessage } - + val message = messages.getOrElseUpdate(header.id, createNewMessage) - logTrace("Receiving chunk of [" + message + "] from [" + getRemoteConnectionManagerId() + "]") + logTrace( + "Receiving chunk of [" + message + "] from [" + getRemoteConnectionManagerId() + "]") message.getChunkForReceiving(header.chunkSize) } - + def getMessageForChunk(chunk: MessageChunk): Option[BufferMessage] = { - messages.get(chunk.header.id) + messages.get(chunk.header.id) } def removeMessage(message: Message) { @@ -387,12 +406,14 @@ extends Connection(channel_, selector_) { } @volatile private var inferredRemoteManagerId: ConnectionManagerId = null + override def getRemoteConnectionManagerId(): ConnectionManagerId = { val currId = inferredRemoteManagerId if (currId != null) currId else super.getRemoteConnectionManagerId() } - // The reciever's remote address is the local socket on remote side : which is NOT the connection manager id of the receiver. + // The reciever's remote address is the local socket on remote side : which is NOT + // the connection manager id of the receiver. // We infer that from the messages we receive on the receiver socket. private def processConnectionManagerId(header: MessageChunkHeader) { val currId = inferredRemoteManagerId @@ -428,7 +449,8 @@ extends Connection(channel_, selector_) { } headerBuffer.flip if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) { - throw new Exception("Unexpected number of bytes (" + headerBuffer.remaining + ") in the header") + throw new Exception( + "Unexpected number of bytes (" + headerBuffer.remaining + ") in the header") } val header = MessageChunkHeader.create(headerBuffer) headerBuffer.clear() @@ -451,9 +473,9 @@ extends Connection(channel_, selector_) { case _ => throw new Exception("Message of unknown type received") } } - + if (currentChunk == null) throw new Exception("No message chunk to receive data") - + val bytesRead = channel.read(currentChunk.buffer) if (bytesRead == 0) { // re-register for read event ... @@ -464,14 +486,15 @@ extends Connection(channel_, selector_) { } /*logDebug("Read " + bytesRead + " bytes for the buffer")*/ - + if (currentChunk.buffer.remaining == 0) { /*println("Filled buffer at " + System.currentTimeMillis)*/ val bufferMessage = inbox.getMessageForChunk(currentChunk).get if (bufferMessage.isCompletelyReceived) { bufferMessage.flip bufferMessage.finishTime = System.currentTimeMillis - logDebug("Finished receiving [" + bufferMessage + "] from [" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken) + logDebug("Finished receiving [" + bufferMessage + "] from " + + "[" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken) if (onReceiveCallback != null) { onReceiveCallback(this, bufferMessage) } @@ -481,7 +504,7 @@ extends Connection(channel_, selector_) { } } } catch { - case e: Exception => { + case e: Exception => { logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e) callOnExceptionCallback(e) close() @@ -491,7 +514,7 @@ extends Connection(channel_, selector_) { // should not happen - to keep scala compiler happy return true } - + def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback} override def changeInterestForRead(): Boolean = true @@ -505,7 +528,7 @@ extends Connection(channel_, selector_) { // it does - so let us keep it for now. changeConnectionKeyInterest(SelectionKey.OP_READ) } - + override def unregisterInterest() { changeConnectionKeyInterest(0) } diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index 0eb03630d0..624a094856 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -18,20 +18,7 @@ import akka.dispatch.{Await, Promise, ExecutionContext, Future} import akka.util.Duration import akka.util.duration._ -private[spark] case class ConnectionManagerId(host: String, port: Int) { - // DEBUG code - Utils.checkHost(host) - assert (port > 0) - def toSocketAddress() = new InetSocketAddress(host, port) -} - -private[spark] object ConnectionManagerId { - def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = { - new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort()) - } -} - private[spark] class ConnectionManager(port: Int) extends Logging { class MessageStatus( @@ -45,7 +32,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging { def markDone() { completionHandler(this) } } - + private val selector = SelectorProvider.provider.openSelector() private val handleMessageExecutor = new ThreadPoolExecutor( @@ -80,7 +67,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging { serverChannel.configureBlocking(false) serverChannel.socket.setReuseAddress(true) - serverChannel.socket.setReceiveBufferSize(256 * 1024) + serverChannel.socket.setReceiveBufferSize(256 * 1024) serverChannel.socket.bind(new InetSocketAddress(port)) serverChannel.register(selector, SelectionKey.OP_ACCEPT) @@ -351,7 +338,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging { case e: Exception => logError("Error in select loop", e) } } - + def acceptConnection(key: SelectionKey) { val serverChannel = key.channel.asInstanceOf[ServerSocketChannel] @@ -463,7 +450,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging { def receiveMessage(connection: Connection, message: Message) { val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress) - logDebug("Received [" + message + "] from [" + connectionManagerId + "]") + logDebug("Received [" + message + "] from [" + connectionManagerId + "]") val runnable = new Runnable() { val creationTime = System.currentTimeMillis def run() { @@ -483,11 +470,11 @@ private[spark] class ConnectionManager(port: Int) extends Logging { if (bufferMessage.hasAckId) { val sentMessageStatus = messageStatuses.synchronized { messageStatuses.get(bufferMessage.ackId) match { - case Some(status) => { - messageStatuses -= bufferMessage.ackId + case Some(status) => { + messageStatuses -= bufferMessage.ackId status } - case None => { + case None => { throw new Exception("Could not find reference for received ack message " + message.id) null } @@ -507,7 +494,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging { logDebug("Not calling back as callback is null") None } - + if (ackMessage.isDefined) { if (!ackMessage.get.isInstanceOf[BufferMessage]) { logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " + ackMessage.get.getClass()) @@ -517,7 +504,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging { } } - sendMessage(connectionManagerId, ackMessage.getOrElse { + sendMessage(connectionManagerId, ackMessage.getOrElse { Message.createBufferMessage(bufferMessage.id) }) } @@ -588,17 +575,17 @@ private[spark] object ConnectionManager { def main(args: Array[String]) { val manager = new ConnectionManager(9999) - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { println("Received [" + msg + "] from [" + id + "]") None }) - + /*testSequentialSending(manager)*/ /*System.gc()*/ /*testParallelSending(manager)*/ /*System.gc()*/ - + /*testParallelDecreasingSending(manager)*/ /*System.gc()*/ @@ -610,9 +597,9 @@ private[spark] object ConnectionManager { println("--------------------------") println("Sequential Sending") println("--------------------------") - val size = 10 * 1024 * 1024 + val size = 10 * 1024 * 1024 val count = 10 - + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) buffer.flip @@ -628,7 +615,7 @@ private[spark] object ConnectionManager { println("--------------------------") println("Parallel Sending") println("--------------------------") - val size = 10 * 1024 * 1024 + val size = 10 * 1024 * 1024 val count = 10 val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) @@ -643,12 +630,12 @@ private[spark] object ConnectionManager { if (!g.isDefined) println("Failed") }) val finishTime = System.currentTimeMillis - + val mb = size * count / 1024.0 / 1024.0 val ms = finishTime - startTime val tput = mb * 1000.0 / ms println("--------------------------") - println("Started at " + startTime + ", finished at " + finishTime) + println("Started at " + startTime + ", finished at " + finishTime) println("Sent " + count + " messages of size " + size + " in " + ms + " ms (" + tput + " MB/s)") println("--------------------------") println() @@ -658,7 +645,7 @@ private[spark] object ConnectionManager { println("--------------------------") println("Parallel Decreasing Sending") println("--------------------------") - val size = 10 * 1024 * 1024 + val size = 10 * 1024 * 1024 val count = 10 val buffers = Array.tabulate(count)(i => ByteBuffer.allocate(size * (i + 1)).put(Array.tabulate[Byte](size * (i + 1))(x => x.toByte))) buffers.foreach(_.flip) @@ -673,7 +660,7 @@ private[spark] object ConnectionManager { if (!g.isDefined) println("Failed") }) val finishTime = System.currentTimeMillis - + val ms = finishTime - startTime val tput = mb * 1000.0 / ms println("--------------------------") @@ -687,7 +674,7 @@ private[spark] object ConnectionManager { println("--------------------------") println("Continuous Sending") println("--------------------------") - val size = 10 * 1024 * 1024 + val size = 10 * 1024 * 1024 val count = 10 val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) diff --git a/core/src/main/scala/spark/network/ConnectionManagerId.scala b/core/src/main/scala/spark/network/ConnectionManagerId.scala new file mode 100644 index 0000000000..b554e84251 --- /dev/null +++ b/core/src/main/scala/spark/network/ConnectionManagerId.scala @@ -0,0 +1,21 @@ +package spark.network + +import java.net.InetSocketAddress + +import spark.Utils + + +private[spark] case class ConnectionManagerId(host: String, port: Int) { + // DEBUG code + Utils.checkHost(host) + assert (port > 0) + + def toSocketAddress() = new InetSocketAddress(host, port) +} + + +private[spark] object ConnectionManagerId { + def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = { + new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort()) + } +} diff --git a/core/src/main/scala/spark/network/Message.scala b/core/src/main/scala/spark/network/Message.scala index 34fac9e776..d4f03610eb 100644 --- a/core/src/main/scala/spark/network/Message.scala +++ b/core/src/main/scala/spark/network/Message.scala @@ -1,56 +1,10 @@ package spark.network -import spark._ - -import scala.collection.mutable.ArrayBuffer - import java.nio.ByteBuffer -import java.net.InetAddress import java.net.InetSocketAddress -import storage.BlockManager - -private[spark] class MessageChunkHeader( - val typ: Long, - val id: Int, - val totalSize: Int, - val chunkSize: Int, - val other: Int, - val address: InetSocketAddress) { - lazy val buffer = { - // No need to change this, at 'use' time, we do a reverse lookup of the hostname. Refer to network.Connection - val ip = address.getAddress.getAddress() - val port = address.getPort() - ByteBuffer. - allocate(MessageChunkHeader.HEADER_SIZE). - putLong(typ). - putInt(id). - putInt(totalSize). - putInt(chunkSize). - putInt(other). - putInt(ip.size). - put(ip). - putInt(port). - position(MessageChunkHeader.HEADER_SIZE). - flip.asInstanceOf[ByteBuffer] - } - - override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ + - " and sizes " + totalSize + " / " + chunkSize + " bytes" -} -private[spark] class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { - val size = if (buffer == null) 0 else buffer.remaining - lazy val buffers = { - val ab = new ArrayBuffer[ByteBuffer]() - ab += header.buffer - if (buffer != null) { - ab += buffer - } - ab - } +import scala.collection.mutable.ArrayBuffer - override def toString = "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")" -} private[spark] abstract class Message(val typ: Long, val id: Int) { var senderAddress: InetSocketAddress = null @@ -59,120 +13,16 @@ private[spark] abstract class Message(val typ: Long, val id: Int) { var finishTime = -1L def size: Int - + def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] - + def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] - + def timeTaken(): String = (finishTime - startTime).toString + " ms" override def toString = this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")" } -private[spark] class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int) -extends Message(Message.BUFFER_MESSAGE, id_) { - - val initialSize = currentSize() - var gotChunkForSendingOnce = false - - def size = initialSize - - def currentSize() = { - if (buffers == null || buffers.isEmpty) { - 0 - } else { - buffers.map(_.remaining).reduceLeft(_ + _) - } - } - - def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = { - if (maxChunkSize <= 0) { - throw new Exception("Max chunk size is " + maxChunkSize) - } - - if (size == 0 && gotChunkForSendingOnce == false) { - val newChunk = new MessageChunk(new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null) - gotChunkForSendingOnce = true - return Some(newChunk) - } - - while(!buffers.isEmpty) { - val buffer = buffers(0) - if (buffer.remaining == 0) { - BlockManager.dispose(buffer) - buffers -= buffer - } else { - val newBuffer = if (buffer.remaining <= maxChunkSize) { - buffer.duplicate() - } else { - buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer] - } - buffer.position(buffer.position + newBuffer.remaining) - val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer) - gotChunkForSendingOnce = true - return Some(newChunk) - } - } - None - } - - def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = { - // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer - if (buffers.size > 1) { - throw new Exception("Attempting to get chunk from message with multiple data buffers") - } - val buffer = buffers(0) - if (buffer.remaining > 0) { - if (buffer.remaining < chunkSize) { - throw new Exception("Not enough space in data buffer for receiving chunk") - } - val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer] - buffer.position(buffer.position + newBuffer.remaining) - val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer) - return Some(newChunk) - } - None - } - - def flip() { - buffers.foreach(_.flip) - } - - def hasAckId() = (ackId != 0) - - def isCompletelyReceived() = !buffers(0).hasRemaining - - override def toString = { - if (hasAckId) { - "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")" - } else { - "BufferMessage(id = " + id + ", size = " + size + ")" - } - } -} - -private[spark] object MessageChunkHeader { - val HEADER_SIZE = 40 - - def create(buffer: ByteBuffer): MessageChunkHeader = { - if (buffer.remaining != HEADER_SIZE) { - throw new IllegalArgumentException("Cannot convert buffer data to Message") - } - val typ = buffer.getLong() - val id = buffer.getInt() - val totalSize = buffer.getInt() - val chunkSize = buffer.getInt() - val other = buffer.getInt() - val ipSize = buffer.getInt() - val ipBytes = new Array[Byte](ipSize) - buffer.get(ipBytes) - val ip = InetAddress.getByAddress(ipBytes) - val port = buffer.getInt() - new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port)) - } -} private[spark] object Message { val BUFFER_MESSAGE = 1111111111L @@ -181,14 +31,16 @@ private[spark] object Message { def getNewId() = synchronized { lastId += 1 - if (lastId == 0) lastId += 1 + if (lastId == 0) { + lastId += 1 + } lastId } def createBufferMessage(dataBuffers: Seq[ByteBuffer], ackId: Int): BufferMessage = { if (dataBuffers == null) { return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer], ackId) - } + } if (dataBuffers.exists(_ == null)) { throw new Exception("Attempting to create buffer message with null buffer") } @@ -197,7 +49,7 @@ private[spark] object Message { def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage = createBufferMessage(dataBuffers, 0) - + def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = { if (dataBuffer == null) { return createBufferMessage(Array(ByteBuffer.allocate(0)), ackId) @@ -205,15 +57,18 @@ private[spark] object Message { return createBufferMessage(Array(dataBuffer), ackId) } } - - def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage = + + def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage = createBufferMessage(dataBuffer, 0) - - def createBufferMessage(ackId: Int): BufferMessage = createBufferMessage(new Array[ByteBuffer](0), ackId) + + def createBufferMessage(ackId: Int): BufferMessage = { + createBufferMessage(new Array[ByteBuffer](0), ackId) + } def create(header: MessageChunkHeader): Message = { val newMessage: Message = header.typ match { - case BUFFER_MESSAGE => new BufferMessage(header.id, ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other) + case BUFFER_MESSAGE => new BufferMessage(header.id, + ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other) } newMessage.senderAddress = header.address newMessage diff --git a/core/src/main/scala/spark/network/MessageChunk.scala b/core/src/main/scala/spark/network/MessageChunk.scala new file mode 100644 index 0000000000..aaf9204d0e --- /dev/null +++ b/core/src/main/scala/spark/network/MessageChunk.scala @@ -0,0 +1,25 @@ +package spark.network + +import java.nio.ByteBuffer + +import scala.collection.mutable.ArrayBuffer + + +private[network] +class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { + + val size = if (buffer == null) 0 else buffer.remaining + + lazy val buffers = { + val ab = new ArrayBuffer[ByteBuffer]() + ab += header.buffer + if (buffer != null) { + ab += buffer + } + ab + } + + override def toString = { + "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")" + } +} diff --git a/core/src/main/scala/spark/network/MessageChunkHeader.scala b/core/src/main/scala/spark/network/MessageChunkHeader.scala new file mode 100644 index 0000000000..3693d509d6 --- /dev/null +++ b/core/src/main/scala/spark/network/MessageChunkHeader.scala @@ -0,0 +1,58 @@ +package spark.network + +import java.net.InetAddress +import java.net.InetSocketAddress +import java.nio.ByteBuffer + + +private[spark] class MessageChunkHeader( + val typ: Long, + val id: Int, + val totalSize: Int, + val chunkSize: Int, + val other: Int, + val address: InetSocketAddress) { + lazy val buffer = { + // No need to change this, at 'use' time, we do a reverse lookup of the hostname. + // Refer to network.Connection + val ip = address.getAddress.getAddress() + val port = address.getPort() + ByteBuffer. + allocate(MessageChunkHeader.HEADER_SIZE). + putLong(typ). + putInt(id). + putInt(totalSize). + putInt(chunkSize). + putInt(other). + putInt(ip.size). + put(ip). + putInt(port). + position(MessageChunkHeader.HEADER_SIZE). + flip.asInstanceOf[ByteBuffer] + } + + override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ + + " and sizes " + totalSize + " / " + chunkSize + " bytes" +} + + +private[spark] object MessageChunkHeader { + val HEADER_SIZE = 40 + + def create(buffer: ByteBuffer): MessageChunkHeader = { + if (buffer.remaining != HEADER_SIZE) { + throw new IllegalArgumentException("Cannot convert buffer data to Message") + } + val typ = buffer.getLong() + val id = buffer.getInt() + val totalSize = buffer.getInt() + val chunkSize = buffer.getInt() + val other = buffer.getInt() + val ipSize = buffer.getInt() + val ipBytes = new Array[Byte](ipSize) + buffer.get(ipBytes) + val ip = InetAddress.getByAddress(ipBytes) + val port = buffer.getInt() + new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port)) + } +} diff --git a/core/src/main/scala/spark/network/netty/FileHeader.scala b/core/src/main/scala/spark/network/netty/FileHeader.scala new file mode 100644 index 0000000000..aed4254234 --- /dev/null +++ b/core/src/main/scala/spark/network/netty/FileHeader.scala @@ -0,0 +1,57 @@ +package spark.network.netty + +import io.netty.buffer._ + +import spark.Logging + +private[spark] class FileHeader ( + val fileLen: Int, + val blockId: String) extends Logging { + + lazy val buffer = { + val buf = Unpooled.buffer() + buf.capacity(FileHeader.HEADER_SIZE) + buf.writeInt(fileLen) + buf.writeInt(blockId.length) + blockId.foreach((x: Char) => buf.writeByte(x)) + //padding the rest of header + if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) { + buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes) + } else { + throw new Exception("too long header " + buf.readableBytes) + logInfo("too long header") + } + buf + } + +} + +private[spark] object FileHeader { + + val HEADER_SIZE = 40 + + def getFileLenOffset = 0 + def getFileLenSize = Integer.SIZE/8 + + def create(buf: ByteBuf): FileHeader = { + val length = buf.readInt + val idLength = buf.readInt + val idBuilder = new StringBuilder(idLength) + for (i <- 1 to idLength) { + idBuilder += buf.readByte().asInstanceOf[Char] + } + val blockId = idBuilder.toString() + new FileHeader(length, blockId) + } + + + def main (args:Array[String]){ + + val header = new FileHeader(25,"block_0"); + val buf = header.buffer; + val newheader = FileHeader.create(buf); + System.out.println("id="+newheader.blockId+",size="+newheader.fileLen) + + } +} + diff --git a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala new file mode 100644 index 0000000000..a91f5a886d --- /dev/null +++ b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala @@ -0,0 +1,84 @@ +package spark.network.netty + +import java.util.concurrent.Executors + +import io.netty.buffer.ByteBuf +import io.netty.channel.ChannelHandlerContext +import io.netty.util.CharsetUtil + +import spark.Logging +import spark.network.ConnectionManagerId + + +private[spark] class ShuffleCopier extends Logging { + + def getBlock(cmId: ConnectionManagerId, blockId: String, + resultCollectCallback: (String, Long, ByteBuf) => Unit) { + + val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback) + val fc = new FileClient(handler) + fc.init() + fc.connect(cmId.host, cmId.port) + fc.sendRequest(blockId) + fc.waitForClose() + fc.close() + } + + def getBlocks(cmId: ConnectionManagerId, + blocks: Seq[(String, Long)], + resultCollectCallback: (String, Long, ByteBuf) => Unit) { + + for ((blockId, size) <- blocks) { + getBlock(cmId, blockId, resultCollectCallback) + } + } +} + + +private[spark] object ShuffleCopier extends Logging { + + private class ShuffleClientHandler(resultCollectCallBack: (String, Long, ByteBuf) => Unit) + extends FileClientHandler with Logging { + + override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) { + logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)"); + resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen)) + } + } + + def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) { + logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"") + } + + def runGetBlock(host:String, port:Int, file:String){ + val handler = new ShuffleClientHandler(echoResultCollectCallBack) + val fc = new FileClient(handler) + fc.init(); + fc.connect(host, port) + fc.sendRequest(file) + fc.waitForClose(); + fc.close() + } + + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println("Usage: ShuffleCopier <host> <port> <shuffle_block_id> <threads>") + System.exit(1) + } + val host = args(0) + val port = args(1).toInt + val file = args(2) + val threads = if (args.length > 3) args(3).toInt else 10 + + val copiers = Executors.newFixedThreadPool(80) + for (i <- Range(0, threads)) { + val runnable = new Runnable() { + def run() { + runGetBlock(host, port, file) + } + } + copiers.execute(runnable) + } + copiers.shutdown + } +} diff --git a/core/src/main/scala/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/spark/network/netty/ShuffleSender.scala new file mode 100644 index 0000000000..dc87fefc56 --- /dev/null +++ b/core/src/main/scala/spark/network/netty/ShuffleSender.scala @@ -0,0 +1,56 @@ +package spark.network.netty + +import java.io.File + +import spark.Logging + + +private[spark] class ShuffleSender(val port: Int, val pResolver: PathResolver) extends Logging { + val server = new FileServer(pResolver) + + Runtime.getRuntime().addShutdownHook( + new Thread() { + override def run() { + server.stop() + } + } + ) + + def start() { + server.run(port) + } +} + + +private[spark] object ShuffleSender { + + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println( + "Usage: ShuffleSender <port> <subDirsPerLocalDir> <list of shuffle_block_directories>") + System.exit(1) + } + + val port = args(0).toInt + val subDirsPerLocalDir = args(1).toInt + val localDirs = args.drop(2).map(new File(_)) + + val pResovler = new PathResolver { + override def getAbsolutePath(blockId: String): String = { + if (!blockId.startsWith("shuffle_")) { + throw new Exception("Block " + blockId + " is not a shuffle block") + } + // Figure out which local directory it hashes to, and which subdirectory in that + val hash = math.abs(blockId.hashCode) + val dirId = hash % localDirs.length + val subDirId = (hash / localDirs.length) % subDirsPerLocalDir + val subDir = new File(localDirs(dirId), "%02x".format(subDirId)) + val file = new File(subDir, blockId) + return file.getAbsolutePath + } + } + val sender = new ShuffleSender(port, pResovler) + + sender.start() + } +} diff --git a/core/src/main/scala/spark/storage/BlockFetchTracker.scala b/core/src/main/scala/spark/storage/BlockFetchTracker.scala index 993aece1f7..0718156b1b 100644 --- a/core/src/main/scala/spark/storage/BlockFetchTracker.scala +++ b/core/src/main/scala/spark/storage/BlockFetchTracker.scala @@ -1,10 +1,10 @@ package spark.storage private[spark] trait BlockFetchTracker { - def totalBlocks : Int - def numLocalBlocks: Int - def numRemoteBlocks: Int - def remoteFetchTime : Long - def fetchWaitTime: Long - def remoteBytesRead : Long + def totalBlocks : Int + def numLocalBlocks: Int + def numRemoteBlocks: Int + def remoteFetchTime : Long + def fetchWaitTime: Long + def remoteBytesRead : Long } diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala new file mode 100644 index 0000000000..43f835237c --- /dev/null +++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala @@ -0,0 +1,360 @@ +package spark.storage + +import java.nio.ByteBuffer +import java.util.concurrent.LinkedBlockingQueue + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashSet +import scala.collection.mutable.Queue + +import io.netty.buffer.ByteBuf + +import spark.Logging +import spark.Utils +import spark.SparkException +import spark.network.BufferMessage +import spark.network.ConnectionManagerId +import spark.network.netty.ShuffleCopier +import spark.serializer.Serializer + + +/** + * A block fetcher iterator interface. There are two implementations: + * + * BasicBlockFetcherIterator: uses a custom-built NIO communication layer. + * NettyBlockFetcherIterator: uses Netty (OIO) as the communication layer. + * + * Eventually we would like the two to converge and use a single NIO-based communication layer, + * but extensive tests show that under some circumstances (e.g. large shuffles with lots of cores), + * NIO would perform poorly and thus the need for the Netty OIO one. + */ + +private[storage] +trait BlockFetcherIterator extends Iterator[(String, Option[Iterator[Any]])] + with Logging with BlockFetchTracker { + def initialize() +} + + +private[storage] +object BlockFetcherIterator { + + // A request to fetch one or more blocks, complete with their sizes + class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) { + val size = blocks.map(_._2).sum + } + + // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize + // the block (since we want all deserializaton to happen in the calling thread); can also + // represent a fetch failure if size == -1. + class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) { + def failed: Boolean = size == -1 + } + + class BasicBlockFetcherIterator( + private val blockManager: BlockManager, + val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], + serializer: Serializer) + extends BlockFetcherIterator { + + import blockManager._ + + private var _remoteBytesRead = 0l + private var _remoteFetchTime = 0l + private var _fetchWaitTime = 0l + + if (blocksByAddress == null) { + throw new IllegalArgumentException("BlocksByAddress is null") + } + + protected var _totalBlocks = blocksByAddress.map(_._2.size).sum + logDebug("Getting " + _totalBlocks + " blocks") + protected var startTime = System.currentTimeMillis + protected val localBlockIds = new ArrayBuffer[String]() + protected val remoteBlockIds = new HashSet[String]() + + // A queue to hold our results. + protected val results = new LinkedBlockingQueue[FetchResult] + + // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that + // the number of bytes in flight is limited to maxBytesInFlight + private val fetchRequests = new Queue[FetchRequest] + + // Current bytes in flight from our requests + private var bytesInFlight = 0L + + protected def sendRequest(req: FetchRequest) { + logDebug("Sending request for %d blocks (%s) from %s".format( + req.blocks.size, Utils.memoryBytesToString(req.size), req.address.hostPort)) + val cmId = new ConnectionManagerId(req.address.host, req.address.port) + val blockMessageArray = new BlockMessageArray(req.blocks.map { + case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId)) + }) + bytesInFlight += req.size + val sizeMap = req.blocks.toMap // so we can look up the size of each blockID + val fetchStart = System.currentTimeMillis() + val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage) + future.onSuccess { + case Some(message) => { + val fetchDone = System.currentTimeMillis() + _remoteFetchTime += fetchDone - fetchStart + val bufferMessage = message.asInstanceOf[BufferMessage] + val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) + for (blockMessage <- blockMessageArray) { + if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) { + throw new SparkException( + "Unexpected message " + blockMessage.getType + " received from " + cmId) + } + val blockId = blockMessage.getId + results.put(new FetchResult(blockId, sizeMap(blockId), + () => dataDeserialize(blockId, blockMessage.getData, serializer))) + _remoteBytesRead += req.size + logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) + } + } + case None => { + logError("Could not get block(s) from " + cmId) + for ((blockId, size) <- req.blocks) { + results.put(new FetchResult(blockId, -1, null)) + } + } + } + } + + protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { + // Split local and remote blocks. Remote blocks are further split into FetchRequests of size + // at most maxBytesInFlight in order to limit the amount of data in flight. + val remoteRequests = new ArrayBuffer[FetchRequest] + for ((address, blockInfos) <- blocksByAddress) { + if (address == blockManagerId) { + localBlockIds ++= blockInfos.map(_._1) + } else { + remoteBlockIds ++= blockInfos.map(_._1) + // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them + // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 + // nodes, rather than blocking on reading output from one node. + val minRequestSize = math.max(maxBytesInFlight / 5, 1L) + logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize) + val iterator = blockInfos.iterator + var curRequestSize = 0L + var curBlocks = new ArrayBuffer[(String, Long)] + while (iterator.hasNext) { + val (blockId, size) = iterator.next() + curBlocks += ((blockId, size)) + curRequestSize += size + if (curRequestSize >= minRequestSize) { + // Add this FetchRequest + remoteRequests += new FetchRequest(address, curBlocks) + curRequestSize = 0 + curBlocks = new ArrayBuffer[(String, Long)] + } + } + // Add in the final request + if (!curBlocks.isEmpty) { + remoteRequests += new FetchRequest(address, curBlocks) + } + } + } + remoteRequests + } + + protected def getLocalBlocks() { + // Get the local blocks while remote blocks are being fetched. Note that it's okay to do + // these all at once because they will just memory-map some files, so they won't consume + // any memory that might exceed our maxBytesInFlight + for (id <- localBlockIds) { + getLocal(id) match { + case Some(iter) => { + // Pass 0 as size since it's not in flight + results.put(new FetchResult(id, 0, () => iter)) + logDebug("Got local block " + id) + } + case None => { + throw new BlockException(id, "Could not get block " + id + " from local machine") + } + } + } + } + + override def initialize() { + // Split local and remote blocks. + val remoteRequests = splitLocalRemoteBlocks() + // Add the remote requests into our queue in a random order + fetchRequests ++= Utils.randomize(remoteRequests) + + // Send out initial requests for blocks, up to our maxBytesInFlight + while (!fetchRequests.isEmpty && + (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { + sendRequest(fetchRequests.dequeue()) + } + + val numGets = remoteBlockIds.size - fetchRequests.size + logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime)) + + // Get Local Blocks + startTime = System.currentTimeMillis + getLocalBlocks() + logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") + } + + //an iterator that will read fetched blocks off the queue as they arrive. + @volatile protected var resultsGotten = 0 + + override def hasNext: Boolean = resultsGotten < _totalBlocks + + override def next(): (String, Option[Iterator[Any]]) = { + resultsGotten += 1 + val startFetchWait = System.currentTimeMillis() + val result = results.take() + val stopFetchWait = System.currentTimeMillis() + _fetchWaitTime += (stopFetchWait - startFetchWait) + if (! result.failed) bytesInFlight -= result.size + while (!fetchRequests.isEmpty && + (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { + sendRequest(fetchRequests.dequeue()) + } + (result.blockId, if (result.failed) None else Some(result.deserialize())) + } + + // Implementing BlockFetchTracker trait. + override def totalBlocks: Int = _totalBlocks + override def numLocalBlocks: Int = localBlockIds.size + override def numRemoteBlocks: Int = remoteBlockIds.size + override def remoteFetchTime: Long = _remoteFetchTime + override def fetchWaitTime: Long = _fetchWaitTime + override def remoteBytesRead: Long = _remoteBytesRead + } + // End of BasicBlockFetcherIterator + + class NettyBlockFetcherIterator( + blockManager: BlockManager, + blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], + serializer: Serializer) + extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer) { + + import blockManager._ + + val fetchRequestsSync = new LinkedBlockingQueue[FetchRequest] + + private def startCopiers(numCopiers: Int): List[_ <: Thread] = { + (for ( i <- Range(0,numCopiers) ) yield { + val copier = new Thread { + override def run(){ + try { + while(!isInterrupted && !fetchRequestsSync.isEmpty) { + sendRequest(fetchRequestsSync.take()) + } + } catch { + case x: InterruptedException => logInfo("Copier Interrupted") + //case _ => throw new SparkException("Exception Throw in Shuffle Copier") + } + } + } + copier.start + copier + }).toList + } + + //keep this to interrupt the threads when necessary + private def stopCopiers() { + for (copier <- copiers) { + copier.interrupt() + } + } + + override protected def sendRequest(req: FetchRequest) { + + def putResult(blockId: String, blockSize: Long, blockData: ByteBuf) { + val fetchResult = new FetchResult(blockId, blockSize, + () => dataDeserialize(blockId, blockData.nioBuffer, serializer)) + results.put(fetchResult) + } + + logDebug("Sending request for %d blocks (%s) from %s".format( + req.blocks.size, Utils.memoryBytesToString(req.size), req.address.host)) + val cmId = new ConnectionManagerId( + req.address.host, System.getProperty("spark.shuffle.sender.port", "6653").toInt) + val cpier = new ShuffleCopier + cpier.getBlocks(cmId, req.blocks, putResult) + logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.host ) + } + + override protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { + // Split local and remote blocks. Remote blocks are further split into FetchRequests of size + // at most maxBytesInFlight in order to limit the amount of data in flight. + val originalTotalBlocks = _totalBlocks; + val remoteRequests = new ArrayBuffer[FetchRequest] + for ((address, blockInfos) <- blocksByAddress) { + if (address == blockManagerId) { + localBlockIds ++= blockInfos.map(_._1) + } else { + remoteBlockIds ++= blockInfos.map(_._1) + // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them + // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 + // nodes, rather than blocking on reading output from one node. + val minRequestSize = math.max(maxBytesInFlight / 5, 1L) + logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize) + val iterator = blockInfos.iterator + var curRequestSize = 0L + var curBlocks = new ArrayBuffer[(String, Long)] + while (iterator.hasNext) { + val (blockId, size) = iterator.next() + if (size > 0) { + curBlocks += ((blockId, size)) + curRequestSize += size + } else if (size == 0) { + //here we changes the totalBlocks + _totalBlocks -= 1 + } else { + throw new BlockException(blockId, "Negative block size " + size) + } + if (curRequestSize >= minRequestSize) { + // Add this FetchRequest + remoteRequests += new FetchRequest(address, curBlocks) + curRequestSize = 0 + curBlocks = new ArrayBuffer[(String, Long)] + } + } + // Add in the final request + if (!curBlocks.isEmpty) { + remoteRequests += new FetchRequest(address, curBlocks) + } + } + } + logInfo("Getting " + _totalBlocks + " non-zero-bytes blocks out of " + + originalTotalBlocks + " blocks") + remoteRequests + } + + private var copiers: List[_ <: Thread] = null + + override def initialize() { + // Split Local Remote Blocks and adjust totalBlocks to include only the non 0-byte blocks + val remoteRequests = splitLocalRemoteBlocks() + // Add the remote requests into our queue in a random order + for (request <- Utils.randomize(remoteRequests)) { + fetchRequestsSync.put(request) + } + + copiers = startCopiers(System.getProperty("spark.shuffle.copier.threads", "6").toInt) + logInfo("Started " + fetchRequestsSync.size + " remote gets in " + + Utils.getUsedTimeMs(startTime)) + + // Get Local Blocks + startTime = System.currentTimeMillis + getLocalBlocks() + logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") + } + + override def next(): (String, Option[Iterator[Any]]) = { + resultsGotten += 1 + val result = results.take() + // if all the results has been retrieved, shutdown the copiers + if (resultsGotten == _totalBlocks && copiers != null) { + stopCopiers() + } + (result.blockId, if (result.failed) None else Some(result.deserialize())) + } + } + // End of NettyBlockFetcherIterator +} diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 72962a3ceb..40d608628e 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -2,7 +2,6 @@ package spark.storage import java.io.{InputStream, OutputStream} import java.nio.{ByteBuffer, MappedByteBuffer} -import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue} import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet, Queue} import scala.collection.JavaConversions._ @@ -24,8 +23,7 @@ import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStam import sun.nio.ch.DirectBuffer -private[spark] -class BlockManager( +private[spark] class BlockManager( executorId: String, actorSystem: ActorSystem, val master: BlockManagerMaster, @@ -484,7 +482,16 @@ class BlockManager( def getMultiple( blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], serializer: Serializer) : BlockFetcherIterator = { - return new BlockFetcherIterator(this, blocksByAddress, serializer) + + val iter = + if (System.getProperty("spark.shuffle.use.netty", "false").toBoolean) { + new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer) + } else { + new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer) + } + + iter.initialize() + iter } def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean) @@ -928,8 +935,8 @@ class BlockManager( } } -private[spark] -object BlockManager extends Logging { + +private[spark] object BlockManager extends Logging { val ID_GENERATOR = new IdGenerator @@ -962,7 +969,7 @@ object BlockManager extends Logging { def blockIdsToExecutorLocations(blockIds: Array[String], env: SparkEnv, blockManagerMaster: BlockManagerMaster = null): HashMap[String, List[String]] = { // env == null and blockManagerMaster != null is used in tests assert (env != null || blockManagerMaster != null) - val locationBlockIds: Seq[Seq[BlockManagerId]] = + val locationBlockIds: Seq[Seq[BlockManagerId]] = if (env != null) { val blockManager = env.blockManager blockManager.getLocationBlockIds(blockIds) @@ -1000,176 +1007,3 @@ object BlockManager extends Logging { } -class BlockFetcherIterator( - private val blockManager: BlockManager, - val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], - serializer: Serializer -) extends Iterator[(String, Option[Iterator[Any]])] with Logging with BlockFetchTracker { - - import blockManager._ - - private var _remoteBytesRead = 0l - private var _remoteFetchTime = 0l - private var _fetchWaitTime = 0l - - if (blocksByAddress == null) { - throw new IllegalArgumentException("BlocksByAddress is null") - } - val totalBlocks = blocksByAddress.map(_._2.size).sum - logDebug("Getting " + totalBlocks + " blocks") - var startTime = System.currentTimeMillis - val localBlockIds = new ArrayBuffer[String]() - val remoteBlockIds = new HashSet[String]() - - // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize - // the block (since we want all deserializaton to happen in the calling thread); can also - // represent a fetch failure if size == -1. - class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) { - def failed: Boolean = size == -1 - } - - // A queue to hold our results. - val results = new LinkedBlockingQueue[FetchResult] - - // A request to fetch one or more blocks, complete with their sizes - class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) { - val size = blocks.map(_._2).sum - } - - // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that - // the number of bytes in flight is limited to maxBytesInFlight - val fetchRequests = new Queue[FetchRequest] - - // Current bytes in flight from our requests - var bytesInFlight = 0L - - def sendRequest(req: FetchRequest) { - logDebug("Sending request for %d blocks (%s) from %s".format( - req.blocks.size, Utils.memoryBytesToString(req.size), req.address.hostPort)) - val cmId = new ConnectionManagerId(req.address.host, req.address.port) - val blockMessageArray = new BlockMessageArray(req.blocks.map { - case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId)) - }) - bytesInFlight += req.size - val sizeMap = req.blocks.toMap // so we can look up the size of each blockID - val fetchStart = System.currentTimeMillis() - val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage) - future.onSuccess { - case Some(message) => { - val fetchDone = System.currentTimeMillis() - _remoteFetchTime += fetchDone - fetchStart - val bufferMessage = message.asInstanceOf[BufferMessage] - val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) - for (blockMessage <- blockMessageArray) { - if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) { - throw new SparkException( - "Unexpected message " + blockMessage.getType + " received from " + cmId) - } - val blockId = blockMessage.getId - results.put(new FetchResult(blockId, sizeMap(blockId), - () => dataDeserialize(blockId, blockMessage.getData, serializer))) - _remoteBytesRead += req.size - logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) - } - } - case None => { - logError("Could not get block(s) from " + cmId) - for ((blockId, size) <- req.blocks) { - results.put(new FetchResult(blockId, -1, null)) - } - } - } - } - - // Split local and remote blocks. Remote blocks are further split into FetchRequests of size - // at most maxBytesInFlight in order to limit the amount of data in flight. - val remoteRequests = new ArrayBuffer[FetchRequest] - for ((address, blockInfos) <- blocksByAddress) { - if (address == blockManagerId) { - localBlockIds ++= blockInfos.map(_._1) - } else { - remoteBlockIds ++= blockInfos.map(_._1) - // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them - // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 - // nodes, rather than blocking on reading output from one node. - val minRequestSize = math.max(maxBytesInFlight / 5, 1L) - logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize) - val iterator = blockInfos.iterator - var curRequestSize = 0L - var curBlocks = new ArrayBuffer[(String, Long)] - while (iterator.hasNext) { - val (blockId, size) = iterator.next() - curBlocks += ((blockId, size)) - curRequestSize += size - if (curRequestSize >= minRequestSize) { - // Add this FetchRequest - remoteRequests += new FetchRequest(address, curBlocks) - curRequestSize = 0 - curBlocks = new ArrayBuffer[(String, Long)] - } - } - // Add in the final request - if (!curBlocks.isEmpty) { - remoteRequests += new FetchRequest(address, curBlocks) - } - } - } - // Add the remote requests into our queue in a random order - fetchRequests ++= Utils.randomize(remoteRequests) - - // Send out initial requests for blocks, up to our maxBytesInFlight - while (!fetchRequests.isEmpty && - (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { - sendRequest(fetchRequests.dequeue()) - } - - val numGets = remoteBlockIds.size - fetchRequests.size - logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime)) - - // Get the local blocks while remote blocks are being fetched. Note that it's okay to do - // these all at once because they will just memory-map some files, so they won't consume - // any memory that might exceed our maxBytesInFlight - startTime = System.currentTimeMillis - for (id <- localBlockIds) { - getLocalFromDisk(id, serializer) match { - case Some(iter) => { - results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight - logDebug("Got local block " + id) - } - case None => { - throw new BlockException(id, "Could not get block " + id + " from local machine") - } - } - } - logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") - - //an iterator that will read fetched blocks off the queue as they arrive. - @volatile private var resultsGotten = 0 - - def hasNext: Boolean = resultsGotten < totalBlocks - - def next(): (String, Option[Iterator[Any]]) = { - resultsGotten += 1 - val startFetchWait = System.currentTimeMillis() - val result = results.take() - val stopFetchWait = System.currentTimeMillis() - _fetchWaitTime += (stopFetchWait - startFetchWait) - if (! result.failed) bytesInFlight -= result.size - while (!fetchRequests.isEmpty && - (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { - sendRequest(fetchRequests.dequeue()) - } - (result.blockId, if (result.failed) None else Some(result.deserialize())) - } - - - //methods to profile the block fetching - def numLocalBlocks = localBlockIds.size - def numRemoteBlocks = remoteBlockIds.size - - def remoteFetchTime = _remoteFetchTime - def fetchWaitTime = _fetchWaitTime - - def remoteBytesRead = _remoteBytesRead - -} diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index 8154b8ca74..933eeaa216 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -14,13 +14,16 @@ import it.unimi.dsi.fastutil.io.FastBufferedOutputStream import spark.Utils import spark.executor.ExecutorExitCode import spark.serializer.{Serializer, SerializationStream} +import spark.Logging +import spark.network.netty.ShuffleSender +import spark.network.netty.PathResolver /** * Stores BlockManager blocks on disk. */ private class DiskStore(blockManager: BlockManager, rootDirs: String) - extends BlockStore(blockManager) { + extends BlockStore(blockManager) with Logging { class DiskBlockObjectWriter(blockId: String, serializer: Serializer, bufferSize: Int) extends BlockObjectWriter(blockId) { @@ -79,19 +82,28 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) val MAX_DIR_CREATION_ATTEMPTS: Int = 10 val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt + var shuffleSender : Thread = null + val thisInstance = this // Create one local directory for each path mentioned in spark.local.dir; then, inside this // directory, create multiple subdirectories that we will hash files into, in order to avoid // having really large inodes at the top level. val localDirs = createLocalDirs() val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) + val useNetty = System.getProperty("spark.shuffle.use.netty", "false").toBoolean + addShutdownHook() + if(useNetty){ + startShuffleBlockSender() + } + def getBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int) : BlockObjectWriter = { new DiskBlockObjectWriter(blockId, serializer, bufferSize) } + override def getSize(blockId: String): Long = { getFile(blockId).length() } @@ -262,10 +274,48 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) localDirs.foreach { localDir => if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir) } + if (useNetty && shuffleSender != null) + shuffleSender.stop } catch { case t: Throwable => logError("Exception while deleting local spark dirs", t) } } }) } + + private def startShuffleBlockSender() { + try { + val port = System.getProperty("spark.shuffle.sender.port", "6653").toInt + + val pResolver = new PathResolver { + override def getAbsolutePath(blockId: String): String = { + if (!blockId.startsWith("shuffle_")) { + return null + } + thisInstance.getFile(blockId).getAbsolutePath() + } + } + shuffleSender = new Thread { + override def run() = { + val sender = new ShuffleSender(port, pResolver) + logInfo("Created ShuffleSender binding to port : "+ port) + sender.start + } + } + shuffleSender.setDaemon(true) + shuffleSender.start + + } catch { + case interrupted: InterruptedException => + logInfo("Runner thread for ShuffleBlockSender interrupted") + + case e: Exception => { + logError("Error running ShuffleBlockSender ", e) + if (shuffleSender != null) { + shuffleSender.stop + shuffleSender = null + } + } + } + } } diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala index 33c99471c6..06a94ed24c 100644 --- a/core/src/test/scala/spark/DistributedSuite.scala +++ b/core/src/test/scala/spark/DistributedSuite.scala @@ -18,10 +18,13 @@ import scala.collection.mutable.ArrayBuffer import SparkContext._ import storage.{GetBlock, BlockManagerWorker, StorageLevel} + class NotSerializableClass class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {} -class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter with LocalSparkContext { + +class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter + with LocalSparkContext { val clusterUrl = "local-cluster[2,1,512]" diff --git a/core/src/test/scala/spark/ShuffleNettySuite.scala b/core/src/test/scala/spark/ShuffleNettySuite.scala new file mode 100644 index 0000000000..bfaffa953e --- /dev/null +++ b/core/src/test/scala/spark/ShuffleNettySuite.scala @@ -0,0 +1,17 @@ +package spark + +import org.scalatest.BeforeAndAfterAll + + +class ShuffleNettySuite extends ShuffleSuite with BeforeAndAfterAll { + + // This test suite should run all tests in ShuffleSuite with Netty shuffle mode. + + override def beforeAll(configMap: Map[String, Any]) { + System.setProperty("spark.shuffle.use.netty", "true") + } + + override def afterAll(configMap: Map[String, Any]) { + System.setProperty("spark.shuffle.use.netty", "false") + } +} |