diff options
author | Shivaram Venkataraman <shivaram@eecs.berkeley.edu> | 2013-05-31 23:21:38 -0700 |
---|---|---|
committer | Shivaram Venkataraman <shivaram@eecs.berkeley.edu> | 2013-05-31 23:21:38 -0700 |
commit | 91aca9224936da84b16ea789cb81914579a0db03 (patch) | |
tree | 79319339baeb37a08c76ca32f426ed2b83f7238c | |
parent | 84530ba6d9fa47ee2863bb50c23742ecfa2a6a64 (diff) | |
download | spark-91aca9224936da84b16ea789cb81914579a0db03.tar.gz spark-91aca9224936da84b16ea789cb81914579a0db03.tar.bz2 spark-91aca9224936da84b16ea789cb81914579a0db03.zip |
Another round of Netty fixes.
1. Avoid race condition between stop and copier completion
2. Handle socket exceptions by reporting them and filling in a failed
FetchResult
4 files changed, 58 insertions, 45 deletions
diff --git a/core/src/main/java/spark/network/netty/FileClient.java b/core/src/main/java/spark/network/netty/FileClient.java index 3a62dacbc8..9c9b976ebe 100644 --- a/core/src/main/java/spark/network/netty/FileClient.java +++ b/core/src/main/java/spark/network/netty/FileClient.java @@ -8,9 +8,12 @@ import io.netty.channel.ChannelOption; import io.netty.channel.oio.OioEventLoopGroup; import io.netty.channel.socket.oio.OioSocketChannel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; class FileClient { + private Logger LOG = LoggerFactory.getLogger(this.getClass().getName()); private FileClientHandler handler = null; private Channel channel = null; private Bootstrap bootstrap = null; @@ -25,25 +28,10 @@ class FileClient { .channel(OioSocketChannel.class) .option(ChannelOption.SO_KEEPALIVE, true) .option(ChannelOption.TCP_NODELAY, true) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 0) // Disable connect timeout .handler(new FileClientChannelInitializer(handler)); } - public static final class ChannelCloseListener implements ChannelFutureListener { - private FileClient fc = null; - - public ChannelCloseListener(FileClient fc){ - this.fc = fc; - } - - @Override - public void operationComplete(ChannelFuture future) { - if (fc.bootstrap!=null){ - fc.bootstrap.shutdown(); - fc.bootstrap = null; - } - } - } - public void connect(String host, int port) { try { // Start the connection attempt. @@ -58,8 +46,8 @@ class FileClient { public void waitForClose() { try { channel.closeFuture().sync(); - } catch (InterruptedException e){ - e.printStackTrace(); + } catch (InterruptedException e) { + LOG.warn("FileClient interrupted", e); } } diff --git a/core/src/main/java/spark/network/netty/FileClientHandler.java b/core/src/main/java/spark/network/netty/FileClientHandler.java index 2069dee5ca..9fc9449827 100644 --- a/core/src/main/java/spark/network/netty/FileClientHandler.java +++ b/core/src/main/java/spark/network/netty/FileClientHandler.java @@ -9,7 +9,14 @@ abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter { private FileHeader currentHeader = null; + private volatile boolean handlerCalled = false; + + public boolean isComplete() { + return handlerCalled; + } + public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header); + public abstract void handleError(String blockId); @Override public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) { @@ -26,6 +33,7 @@ abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter { // get file if(in.readableBytes() >= currentHeader.fileLen()) { handle(ctx, in, currentHeader); + handlerCalled = true; currentHeader = null; ctx.close(); } diff --git a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala index a91f5a886d..8ec46d42fa 100644 --- a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala +++ b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala @@ -9,19 +9,35 @@ import io.netty.util.CharsetUtil import spark.Logging import spark.network.ConnectionManagerId +import scala.collection.JavaConverters._ + private[spark] class ShuffleCopier extends Logging { - def getBlock(cmId: ConnectionManagerId, blockId: String, + def getBlock(host: String, port: Int, blockId: String, resultCollectCallback: (String, Long, ByteBuf) => Unit) { val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback) val fc = new FileClient(handler) - fc.init() - fc.connect(cmId.host, cmId.port) - fc.sendRequest(blockId) - fc.waitForClose() - fc.close() + try { + fc.init() + fc.connect(host, port) + fc.sendRequest(blockId) + fc.waitForClose() + fc.close() + } catch { + // Handle any socket-related exceptions in FileClient + case e: Exception => { + logError("Shuffle copy of block " + blockId + " from " + host + ":" + port + + " failed", e) + handler.handleError(blockId) + } + } + } + + def getBlock(cmId: ConnectionManagerId, blockId: String, + resultCollectCallback: (String, Long, ByteBuf) => Unit) { + getBlock(cmId.host, cmId.port, blockId, resultCollectCallback) } def getBlocks(cmId: ConnectionManagerId, @@ -44,20 +60,18 @@ private[spark] object ShuffleCopier extends Logging { logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)"); resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen)) } - } - def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) { - logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"") + override def handleError(blockId: String) { + if (!isComplete) { + resultCollectCallBack(blockId, -1, null) + } + } } - def runGetBlock(host:String, port:Int, file:String){ - val handler = new ShuffleClientHandler(echoResultCollectCallBack) - val fc = new FileClient(handler) - fc.init(); - fc.connect(host, port) - fc.sendRequest(file) - fc.waitForClose(); - fc.close() + def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) { + if (size != -1) { + logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"") + } } def main(args: Array[String]) { @@ -71,14 +85,16 @@ private[spark] object ShuffleCopier extends Logging { val threads = if (args.length > 3) args(3).toInt else 10 val copiers = Executors.newFixedThreadPool(80) - for (i <- Range(0, threads)) { - val runnable = new Runnable() { + val tasks = (for (i <- Range(0, threads)) yield { + Executors.callable(new Runnable() { def run() { - runGetBlock(host, port, file) + val copier = new ShuffleCopier() + copier.getBlock(host, port, file, echoResultCollectCallBack) } - } - copiers.execute(runnable) - } + }) + }).asJava + copiers.invokeAll(tasks) copiers.shutdown + System.exit(0) } } diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala index 1d69d658f7..fac416a5b3 100644 --- a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala @@ -265,7 +265,7 @@ object BlockFetcherIterator { }).toList } - //keep this to interrupt the threads when necessary + // keep this to interrupt the threads when necessary private def stopCopiers() { for (copier <- copiers) { copier.interrupt() @@ -312,9 +312,10 @@ object BlockFetcherIterator { resultsGotten += 1 val result = results.take() // if all the results has been retrieved, shutdown the copiers - if (resultsGotten == _totalBlocks && copiers != null) { - stopCopiers() - } + // NO need to stop the copiers if we got all the blocks ? + // if (resultsGotten == _totalBlocks && copiers != null) { + // stopCopiers() + // } (result.blockId, if (result.failed) None else Some(result.deserialize())) } } |