diff options
author | shane-huang <shengsheng.huang@intel.com> | 2013-02-20 11:51:13 +0800 |
---|---|---|
committer | shane-huang <shengsheng.huang@intel.com> | 2013-04-07 14:37:12 +0800 |
commit | df47b40b764e25cbd10ce49d7152e1d33f51a263 (patch) | |
tree | 46b3efea5434f02a5ee87c0b31c3bb59c74b773c | |
parent | dfe98ca798d84e6847330012f1332d9271156534 (diff) | |
download | spark-df47b40b764e25cbd10ce49d7152e1d33f51a263.tar.gz spark-df47b40b764e25cbd10ce49d7152e1d33f51a263.tar.bz2 spark-df47b40b764e25cbd10ce49d7152e1d33f51a263.zip |
Shuffle Performance fix: Use netty embeded OIO file server instead of ConnectionManager
Shuffle Performance Optimization: do not send 0-byte block requests to reduce network messages
change reference from io.Source to scala.io.Source to avoid looking into io.netty package
Signed-off-by: shane-huang <shengsheng.huang@intel.com>
14 files changed, 795 insertions, 56 deletions
diff --git a/core/src/main/java/spark/network/netty/FileClient.java b/core/src/main/java/spark/network/netty/FileClient.java new file mode 100644 index 0000000000..d0c5081dd2 --- /dev/null +++ b/core/src/main/java/spark/network/netty/FileClient.java @@ -0,0 +1,89 @@ +package spark.network.netty; + +import io.netty.bootstrap.Bootstrap; +import io.netty.channel.AbstractChannel; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelOption; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.channel.oio.OioEventLoopGroup; +import io.netty.channel.socket.oio.OioSocketChannel; + +import java.util.Arrays; + +public class FileClient { + + private FileClientHandler handler = null; + private Channel channel = null; + private Bootstrap bootstrap = null; + + public FileClient(FileClientHandler handler){ + this.handler = handler; + } + + public void init(){ + bootstrap = new Bootstrap(); + bootstrap.group(new OioEventLoopGroup()) + .channel(OioSocketChannel.class) + .option(ChannelOption.SO_KEEPALIVE, true) + .option(ChannelOption.TCP_NODELAY, true) + .handler(new FileClientChannelInitializer(handler)); + } + + public static final class ChannelCloseListener implements ChannelFutureListener { + private FileClient fc = null; + public ChannelCloseListener(FileClient fc){ + this.fc = fc; + } + @Override + public void operationComplete(ChannelFuture future) { + if (fc.bootstrap!=null){ + fc.bootstrap.shutdown(); + fc.bootstrap = null; + } + } + } + + public void connect(String host, int port){ + try { + + // Start the connection attempt. + channel = bootstrap.connect(host, port).sync().channel(); + // ChannelFuture cf = channel.closeFuture(); + //cf.addListener(new ChannelCloseListener(this)); + } catch (InterruptedException e) { + close(); + } + } + + public void waitForClose(){ + try { + channel.closeFuture().sync(); + } catch (InterruptedException e){ + e.printStackTrace(); + } + } + + public void sendRequest(String file){ + //assert(file == null); + //assert(channel == null); + channel.write(file+"\r\n"); + } + + public void close(){ + if(channel != null) { + channel.close(); + channel = null; + } + if ( bootstrap!=null) { + bootstrap.shutdown(); + bootstrap = null; + } + } + + +} + + diff --git a/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java b/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java new file mode 100644 index 0000000000..50e5704619 --- /dev/null +++ b/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java @@ -0,0 +1,29 @@ +package spark.network.netty; + +import io.netty.buffer.BufType; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.socket.SocketChannel; +import io.netty.handler.codec.LengthFieldBasedFrameDecoder; +import io.netty.handler.codec.string.StringEncoder; +import io.netty.util.CharsetUtil; + +import io.netty.handler.logging.LoggingHandler; +import io.netty.handler.logging.LogLevel; + +public class FileClientChannelInitializer extends + ChannelInitializer<SocketChannel> { + + private FileClientHandler fhandler; + + public FileClientChannelInitializer(FileClientHandler handler) { + fhandler = handler; + } + + @Override + public void initChannel(SocketChannel channel) { + // file no more than 2G + channel.pipeline() + .addLast("encoder", new StringEncoder(BufType.BYTE)) + .addLast("handler", fhandler); + } +} diff --git a/core/src/main/java/spark/network/netty/FileClientHandler.java b/core/src/main/java/spark/network/netty/FileClientHandler.java new file mode 100644 index 0000000000..911c8b32b5 --- /dev/null +++ b/core/src/main/java/spark/network/netty/FileClientHandler.java @@ -0,0 +1,38 @@ +package spark.network.netty; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundByteHandlerAdapter; +import io.netty.util.CharsetUtil; + +import java.util.concurrent.atomic.AtomicInteger; +import java.util.logging.Logger; + +public abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter { + + private FileHeader currentHeader = null; + + public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header); + + @Override + public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) { + // Use direct buffer if possible. + return ctx.alloc().ioBuffer(); + } + + @Override + public void inboundBufferUpdated(ChannelHandlerContext ctx, ByteBuf in) { + // get header + if (currentHeader == null && in.readableBytes() >= FileHeader.HEADER_SIZE()) { + currentHeader = FileHeader.create(in.readBytes(FileHeader.HEADER_SIZE())); + } + // get file + if(in.readableBytes() >= currentHeader.fileLen()){ + handle(ctx,in,currentHeader); + currentHeader = null; + ctx.close(); + } + } + +} + diff --git a/core/src/main/java/spark/network/netty/FileServer.java b/core/src/main/java/spark/network/netty/FileServer.java new file mode 100644 index 0000000000..729e45f0a1 --- /dev/null +++ b/core/src/main/java/spark/network/netty/FileServer.java @@ -0,0 +1,59 @@ +package spark.network.netty; + +import java.io.File; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelOption; +import io.netty.channel.Channel; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.oio.OioEventLoopGroup; +import io.netty.channel.socket.oio.OioServerSocketChannel; +import io.netty.handler.logging.LogLevel; +import io.netty.handler.logging.LoggingHandler; + +/** + * Server that accept the path of a file an echo back its content. + */ +public class FileServer { + + private ServerBootstrap bootstrap = null; + private Channel channel = null; + private PathResolver pResolver; + + public FileServer(PathResolver pResolver){ + this.pResolver = pResolver; + } + + public void run(int port) { + // Configure the server. + bootstrap = new ServerBootstrap(); + try { + bootstrap.group(new OioEventLoopGroup(), new OioEventLoopGroup()) + .channel(OioServerSocketChannel.class) + .option(ChannelOption.SO_BACKLOG, 100) + .option(ChannelOption.SO_RCVBUF, 1500) + .childHandler(new FileServerChannelInitializer(pResolver)); + // Start the server. + channel = bootstrap.bind(port).sync().channel(); + channel.closeFuture().sync(); + } catch (InterruptedException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } finally{ + bootstrap.shutdown(); + } + } + + public void stop(){ + if (channel!=null){ + channel.close(); + } + if (bootstrap != null){ + bootstrap.shutdown(); + bootstrap = null; + } + } +} + + diff --git a/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java b/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java new file mode 100644 index 0000000000..9d0618ff1c --- /dev/null +++ b/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java @@ -0,0 +1,33 @@ +package spark.network.netty; + +import java.io.File; +import io.netty.buffer.BufType; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.socket.SocketChannel; +import io.netty.handler.codec.string.StringDecoder; +import io.netty.handler.codec.string.StringEncoder; +import io.netty.handler.codec.DelimiterBasedFrameDecoder; +import io.netty.handler.codec.Delimiters; +import io.netty.util.CharsetUtil; +import io.netty.handler.logging.LoggingHandler; +import io.netty.handler.logging.LogLevel; + +public class FileServerChannelInitializer extends + ChannelInitializer<SocketChannel> { + + PathResolver pResolver; + + public FileServerChannelInitializer(PathResolver pResolver) { + this.pResolver = pResolver; + } + + @Override + public void initChannel(SocketChannel channel) { + channel.pipeline() + .addLast("framer", new DelimiterBasedFrameDecoder( + 8192, Delimiters.lineDelimiter())) + .addLast("strDecoder", new StringDecoder()) + .addLast("handler", new FileServerHandler(pResolver)); + + } +} diff --git a/core/src/main/java/spark/network/netty/FileServerHandler.java b/core/src/main/java/spark/network/netty/FileServerHandler.java new file mode 100644 index 0000000000..e1083e87a2 --- /dev/null +++ b/core/src/main/java/spark/network/netty/FileServerHandler.java @@ -0,0 +1,68 @@ +package spark.network.netty; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundMessageHandlerAdapter; +import io.netty.channel.DefaultFileRegion; +import io.netty.handler.stream.ChunkedFile; +import java.io.File; +import java.io.FileInputStream; + +public class FileServerHandler extends + ChannelInboundMessageHandlerAdapter<String> { + + PathResolver pResolver; + + public FileServerHandler(PathResolver pResolver){ + this.pResolver = pResolver; + } + + @Override + public void messageReceived(ChannelHandlerContext ctx, String blockId) { + String path = pResolver.getAbsolutePath(blockId); + // if getFilePath returns null, close the channel + if (path == null) { + //ctx.close(); + return; + } + File file = new File(path); + if (file.exists()) { + if (!file.isFile()) { + //logger.info("Not a file : " + file.getAbsolutePath()); + ctx.write(new FileHeader(0, blockId).buffer()); + ctx.flush(); + return; + } + long length = file.length(); + if (length > Integer.MAX_VALUE || length <= 0 ) { + //logger.info("too large file : " + file.getAbsolutePath() + " of size "+ length); + ctx.write(new FileHeader(0, blockId).buffer()); + ctx.flush(); + return; + } + int len = new Long(length).intValue(); + //logger.info("Sending block "+blockId+" filelen = "+len); + //logger.info("header = "+ (new FileHeader(len, blockId)).buffer()); + ctx.write((new FileHeader(len, blockId)).buffer()); + try { + ctx.sendFile(new DefaultFileRegion(new FileInputStream(file) + .getChannel(), 0, file.length())); + } catch (Exception e) { + // TODO Auto-generated catch block + //logger.warning("Exception when sending file : " + //+ file.getAbsolutePath()); + e.printStackTrace(); + } + } else { + //logger.warning("File not found: " + file.getAbsolutePath()); + ctx.write(new FileHeader(0, blockId).buffer()); + } + ctx.flush(); + } + + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + cause.printStackTrace(); + ctx.close(); + } +} diff --git a/core/src/main/java/spark/network/netty/PathResolver.java b/core/src/main/java/spark/network/netty/PathResolver.java new file mode 100755 index 0000000000..5d5eda006e --- /dev/null +++ b/core/src/main/java/spark/network/netty/PathResolver.java @@ -0,0 +1,12 @@ +package spark.network.netty;
+
+public interface PathResolver {
+ /**
+ * Get the absolute path of the file
+ *
+ * @param fileId
+ * @return the absolute path of file
+ */
+ public String getAbsolutePath(String fileId);
+
+}
diff --git a/core/src/main/scala/spark/network/netty/FileHeader.scala b/core/src/main/scala/spark/network/netty/FileHeader.scala new file mode 100644 index 0000000000..aed4254234 --- /dev/null +++ b/core/src/main/scala/spark/network/netty/FileHeader.scala @@ -0,0 +1,57 @@ +package spark.network.netty + +import io.netty.buffer._ + +import spark.Logging + +private[spark] class FileHeader ( + val fileLen: Int, + val blockId: String) extends Logging { + + lazy val buffer = { + val buf = Unpooled.buffer() + buf.capacity(FileHeader.HEADER_SIZE) + buf.writeInt(fileLen) + buf.writeInt(blockId.length) + blockId.foreach((x: Char) => buf.writeByte(x)) + //padding the rest of header + if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) { + buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes) + } else { + throw new Exception("too long header " + buf.readableBytes) + logInfo("too long header") + } + buf + } + +} + +private[spark] object FileHeader { + + val HEADER_SIZE = 40 + + def getFileLenOffset = 0 + def getFileLenSize = Integer.SIZE/8 + + def create(buf: ByteBuf): FileHeader = { + val length = buf.readInt + val idLength = buf.readInt + val idBuilder = new StringBuilder(idLength) + for (i <- 1 to idLength) { + idBuilder += buf.readByte().asInstanceOf[Char] + } + val blockId = idBuilder.toString() + new FileHeader(length, blockId) + } + + + def main (args:Array[String]){ + + val header = new FileHeader(25,"block_0"); + val buf = header.buffer; + val newheader = FileHeader.create(buf); + System.out.println("id="+newheader.blockId+",size="+newheader.fileLen) + + } +} + diff --git a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala new file mode 100644 index 0000000000..d8d35bfeec --- /dev/null +++ b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala @@ -0,0 +1,88 @@ +package spark.network.netty + +import io.netty.buffer.ByteBuf +import io.netty.channel.ChannelHandlerContext +import io.netty.channel.ChannelInboundByteHandlerAdapter +import io.netty.util.CharsetUtil + +import java.util.concurrent.atomic.AtomicInteger +import java.util.logging.Logger +import spark.Logging +import spark.network.ConnectionManagerId +import java.util.concurrent.Executors + +private[spark] class ShuffleCopier extends Logging { + + def getBlock(cmId: ConnectionManagerId, + blockId: String, + resultCollectCallback: (String, Long, ByteBuf) => Unit) = { + + val handler = new ShuffleClientHandler(resultCollectCallback) + val fc = new FileClient(handler) + fc.init() + fc.connect(cmId.host, cmId.port) + fc.sendRequest(blockId) + fc.waitForClose() + fc.close() + } + + def getBlocks(cmId: ConnectionManagerId, + blocks: Seq[(String, Long)], + resultCollectCallback: (String, Long, ByteBuf) => Unit) = { + + blocks.map { + case(blockId,size) => { + getBlock(cmId,blockId,resultCollectCallback) + } + } + } +} + +private[spark] class ShuffleClientHandler(val resultCollectCallBack: (String, Long, ByteBuf) => Unit ) extends FileClientHandler with Logging { + + def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) { + logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)"); + resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen)) + } +} + +private[spark] object ShuffleCopier extends Logging { + + def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) = { + logInfo("File: " + blockId + " content is : \" " + + content.toString(CharsetUtil.UTF_8) + "\"") + } + + def runGetBlock(host:String, port:Int, file:String){ + val handler = new ShuffleClientHandler(echoResultCollectCallBack) + val fc = new FileClient(handler) + fc.init(); + fc.connect(host, port) + fc.sendRequest(file) + fc.waitForClose(); + fc.close() + } + + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println("Usage: ShuffleCopier <host> <port> <shuffle_block_id> <threads>") + System.exit(1) + } + val host = args(0) + val port = args(1).toInt + val file = args(2) + val threads = if (args.length>3) args(3).toInt else 10 + + val copiers = Executors.newFixedThreadPool(80) + for (i <- Range(0,threads)){ + val runnable = new Runnable() { + def run() { + runGetBlock(host,port,file) + } + } + copiers.execute(runnable) + } + copiers.shutdown + } + +} diff --git a/core/src/main/scala/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/spark/network/netty/ShuffleSender.scala new file mode 100644 index 0000000000..c1986812e9 --- /dev/null +++ b/core/src/main/scala/spark/network/netty/ShuffleSender.scala @@ -0,0 +1,50 @@ +package spark.network.netty + +import spark.Logging +import java.io.File + + +private[spark] class ShuffleSender(val port: Int, val pResolver:PathResolver) extends Logging { + val server = new FileServer(pResolver) + + Runtime.getRuntime().addShutdownHook( + new Thread() { + override def run() { + server.stop() + } + } + ) + + def start() { + server.run(port) + } +} + +private[spark] object ShuffleSender { + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println("Usage: ShuffleSender <port> <subDirsPerLocalDir> <list of shuffle_block_directories>") + System.exit(1) + } + val port = args(0).toInt + val subDirsPerLocalDir = args(1).toInt + val localDirs = args.drop(2) map {new File(_)} + val pResovler = new PathResolver { + def getAbsolutePath(blockId:String):String = { + if (!blockId.startsWith("shuffle_")) { + throw new Exception("Block " + blockId + " is not a shuffle block") + } + // Figure out which local directory it hashes to, and which subdirectory in that + val hash = math.abs(blockId.hashCode) + val dirId = hash % localDirs.length + val subDirId = (hash / localDirs.length) % subDirsPerLocalDir + val subDir = new File(localDirs(dirId), "%02x".format(subDirId)) + val file = new File(subDir, blockId) + return file.getAbsolutePath + } + } + val sender = new ShuffleSender(port, pResovler) + + sender.start() + } +} diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 210061e972..b8b68d4283 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -23,6 +23,8 @@ import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStam import sun.nio.ch.DirectBuffer +import spark.network.netty.ShuffleCopier +import io.netty.buffer.ByteBuf private[spark] case class BlockException(blockId: String, message: String, ex: Exception = null) @@ -468,6 +470,21 @@ class BlockManager( } /** + * A request to fetch one or more blocks, complete with their sizes + */ + class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) { + val size = blocks.map(_._2).sum + } + + /** + * A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize + * the block (since we want all deserializaton to happen in the calling thread); can also + * represent a fetch failure if size == -1. + */ + class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) { + def failed: Boolean = size == -1 + } + /** * Get multiple blocks from local and remote block manager using their BlockManagerIds. Returns * an Iterator of (block ID, value) pairs so that clients may handle blocks in a pipelined * fashion as they're received. Expects a size in bytes to be provided for each block fetched, @@ -475,7 +492,12 @@ class BlockManager( */ def getMultiple(blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])]) : BlockFetcherIterator = { - return new BlockFetcherIterator(this, blocksByAddress) + + if(System.getProperty("spark.shuffle.use.netty", "false").toBoolean){ + return new NettyBlockFetcherIterator(this, blocksByAddress) + } else { + return new BlockFetcherIterator(this, blocksByAddress) + } } def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean) @@ -908,7 +930,7 @@ class BlockFetcherIterator( if (blocksByAddress == null) { throw new IllegalArgumentException("BlocksByAddress is null") } - val totalBlocks = blocksByAddress.map(_._2.size).sum + var totalBlocks = blocksByAddress.map(_._2.size).sum logDebug("Getting " + totalBlocks + " blocks") var startTime = System.currentTimeMillis val localBlockIds = new ArrayBuffer[String]() @@ -974,68 +996,83 @@ class BlockFetcherIterator( } } - // Split local and remote blocks. Remote blocks are further split into FetchRequests of size - // at most maxBytesInFlight in order to limit the amount of data in flight. - val remoteRequests = new ArrayBuffer[FetchRequest] - for ((address, blockInfos) <- blocksByAddress) { - if (address == blockManagerId) { - localBlockIds ++= blockInfos.map(_._1) - } else { - remoteBlockIds ++= blockInfos.map(_._1) - // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them - // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 - // nodes, rather than blocking on reading output from one node. - val minRequestSize = math.max(maxBytesInFlight / 5, 1L) - logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize) - val iterator = blockInfos.iterator - var curRequestSize = 0L - var curBlocks = new ArrayBuffer[(String, Long)] - while (iterator.hasNext) { - val (blockId, size) = iterator.next() - curBlocks += ((blockId, size)) - curRequestSize += size - if (curRequestSize >= minRequestSize) { - // Add this FetchRequest + def splitLocalRemoteBlocks():ArrayBuffer[FetchRequest] = { + // Split local and remote blocks. Remote blocks are further split into FetchRequests of size + // at most maxBytesInFlight in order to limit the amount of data in flight. + val remoteRequests = new ArrayBuffer[FetchRequest] + for ((address, blockInfos) <- blocksByAddress) { + if (address == blockManagerId) { + localBlockIds ++= blockInfos.map(_._1) + } else { + remoteBlockIds ++= blockInfos.map(_._1) + // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them + // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 + // nodes, rather than blocking on reading output from one node. + val minRequestSize = math.max(maxBytesInFlight / 5, 1L) + logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize) + val iterator = blockInfos.iterator + var curRequestSize = 0L + var curBlocks = new ArrayBuffer[(String, Long)] + while (iterator.hasNext) { + val (blockId, size) = iterator.next() + curBlocks += ((blockId, size)) + curRequestSize += size + if (curRequestSize >= minRequestSize) { + // Add this FetchRequest + remoteRequests += new FetchRequest(address, curBlocks) + curRequestSize = 0 + curBlocks = new ArrayBuffer[(String, Long)] + } + } + // Add in the final request + if (!curBlocks.isEmpty) { remoteRequests += new FetchRequest(address, curBlocks) - curRequestSize = 0 - curBlocks = new ArrayBuffer[(String, Long)] } } - // Add in the final request - if (!curBlocks.isEmpty) { - remoteRequests += new FetchRequest(address, curBlocks) - } } + remoteRequests } - // Add the remote requests into our queue in a random order - fetchRequests ++= Utils.randomize(remoteRequests) - // Send out initial requests for blocks, up to our maxBytesInFlight - while (!fetchRequests.isEmpty && - (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { - sendRequest(fetchRequests.dequeue()) + def getLocalBlocks(){ + // Get the local blocks while remote blocks are being fetched. Note that it's okay to do + // these all at once because they will just memory-map some files, so they won't consume + // any memory that might exceed our maxBytesInFlight + for (id <- localBlockIds) { + getLocal(id) match { + case Some(iter) => { + results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight + logDebug("Got local block " + id) + } + case None => { + throw new BlockException(id, "Could not get block " + id + " from local machine") + } + } + } } - val numGets = remoteBlockIds.size - fetchRequests.size - logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime)) - - // Get the local blocks while remote blocks are being fetched. Note that it's okay to do - // these all at once because they will just memory-map some files, so they won't consume - // any memory that might exceed our maxBytesInFlight - startTime = System.currentTimeMillis - for (id <- localBlockIds) { - getLocal(id) match { - case Some(iter) => { - results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight - logDebug("Got local block " + id) - } - case None => { - throw new BlockException(id, "Could not get block " + id + " from local machine") - } + def initialize(){ + // Split local and remote blocks. + val remoteRequests = splitLocalRemoteBlocks() + // Add the remote requests into our queue in a random order + fetchRequests ++= Utils.randomize(remoteRequests) + + // Send out initial requests for blocks, up to our maxBytesInFlight + while (!fetchRequests.isEmpty && + (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { + sendRequest(fetchRequests.dequeue()) } + + val numGets = remoteBlockIds.size - fetchRequests.size + logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime)) + + // Get Local Blocks + startTime = System.currentTimeMillis + getLocalBlocks() + logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") + } - logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") + initialize() //an iterator that will read fetched blocks off the queue as they arrive. var resultsGotten = 0 @@ -1066,3 +1103,132 @@ class BlockFetcherIterator( def remoteBytesRead = _remoteBytesRead } + +class NettyBlockFetcherIterator( + blockManager: BlockManager, + blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] +) extends BlockFetcherIterator(blockManager,blocksByAddress) { + + import blockManager._ + + val fetchRequestsSync = new LinkedBlockingQueue[FetchRequest] + + def putResult(blockId:String, blockSize:Long, blockData:ByteBuffer, + results : LinkedBlockingQueue[FetchResult]){ + results.put(new FetchResult( + blockId, blockSize, () => dataDeserialize(blockId, blockData) )) + } + + def startCopiers (numCopiers: Int): List [ _ <: Thread]= { + (for ( i <- Range(0,numCopiers) ) yield { + val copier = new Thread { + override def run(){ + try { + while(!isInterrupted && !fetchRequestsSync.isEmpty) { + sendRequest(fetchRequestsSync.take()) + } + } catch { + case x: InterruptedException => logInfo("Copier Interrupted") + case _ => throw new SparkException("Exception Throw in Shuffle Copier") + } + } + } + copier.start + copier + }).toList + } + + //keep this to interrupt the threads when necessary + def stopCopiers(copiers : List[_ <: Thread]) { + for (copier <- copiers) { + copier.interrupt() + } + } + + override def sendRequest(req: FetchRequest) { + logDebug("Sending request for %d blocks (%s) from %s".format( + req.blocks.size, Utils.memoryBytesToString(req.size), req.address.ip)) + val cmId = new ConnectionManagerId(req.address.ip, System.getProperty("spark.shuffle.sender.port", "6653").toInt) + val cpier = new ShuffleCopier + cpier.getBlocks(cmId,req.blocks,(blockId:String,blockSize:Long,blockData:ByteBuf) => putResult(blockId,blockSize,blockData.nioBuffer,results)) + logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.ip ) + } + + override def splitLocalRemoteBlocks() : ArrayBuffer[FetchRequest] = { + // Split local and remote blocks. Remote blocks are further split into FetchRequests of size + // at most maxBytesInFlight in order to limit the amount of data in flight. + val originalTotalBlocks = totalBlocks; + val remoteRequests = new ArrayBuffer[FetchRequest] + for ((address, blockInfos) <- blocksByAddress) { + if (address == blockManagerId) { + localBlockIds ++= blockInfos.map(_._1) + } else { + remoteBlockIds ++= blockInfos.map(_._1) + // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them + // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 + // nodes, rather than blocking on reading output from one node. + val minRequestSize = math.max(maxBytesInFlight / 5, 1L) + logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize) + val iterator = blockInfos.iterator + var curRequestSize = 0L + var curBlocks = new ArrayBuffer[(String, Long)] + while (iterator.hasNext) { + val (blockId, size) = iterator.next() + if (size > 0) { + curBlocks += ((blockId, size)) + curRequestSize += size + } else if (size == 0){ + //here we changes the totalBlocks + totalBlocks -= 1 + } else { + throw new SparkException("Negative block size "+blockId) + } + if (curRequestSize >= minRequestSize) { + // Add this FetchRequest + remoteRequests += new FetchRequest(address, curBlocks) + curRequestSize = 0 + curBlocks = new ArrayBuffer[(String, Long)] + } + } + // Add in the final request + if (!curBlocks.isEmpty) { + remoteRequests += new FetchRequest(address, curBlocks) + } + } + } + logInfo("Getting " + totalBlocks + " non 0-byte blocks out of " + originalTotalBlocks + " blocks") + remoteRequests + } + + var copiers : List[_ <: Thread] = null + + override def initialize(){ + // Split Local Remote Blocks and adjust totalBlocks to include only the non 0-byte blocks + val remoteRequests = splitLocalRemoteBlocks() + // Add the remote requests into our queue in a random order + for (request <- Utils.randomize(remoteRequests)) { + fetchRequestsSync.put(request) + } + + copiers = startCopiers(System.getProperty("spark.shuffle.copier.threads", "6").toInt) + logInfo("Started " + fetchRequestsSync.size + " remote gets in " + Utils.getUsedTimeMs(startTime)) + + // Get Local Blocks + startTime = System.currentTimeMillis + getLocalBlocks() + logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") + } + + override def next(): (String, Option[Iterator[Any]]) = { + resultsGotten += 1 + val result = results.take() + // if all the results has been retrieved + // shutdown the copiers + if (resultsGotten == totalBlocks) { + if( copiers != null ) + stopCopiers(copiers) + } + (result.blockId, if (result.failed) None else Some(result.deserialize())) + } + } + diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index ddbf8821ad..d702bb23e0 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -13,24 +13,35 @@ import scala.collection.mutable.ArrayBuffer import spark.executor.ExecutorExitCode import spark.Utils +import spark.Logging +import spark.network.netty.ShuffleSender +import spark.network.netty.PathResolver /** * Stores BlockManager blocks on disk. */ private class DiskStore(blockManager: BlockManager, rootDirs: String) - extends BlockStore(blockManager) { + extends BlockStore(blockManager) with Logging { val MAX_DIR_CREATION_ATTEMPTS: Int = 10 val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt + var shuffleSender : Thread = null + val thisInstance = this // Create one local directory for each path mentioned in spark.local.dir; then, inside this // directory, create multiple subdirectories that we will hash files into, in order to avoid // having really large inodes at the top level. val localDirs = createLocalDirs() val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) + val useNetty = System.getProperty("spark.shuffle.use.netty", "false").toBoolean + addShutdownHook() + if(useNetty){ + startShuffleBlockSender() + } + override def getSize(blockId: String): Long = { getFile(blockId).length() } @@ -180,10 +191,48 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) logDebug("Shutdown hook called") try { localDirs.foreach(localDir => Utils.deleteRecursively(localDir)) + if (useNetty && shuffleSender != null) + shuffleSender.stop } catch { case t: Throwable => logError("Exception while deleting local spark dirs", t) } } }) } + + private def startShuffleBlockSender (){ + try { + val port = System.getProperty("spark.shuffle.sender.port", "6653").toInt + + val pResolver = new PathResolver { + def getAbsolutePath(blockId:String):String = { + if (!blockId.startsWith("shuffle_")) { + return null + } + thisInstance.getFile(blockId).getAbsolutePath() + } + } + shuffleSender = new Thread { + override def run() = { + val sender = new ShuffleSender(port,pResolver) + logInfo("created ShuffleSender binding to port : "+ port) + sender.start + } + } + shuffleSender.setDaemon(true) + shuffleSender.start + + } catch { + case interrupted: InterruptedException => + logInfo("Runner thread for ShuffleBlockSender interrupted") + + case e: Exception => { + logError("Error running ShuffleBlockSender ", e) + if (shuffleSender != null) { + shuffleSender.stop + shuffleSender = null + } + } + } + } } diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 5f378b2398..e3645653ee 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -141,7 +141,8 @@ object SparkBuild extends Build { "cc.spray" % "spray-can" % "1.0-M2.1", "cc.spray" % "spray-server" % "1.0-M2.1", "cc.spray" %% "spray-json" % "1.1.1", - "org.apache.mesos" % "mesos" % "0.9.0-incubating" + "org.apache.mesos" % "mesos" % "0.9.0-incubating", + "io.netty" % "netty-all" % "4.0.0.Beta2" ) ++ (if (HADOOP_MAJOR_VERSION == "2") Some("org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION) else None).toSeq, unmanagedSourceDirectories in Compile <+= baseDirectory{ _ / ("src/hadoop" + HADOOP_MAJOR_VERSION + "/scala") } ) ++ assemblySettings ++ extraAssemblySettings ++ Twirl.settings diff --git a/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala index d8b987ec86..bd0b0e74c1 100644 --- a/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala +++ b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala @@ -5,7 +5,7 @@ import spark.util.{RateLimitedOutputStream, IntParam} import java.net.ServerSocket import spark.{Logging, KryoSerializer} import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream -import io.Source +import scala.io.Source import java.io.IOException /** |