aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/java/spark/network/netty/FileClient.java89
-rw-r--r--core/src/main/java/spark/network/netty/FileClientChannelInitializer.java29
-rw-r--r--core/src/main/java/spark/network/netty/FileClientHandler.java38
-rw-r--r--core/src/main/java/spark/network/netty/FileServer.java58
-rw-r--r--core/src/main/java/spark/network/netty/FileServerChannelInitializer.java33
-rw-r--r--core/src/main/java/spark/network/netty/FileServerHandler.java68
-rwxr-xr-xcore/src/main/java/spark/network/netty/PathResolver.java12
-rw-r--r--core/src/main/scala/spark/network/netty/FileHeader.scala57
-rw-r--r--core/src/main/scala/spark/network/netty/ShuffleCopier.scala88
-rw-r--r--core/src/main/scala/spark/network/netty/ShuffleSender.scala50
-rw-r--r--core/src/main/scala/spark/storage/BlockManager.scala302
-rw-r--r--core/src/main/scala/spark/storage/DiskStore.scala52
-rw-r--r--project/SparkBuild.scala3
-rw-r--r--streaming/src/main/scala/spark/streaming/util/RawTextSender.scala2
14 files changed, 811 insertions, 70 deletions
diff --git a/core/src/main/java/spark/network/netty/FileClient.java b/core/src/main/java/spark/network/netty/FileClient.java
new file mode 100644
index 0000000000..d0c5081dd2
--- /dev/null
+++ b/core/src/main/java/spark/network/netty/FileClient.java
@@ -0,0 +1,89 @@
+package spark.network.netty;
+
+import io.netty.bootstrap.Bootstrap;
+import io.netty.channel.AbstractChannel;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelFutureListener;
+import io.netty.channel.ChannelOption;
+import io.netty.channel.nio.NioEventLoopGroup;
+import io.netty.channel.socket.nio.NioSocketChannel;
+import io.netty.channel.oio.OioEventLoopGroup;
+import io.netty.channel.socket.oio.OioSocketChannel;
+
+import java.util.Arrays;
+
+public class FileClient {
+
+ private FileClientHandler handler = null;
+ private Channel channel = null;
+ private Bootstrap bootstrap = null;
+
+ public FileClient(FileClientHandler handler){
+ this.handler = handler;
+ }
+
+ public void init(){
+ bootstrap = new Bootstrap();
+ bootstrap.group(new OioEventLoopGroup())
+ .channel(OioSocketChannel.class)
+ .option(ChannelOption.SO_KEEPALIVE, true)
+ .option(ChannelOption.TCP_NODELAY, true)
+ .handler(new FileClientChannelInitializer(handler));
+ }
+
+ public static final class ChannelCloseListener implements ChannelFutureListener {
+ private FileClient fc = null;
+ public ChannelCloseListener(FileClient fc){
+ this.fc = fc;
+ }
+ @Override
+ public void operationComplete(ChannelFuture future) {
+ if (fc.bootstrap!=null){
+ fc.bootstrap.shutdown();
+ fc.bootstrap = null;
+ }
+ }
+ }
+
+ public void connect(String host, int port){
+ try {
+
+ // Start the connection attempt.
+ channel = bootstrap.connect(host, port).sync().channel();
+ // ChannelFuture cf = channel.closeFuture();
+ //cf.addListener(new ChannelCloseListener(this));
+ } catch (InterruptedException e) {
+ close();
+ }
+ }
+
+ public void waitForClose(){
+ try {
+ channel.closeFuture().sync();
+ } catch (InterruptedException e){
+ e.printStackTrace();
+ }
+ }
+
+ public void sendRequest(String file){
+ //assert(file == null);
+ //assert(channel == null);
+ channel.write(file+"\r\n");
+ }
+
+ public void close(){
+ if(channel != null) {
+ channel.close();
+ channel = null;
+ }
+ if ( bootstrap!=null) {
+ bootstrap.shutdown();
+ bootstrap = null;
+ }
+ }
+
+
+}
+
+
diff --git a/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java b/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java
new file mode 100644
index 0000000000..50e5704619
--- /dev/null
+++ b/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java
@@ -0,0 +1,29 @@
+package spark.network.netty;
+
+import io.netty.buffer.BufType;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.socket.SocketChannel;
+import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
+import io.netty.handler.codec.string.StringEncoder;
+import io.netty.util.CharsetUtil;
+
+import io.netty.handler.logging.LoggingHandler;
+import io.netty.handler.logging.LogLevel;
+
+public class FileClientChannelInitializer extends
+ ChannelInitializer<SocketChannel> {
+
+ private FileClientHandler fhandler;
+
+ public FileClientChannelInitializer(FileClientHandler handler) {
+ fhandler = handler;
+ }
+
+ @Override
+ public void initChannel(SocketChannel channel) {
+ // file no more than 2G
+ channel.pipeline()
+ .addLast("encoder", new StringEncoder(BufType.BYTE))
+ .addLast("handler", fhandler);
+ }
+}
diff --git a/core/src/main/java/spark/network/netty/FileClientHandler.java b/core/src/main/java/spark/network/netty/FileClientHandler.java
new file mode 100644
index 0000000000..911c8b32b5
--- /dev/null
+++ b/core/src/main/java/spark/network/netty/FileClientHandler.java
@@ -0,0 +1,38 @@
+package spark.network.netty;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundByteHandlerAdapter;
+import io.netty.util.CharsetUtil;
+
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.logging.Logger;
+
+public abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter {
+
+ private FileHeader currentHeader = null;
+
+ public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header);
+
+ @Override
+ public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) {
+ // Use direct buffer if possible.
+ return ctx.alloc().ioBuffer();
+ }
+
+ @Override
+ public void inboundBufferUpdated(ChannelHandlerContext ctx, ByteBuf in) {
+ // get header
+ if (currentHeader == null && in.readableBytes() >= FileHeader.HEADER_SIZE()) {
+ currentHeader = FileHeader.create(in.readBytes(FileHeader.HEADER_SIZE()));
+ }
+ // get file
+ if(in.readableBytes() >= currentHeader.fileLen()){
+ handle(ctx,in,currentHeader);
+ currentHeader = null;
+ ctx.close();
+ }
+ }
+
+}
+
diff --git a/core/src/main/java/spark/network/netty/FileServer.java b/core/src/main/java/spark/network/netty/FileServer.java
new file mode 100644
index 0000000000..38af305096
--- /dev/null
+++ b/core/src/main/java/spark/network/netty/FileServer.java
@@ -0,0 +1,58 @@
+package spark.network.netty;
+
+import java.io.File;
+import io.netty.bootstrap.ServerBootstrap;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelOption;
+import io.netty.channel.Channel;
+import io.netty.channel.nio.NioEventLoopGroup;
+import io.netty.channel.socket.nio.NioServerSocketChannel;
+import io.netty.channel.oio.OioEventLoopGroup;
+import io.netty.channel.socket.oio.OioServerSocketChannel;
+import io.netty.handler.logging.LogLevel;
+import io.netty.handler.logging.LoggingHandler;
+
+/**
+ * Server that accept the path of a file an echo back its content.
+ */
+public class FileServer {
+
+ private ServerBootstrap bootstrap = null;
+ private Channel channel = null;
+ private PathResolver pResolver;
+
+ public FileServer(PathResolver pResolver){
+ this.pResolver = pResolver;
+ }
+
+ public void run(int port) {
+ // Configure the server.
+ bootstrap = new ServerBootstrap();
+ try {
+ bootstrap.group(new OioEventLoopGroup(), new OioEventLoopGroup())
+ .channel(OioServerSocketChannel.class)
+ .option(ChannelOption.SO_BACKLOG, 100)
+ .option(ChannelOption.SO_RCVBUF, 1500)
+ .childHandler(new FileServerChannelInitializer(pResolver));
+ // Start the server.
+ channel = bootstrap.bind(port).sync().channel();
+ channel.closeFuture().sync();
+ } catch (InterruptedException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ } finally{
+ bootstrap.shutdown();
+ }
+ }
+
+ public void stop(){
+ if (channel!=null){
+ channel.close();
+ }
+ if (bootstrap != null){
+ bootstrap.shutdown();
+ }
+ }
+}
+
+
diff --git a/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java b/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java
new file mode 100644
index 0000000000..9d0618ff1c
--- /dev/null
+++ b/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java
@@ -0,0 +1,33 @@
+package spark.network.netty;
+
+import java.io.File;
+import io.netty.buffer.BufType;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.socket.SocketChannel;
+import io.netty.handler.codec.string.StringDecoder;
+import io.netty.handler.codec.string.StringEncoder;
+import io.netty.handler.codec.DelimiterBasedFrameDecoder;
+import io.netty.handler.codec.Delimiters;
+import io.netty.util.CharsetUtil;
+import io.netty.handler.logging.LoggingHandler;
+import io.netty.handler.logging.LogLevel;
+
+public class FileServerChannelInitializer extends
+ ChannelInitializer<SocketChannel> {
+
+ PathResolver pResolver;
+
+ public FileServerChannelInitializer(PathResolver pResolver) {
+ this.pResolver = pResolver;
+ }
+
+ @Override
+ public void initChannel(SocketChannel channel) {
+ channel.pipeline()
+ .addLast("framer", new DelimiterBasedFrameDecoder(
+ 8192, Delimiters.lineDelimiter()))
+ .addLast("strDecoder", new StringDecoder())
+ .addLast("handler", new FileServerHandler(pResolver));
+
+ }
+}
diff --git a/core/src/main/java/spark/network/netty/FileServerHandler.java b/core/src/main/java/spark/network/netty/FileServerHandler.java
new file mode 100644
index 0000000000..e1083e87a2
--- /dev/null
+++ b/core/src/main/java/spark/network/netty/FileServerHandler.java
@@ -0,0 +1,68 @@
+package spark.network.netty;
+
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundMessageHandlerAdapter;
+import io.netty.channel.DefaultFileRegion;
+import io.netty.handler.stream.ChunkedFile;
+import java.io.File;
+import java.io.FileInputStream;
+
+public class FileServerHandler extends
+ ChannelInboundMessageHandlerAdapter<String> {
+
+ PathResolver pResolver;
+
+ public FileServerHandler(PathResolver pResolver){
+ this.pResolver = pResolver;
+ }
+
+ @Override
+ public void messageReceived(ChannelHandlerContext ctx, String blockId) {
+ String path = pResolver.getAbsolutePath(blockId);
+ // if getFilePath returns null, close the channel
+ if (path == null) {
+ //ctx.close();
+ return;
+ }
+ File file = new File(path);
+ if (file.exists()) {
+ if (!file.isFile()) {
+ //logger.info("Not a file : " + file.getAbsolutePath());
+ ctx.write(new FileHeader(0, blockId).buffer());
+ ctx.flush();
+ return;
+ }
+ long length = file.length();
+ if (length > Integer.MAX_VALUE || length <= 0 ) {
+ //logger.info("too large file : " + file.getAbsolutePath() + " of size "+ length);
+ ctx.write(new FileHeader(0, blockId).buffer());
+ ctx.flush();
+ return;
+ }
+ int len = new Long(length).intValue();
+ //logger.info("Sending block "+blockId+" filelen = "+len);
+ //logger.info("header = "+ (new FileHeader(len, blockId)).buffer());
+ ctx.write((new FileHeader(len, blockId)).buffer());
+ try {
+ ctx.sendFile(new DefaultFileRegion(new FileInputStream(file)
+ .getChannel(), 0, file.length()));
+ } catch (Exception e) {
+ // TODO Auto-generated catch block
+ //logger.warning("Exception when sending file : "
+ //+ file.getAbsolutePath());
+ e.printStackTrace();
+ }
+ } else {
+ //logger.warning("File not found: " + file.getAbsolutePath());
+ ctx.write(new FileHeader(0, blockId).buffer());
+ }
+ ctx.flush();
+ }
+
+
+ @Override
+ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
+ cause.printStackTrace();
+ ctx.close();
+ }
+}
diff --git a/core/src/main/java/spark/network/netty/PathResolver.java b/core/src/main/java/spark/network/netty/PathResolver.java
new file mode 100755
index 0000000000..5d5eda006e
--- /dev/null
+++ b/core/src/main/java/spark/network/netty/PathResolver.java
@@ -0,0 +1,12 @@
+package spark.network.netty;
+
+public interface PathResolver {
+ /**
+ * Get the absolute path of the file
+ *
+ * @param fileId
+ * @return the absolute path of file
+ */
+ public String getAbsolutePath(String fileId);
+
+}
diff --git a/core/src/main/scala/spark/network/netty/FileHeader.scala b/core/src/main/scala/spark/network/netty/FileHeader.scala
new file mode 100644
index 0000000000..aed4254234
--- /dev/null
+++ b/core/src/main/scala/spark/network/netty/FileHeader.scala
@@ -0,0 +1,57 @@
+package spark.network.netty
+
+import io.netty.buffer._
+
+import spark.Logging
+
+private[spark] class FileHeader (
+ val fileLen: Int,
+ val blockId: String) extends Logging {
+
+ lazy val buffer = {
+ val buf = Unpooled.buffer()
+ buf.capacity(FileHeader.HEADER_SIZE)
+ buf.writeInt(fileLen)
+ buf.writeInt(blockId.length)
+ blockId.foreach((x: Char) => buf.writeByte(x))
+ //padding the rest of header
+ if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) {
+ buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes)
+ } else {
+ throw new Exception("too long header " + buf.readableBytes)
+ logInfo("too long header")
+ }
+ buf
+ }
+
+}
+
+private[spark] object FileHeader {
+
+ val HEADER_SIZE = 40
+
+ def getFileLenOffset = 0
+ def getFileLenSize = Integer.SIZE/8
+
+ def create(buf: ByteBuf): FileHeader = {
+ val length = buf.readInt
+ val idLength = buf.readInt
+ val idBuilder = new StringBuilder(idLength)
+ for (i <- 1 to idLength) {
+ idBuilder += buf.readByte().asInstanceOf[Char]
+ }
+ val blockId = idBuilder.toString()
+ new FileHeader(length, blockId)
+ }
+
+
+ def main (args:Array[String]){
+
+ val header = new FileHeader(25,"block_0");
+ val buf = header.buffer;
+ val newheader = FileHeader.create(buf);
+ System.out.println("id="+newheader.blockId+",size="+newheader.fileLen)
+
+ }
+}
+
diff --git a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala
new file mode 100644
index 0000000000..d8d35bfeec
--- /dev/null
+++ b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala
@@ -0,0 +1,88 @@
+package spark.network.netty
+
+import io.netty.buffer.ByteBuf
+import io.netty.channel.ChannelHandlerContext
+import io.netty.channel.ChannelInboundByteHandlerAdapter
+import io.netty.util.CharsetUtil
+
+import java.util.concurrent.atomic.AtomicInteger
+import java.util.logging.Logger
+import spark.Logging
+import spark.network.ConnectionManagerId
+import java.util.concurrent.Executors
+
+private[spark] class ShuffleCopier extends Logging {
+
+ def getBlock(cmId: ConnectionManagerId,
+ blockId: String,
+ resultCollectCallback: (String, Long, ByteBuf) => Unit) = {
+
+ val handler = new ShuffleClientHandler(resultCollectCallback)
+ val fc = new FileClient(handler)
+ fc.init()
+ fc.connect(cmId.host, cmId.port)
+ fc.sendRequest(blockId)
+ fc.waitForClose()
+ fc.close()
+ }
+
+ def getBlocks(cmId: ConnectionManagerId,
+ blocks: Seq[(String, Long)],
+ resultCollectCallback: (String, Long, ByteBuf) => Unit) = {
+
+ blocks.map {
+ case(blockId,size) => {
+ getBlock(cmId,blockId,resultCollectCallback)
+ }
+ }
+ }
+}
+
+private[spark] class ShuffleClientHandler(val resultCollectCallBack: (String, Long, ByteBuf) => Unit ) extends FileClientHandler with Logging {
+
+ def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) {
+ logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)");
+ resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen))
+ }
+}
+
+private[spark] object ShuffleCopier extends Logging {
+
+ def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) = {
+ logInfo("File: " + blockId + " content is : \" "
+ + content.toString(CharsetUtil.UTF_8) + "\"")
+ }
+
+ def runGetBlock(host:String, port:Int, file:String){
+ val handler = new ShuffleClientHandler(echoResultCollectCallBack)
+ val fc = new FileClient(handler)
+ fc.init();
+ fc.connect(host, port)
+ fc.sendRequest(file)
+ fc.waitForClose();
+ fc.close()
+ }
+
+ def main(args: Array[String]) {
+ if (args.length < 3) {
+ System.err.println("Usage: ShuffleCopier <host> <port> <shuffle_block_id> <threads>")
+ System.exit(1)
+ }
+ val host = args(0)
+ val port = args(1).toInt
+ val file = args(2)
+ val threads = if (args.length>3) args(3).toInt else 10
+
+ val copiers = Executors.newFixedThreadPool(80)
+ for (i <- Range(0,threads)){
+ val runnable = new Runnable() {
+ def run() {
+ runGetBlock(host,port,file)
+ }
+ }
+ copiers.execute(runnable)
+ }
+ copiers.shutdown
+ }
+
+}
diff --git a/core/src/main/scala/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/spark/network/netty/ShuffleSender.scala
new file mode 100644
index 0000000000..c1986812e9
--- /dev/null
+++ b/core/src/main/scala/spark/network/netty/ShuffleSender.scala
@@ -0,0 +1,50 @@
+package spark.network.netty
+
+import spark.Logging
+import java.io.File
+
+
+private[spark] class ShuffleSender(val port: Int, val pResolver:PathResolver) extends Logging {
+ val server = new FileServer(pResolver)
+
+ Runtime.getRuntime().addShutdownHook(
+ new Thread() {
+ override def run() {
+ server.stop()
+ }
+ }
+ )
+
+ def start() {
+ server.run(port)
+ }
+}
+
+private[spark] object ShuffleSender {
+ def main(args: Array[String]) {
+ if (args.length < 3) {
+ System.err.println("Usage: ShuffleSender <port> <subDirsPerLocalDir> <list of shuffle_block_directories>")
+ System.exit(1)
+ }
+ val port = args(0).toInt
+ val subDirsPerLocalDir = args(1).toInt
+ val localDirs = args.drop(2) map {new File(_)}
+ val pResovler = new PathResolver {
+ def getAbsolutePath(blockId:String):String = {
+ if (!blockId.startsWith("shuffle_")) {
+ throw new Exception("Block " + blockId + " is not a shuffle block")
+ }
+ // Figure out which local directory it hashes to, and which subdirectory in that
+ val hash = math.abs(blockId.hashCode)
+ val dirId = hash % localDirs.length
+ val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
+ val subDir = new File(localDirs(dirId), "%02x".format(subDirId))
+ val file = new File(subDir, blockId)
+ return file.getAbsolutePath
+ }
+ }
+ val sender = new ShuffleSender(port, pResovler)
+
+ sender.start()
+ }
+}
diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala
index 09572b19db..433e939656 100644
--- a/core/src/main/scala/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/spark/storage/BlockManager.scala
@@ -23,6 +23,8 @@ import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStam
import sun.nio.ch.DirectBuffer
+import spark.network.netty.ShuffleCopier
+import io.netty.buffer.ByteBuf
private[spark]
class BlockManager(
@@ -495,7 +497,11 @@ class BlockManager(
def getMultiple(
blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], serializer: Serializer)
: BlockFetcherIterator = {
- return new BlockFetcherIterator(this, blocksByAddress, serializer)
+ if(System.getProperty("spark.shuffle.use.netty", "false").toBoolean){
+ return BlockFetcherIterator("netty",this, blocksByAddress, serializer)
+ } else {
+ return BlockFetcherIterator("", this, blocksByAddress, serializer)
+ }
}
def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean)
@@ -971,11 +977,30 @@ object BlockManager extends Logging {
}
}
-class BlockFetcherIterator(
+
+trait BlockFetcherIterator extends Iterator[(String, Option[Iterator[Any]])] with Logging with BlockFetchTracker {
+ def initialize
+}
+
+object BlockFetcherIterator {
+
+ // A request to fetch one or more blocks, complete with their sizes
+ class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) {
+ val size = blocks.map(_._2).sum
+ }
+
+ // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize
+ // the block (since we want all deserializaton to happen in the calling thread); can also
+ // represent a fetch failure if size == -1.
+ class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) {
+ def failed: Boolean = size == -1
+ }
+
+class BasicBlockFetcherIterator(
private val blockManager: BlockManager,
val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])],
serializer: Serializer
-) extends Iterator[(String, Option[Iterator[Any]])] with Logging with BlockFetchTracker {
+) extends BlockFetcherIterator {
import blockManager._
@@ -986,27 +1011,15 @@ class BlockFetcherIterator(
if (blocksByAddress == null) {
throw new IllegalArgumentException("BlocksByAddress is null")
}
- val totalBlocks = blocksByAddress.map(_._2.size).sum
+ var totalBlocks = blocksByAddress.map(_._2.size).sum
logDebug("Getting " + totalBlocks + " blocks")
var startTime = System.currentTimeMillis
val localBlockIds = new ArrayBuffer[String]()
val remoteBlockIds = new HashSet[String]()
- // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize
- // the block (since we want all deserializaton to happen in the calling thread); can also
- // represent a fetch failure if size == -1.
- class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) {
- def failed: Boolean = size == -1
- }
-
// A queue to hold our results.
val results = new LinkedBlockingQueue[FetchResult]
- // A request to fetch one or more blocks, complete with their sizes
- class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) {
- val size = blocks.map(_._2).sum
- }
-
// Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
// the number of bytes in flight is limited to maxBytesInFlight
val fetchRequests = new Queue[FetchRequest]
@@ -1052,67 +1065,81 @@ class BlockFetcherIterator(
}
}
- // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
- // at most maxBytesInFlight in order to limit the amount of data in flight.
- val remoteRequests = new ArrayBuffer[FetchRequest]
- for ((address, blockInfos) <- blocksByAddress) {
- if (address == blockManagerId) {
- localBlockIds ++= blockInfos.map(_._1)
- } else {
- remoteBlockIds ++= blockInfos.map(_._1)
- // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
- // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
- // nodes, rather than blocking on reading output from one node.
- val minRequestSize = math.max(maxBytesInFlight / 5, 1L)
- logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize)
- val iterator = blockInfos.iterator
- var curRequestSize = 0L
- var curBlocks = new ArrayBuffer[(String, Long)]
- while (iterator.hasNext) {
- val (blockId, size) = iterator.next()
- curBlocks += ((blockId, size))
- curRequestSize += size
- if (curRequestSize >= minRequestSize) {
- // Add this FetchRequest
+ def splitLocalRemoteBlocks():ArrayBuffer[FetchRequest] = {
+ // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
+ // at most maxBytesInFlight in order to limit the amount of data in flight.
+ val remoteRequests = new ArrayBuffer[FetchRequest]
+ for ((address, blockInfos) <- blocksByAddress) {
+ if (address == blockManagerId) {
+ localBlockIds ++= blockInfos.map(_._1)
+ } else {
+ remoteBlockIds ++= blockInfos.map(_._1)
+ // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
+ // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
+ // nodes, rather than blocking on reading output from one node.
+ val minRequestSize = math.max(maxBytesInFlight / 5, 1L)
+ logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize)
+ val iterator = blockInfos.iterator
+ var curRequestSize = 0L
+ var curBlocks = new ArrayBuffer[(String, Long)]
+ while (iterator.hasNext) {
+ val (blockId, size) = iterator.next()
+ curBlocks += ((blockId, size))
+ curRequestSize += size
+ if (curRequestSize >= minRequestSize) {
+ // Add this FetchRequest
+ remoteRequests += new FetchRequest(address, curBlocks)
+ curRequestSize = 0
+ curBlocks = new ArrayBuffer[(String, Long)]
+ }
+ }
+ // Add in the final request
+ if (!curBlocks.isEmpty) {
remoteRequests += new FetchRequest(address, curBlocks)
- curRequestSize = 0
- curBlocks = new ArrayBuffer[(String, Long)]
}
}
- // Add in the final request
- if (!curBlocks.isEmpty) {
- remoteRequests += new FetchRequest(address, curBlocks)
- }
}
+ remoteRequests
}
- // Add the remote requests into our queue in a random order
- fetchRequests ++= Utils.randomize(remoteRequests)
- // Send out initial requests for blocks, up to our maxBytesInFlight
- while (!fetchRequests.isEmpty &&
- (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
- sendRequest(fetchRequests.dequeue())
+ def getLocalBlocks(){
+ // Get the local blocks while remote blocks are being fetched. Note that it's okay to do
+ // these all at once because they will just memory-map some files, so they won't consume
+ // any memory that might exceed our maxBytesInFlight
+ for (id <- localBlockIds) {
+ getLocal(id) match {
+ case Some(iter) => {
+ results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight
+ logDebug("Got local block " + id)
+ }
+ case None => {
+ throw new BlockException(id, "Could not get block " + id + " from local machine")
+ }
+ }
+ }
}
- val numGets = remoteBlockIds.size - fetchRequests.size
- logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime))
-
- // Get the local blocks while remote blocks are being fetched. Note that it's okay to do
- // these all at once because they will just memory-map some files, so they won't consume
- // any memory that might exceed our maxBytesInFlight
- startTime = System.currentTimeMillis
- for (id <- localBlockIds) {
- getLocalFromDisk(id, serializer) match {
- case Some(iter) => {
- results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight
- logDebug("Got local block " + id)
- }
- case None => {
- throw new BlockException(id, "Could not get block " + id + " from local machine")
- }
+ def initialize(){
+ // Split local and remote blocks.
+ val remoteRequests = splitLocalRemoteBlocks()
+ // Add the remote requests into our queue in a random order
+ fetchRequests ++= Utils.randomize(remoteRequests)
+
+ // Send out initial requests for blocks, up to our maxBytesInFlight
+ while (!fetchRequests.isEmpty &&
+ (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
+ sendRequest(fetchRequests.dequeue())
}
+
+ val numGets = remoteBlockIds.size - fetchRequests.size
+ logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime))
+
+ // Get Local Blocks
+ startTime = System.currentTimeMillis
+ getLocalBlocks()
+ logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
+
}
- logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
//an iterator that will read fetched blocks off the queue as they arrive.
@volatile private var resultsGotten = 0
@@ -1144,3 +1171,144 @@ class BlockFetcherIterator(
def remoteBytesRead = _remoteBytesRead
}
+
+class NettyBlockFetcherIterator(
+ blockManager: BlockManager,
+ blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])],
+ serializer: Serializer
+) extends BasicBlockFetcherIterator(blockManager,blocksByAddress,serializer) {
+
+ import blockManager._
+
+ val fetchRequestsSync = new LinkedBlockingQueue[FetchRequest]
+
+ def putResult(blockId:String, blockSize:Long, blockData:ByteBuffer,
+ results : LinkedBlockingQueue[FetchResult]){
+ results.put(new FetchResult(
+ blockId, blockSize, () => dataDeserialize(blockId, blockData, serializer) ))
+ }
+
+ def startCopiers (numCopiers: Int): List [ _ <: Thread]= {
+ (for ( i <- Range(0,numCopiers) ) yield {
+ val copier = new Thread {
+ override def run(){
+ try {
+ while(!isInterrupted && !fetchRequestsSync.isEmpty) {
+ sendRequest(fetchRequestsSync.take())
+ }
+ } catch {
+ case x: InterruptedException => logInfo("Copier Interrupted")
+ //case _ => throw new SparkException("Exception Throw in Shuffle Copier")
+ }
+ }
+ }
+ copier.start
+ copier
+ }).toList
+ }
+
+ //keep this to interrupt the threads when necessary
+ def stopCopiers(copiers : List[_ <: Thread]) {
+ for (copier <- copiers) {
+ copier.interrupt()
+ }
+ }
+
+ override def sendRequest(req: FetchRequest) {
+ logDebug("Sending request for %d blocks (%s) from %s".format(
+ req.blocks.size, Utils.memoryBytesToString(req.size), req.address.ip))
+ val cmId = new ConnectionManagerId(req.address.ip, System.getProperty("spark.shuffle.sender.port", "6653").toInt)
+ val cpier = new ShuffleCopier
+ cpier.getBlocks(cmId,req.blocks,(blockId:String,blockSize:Long,blockData:ByteBuf) => putResult(blockId,blockSize,blockData.nioBuffer,results))
+ logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.ip )
+ }
+
+ override def splitLocalRemoteBlocks() : ArrayBuffer[FetchRequest] = {
+ // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
+ // at most maxBytesInFlight in order to limit the amount of data in flight.
+ val originalTotalBlocks = totalBlocks;
+ val remoteRequests = new ArrayBuffer[FetchRequest]
+ for ((address, blockInfos) <- blocksByAddress) {
+ if (address == blockManagerId) {
+ localBlockIds ++= blockInfos.map(_._1)
+ } else {
+ remoteBlockIds ++= blockInfos.map(_._1)
+ // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
+ // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
+ // nodes, rather than blocking on reading output from one node.
+ val minRequestSize = math.max(maxBytesInFlight / 5, 1L)
+ logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize)
+ val iterator = blockInfos.iterator
+ var curRequestSize = 0L
+ var curBlocks = new ArrayBuffer[(String, Long)]
+ while (iterator.hasNext) {
+ val (blockId, size) = iterator.next()
+ if (size > 0) {
+ curBlocks += ((blockId, size))
+ curRequestSize += size
+ } else if (size == 0){
+ //here we changes the totalBlocks
+ totalBlocks -= 1
+ } else {
+ throw new SparkException("Negative block size "+blockId)
+ }
+ if (curRequestSize >= minRequestSize) {
+ // Add this FetchRequest
+ remoteRequests += new FetchRequest(address, curBlocks)
+ curRequestSize = 0
+ curBlocks = new ArrayBuffer[(String, Long)]
+ }
+ }
+ // Add in the final request
+ if (!curBlocks.isEmpty) {
+ remoteRequests += new FetchRequest(address, curBlocks)
+ }
+ }
+ }
+ logInfo("Getting " + totalBlocks + " non 0-byte blocks out of " + originalTotalBlocks + " blocks")
+ remoteRequests
+ }
+
+ var copiers : List[_ <: Thread] = null
+
+ override def initialize(){
+ // Split Local Remote Blocks and adjust totalBlocks to include only the non 0-byte blocks
+ val remoteRequests = splitLocalRemoteBlocks()
+ // Add the remote requests into our queue in a random order
+ for (request <- Utils.randomize(remoteRequests)) {
+ fetchRequestsSync.put(request)
+ }
+
+ copiers = startCopiers(System.getProperty("spark.shuffle.copier.threads", "6").toInt)
+ logInfo("Started " + fetchRequestsSync.size + " remote gets in " + Utils.getUsedTimeMs(startTime))
+
+ // Get Local Blocks
+ startTime = System.currentTimeMillis
+ getLocalBlocks()
+ logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
+ }
+
+ override def next(): (String, Option[Iterator[Any]]) = {
+ resultsGotten += 1
+ val result = results.take()
+ // if all the results has been retrieved
+ // shutdown the copiers
+ if (resultsGotten == totalBlocks) {
+ if( copiers != null )
+ stopCopiers(copiers)
+ }
+ (result.blockId, if (result.failed) None else Some(result.deserialize()))
+ }
+ }
+
+ def apply(t: String,
+ blockManager: BlockManager,
+ blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])],
+ serializer: Serializer): BlockFetcherIterator = {
+ val iter = if (t == "netty") { new NettyBlockFetcherIterator(blockManager,blocksByAddress, serializer) }
+ else { new BasicBlockFetcherIterator(blockManager,blocksByAddress, serializer) }
+ iter.initialize
+ iter
+ }
+
+}
diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala
index 8154b8ca74..82bcbd5bc2 100644
--- a/core/src/main/scala/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/spark/storage/DiskStore.scala
@@ -14,13 +14,16 @@ import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
import spark.Utils
import spark.executor.ExecutorExitCode
import spark.serializer.{Serializer, SerializationStream}
+import spark.Logging
+import spark.network.netty.ShuffleSender
+import spark.network.netty.PathResolver
/**
* Stores BlockManager blocks on disk.
*/
private class DiskStore(blockManager: BlockManager, rootDirs: String)
- extends BlockStore(blockManager) {
+ extends BlockStore(blockManager) with Logging {
class DiskBlockObjectWriter(blockId: String, serializer: Serializer, bufferSize: Int)
extends BlockObjectWriter(blockId) {
@@ -79,19 +82,28 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
val MAX_DIR_CREATION_ATTEMPTS: Int = 10
val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
+ var shuffleSender : Thread = null
+ val thisInstance = this
// Create one local directory for each path mentioned in spark.local.dir; then, inside this
// directory, create multiple subdirectories that we will hash files into, in order to avoid
// having really large inodes at the top level.
val localDirs = createLocalDirs()
val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
+ val useNetty = System.getProperty("spark.shuffle.use.netty", "false").toBoolean
+
addShutdownHook()
+ if(useNetty){
+ startShuffleBlockSender()
+ }
+
def getBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int)
: BlockObjectWriter = {
new DiskBlockObjectWriter(blockId, serializer, bufferSize)
}
+
override def getSize(blockId: String): Long = {
getFile(blockId).length()
}
@@ -262,10 +274,48 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
localDirs.foreach { localDir =>
if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
}
+ if (useNetty && shuffleSender != null)
+ shuffleSender.stop
} catch {
case t: Throwable => logError("Exception while deleting local spark dirs", t)
}
}
})
}
+
+ private def startShuffleBlockSender (){
+ try {
+ val port = System.getProperty("spark.shuffle.sender.port", "6653").toInt
+
+ val pResolver = new PathResolver {
+ def getAbsolutePath(blockId:String):String = {
+ if (!blockId.startsWith("shuffle_")) {
+ return null
+ }
+ thisInstance.getFile(blockId).getAbsolutePath()
+ }
+ }
+ shuffleSender = new Thread {
+ override def run() = {
+ val sender = new ShuffleSender(port,pResolver)
+ logInfo("created ShuffleSender binding to port : "+ port)
+ sender.start
+ }
+ }
+ shuffleSender.setDaemon(true)
+ shuffleSender.start
+
+ } catch {
+ case interrupted: InterruptedException =>
+ logInfo("Runner thread for ShuffleBlockSender interrupted")
+
+ case e: Exception => {
+ logError("Error running ShuffleBlockSender ", e)
+ if (shuffleSender != null) {
+ shuffleSender.stop
+ shuffleSender = null
+ }
+ }
+ }
+ }
}
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 190d723435..dbfe5b0aa6 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -160,7 +160,8 @@ object SparkBuild extends Build {
"cc.spray" % "spray-can" % "1.0-M2.1" excludeAll(excludeNetty),
"cc.spray" % "spray-server" % "1.0-M2.1" excludeAll(excludeNetty),
"cc.spray" % "spray-json_2.9.2" % "1.1.1" excludeAll(excludeNetty),
- "org.apache.mesos" % "mesos" % "0.9.0-incubating"
+ "org.apache.mesos" % "mesos" % "0.9.0-incubating",
+ "io.netty" % "netty-all" % "4.0.0.Beta2"
) ++ (
if (HADOOP_MAJOR_VERSION == "2") {
if (HADOOP_YARN) {
diff --git a/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala
index d8b987ec86..bd0b0e74c1 100644
--- a/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala
+++ b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala
@@ -5,7 +5,7 @@ import spark.util.{RateLimitedOutputStream, IntParam}
import java.net.ServerSocket
import spark.{Logging, KryoSerializer}
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
-import io.Source
+import scala.io.Source
import java.io.IOException
/**