From 91aca9224936da84b16ea789cb81914579a0db03 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Fri, 31 May 2013 23:21:38 -0700 Subject: Another round of Netty fixes. 1. Avoid race condition between stop and copier completion 2. Handle socket exceptions by reporting them and filling in a failed FetchResult --- .../main/java/spark/network/netty/FileClient.java | 24 +++------ .../spark/network/netty/FileClientHandler.java | 8 +++ .../scala/spark/network/netty/ShuffleCopier.scala | 62 ++++++++++++++-------- .../scala/spark/storage/BlockFetcherIterator.scala | 9 ++-- 4 files changed, 58 insertions(+), 45 deletions(-) diff --git a/core/src/main/java/spark/network/netty/FileClient.java b/core/src/main/java/spark/network/netty/FileClient.java index 3a62dacbc8..9c9b976ebe 100644 --- a/core/src/main/java/spark/network/netty/FileClient.java +++ b/core/src/main/java/spark/network/netty/FileClient.java @@ -8,9 +8,12 @@ import io.netty.channel.ChannelOption; import io.netty.channel.oio.OioEventLoopGroup; import io.netty.channel.socket.oio.OioSocketChannel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; class FileClient { + private Logger LOG = LoggerFactory.getLogger(this.getClass().getName()); private FileClientHandler handler = null; private Channel channel = null; private Bootstrap bootstrap = null; @@ -25,25 +28,10 @@ class FileClient { .channel(OioSocketChannel.class) .option(ChannelOption.SO_KEEPALIVE, true) .option(ChannelOption.TCP_NODELAY, true) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 0) // Disable connect timeout .handler(new FileClientChannelInitializer(handler)); } - public static final class ChannelCloseListener implements ChannelFutureListener { - private FileClient fc = null; - - public ChannelCloseListener(FileClient fc){ - this.fc = fc; - } - - @Override - public void operationComplete(ChannelFuture future) { - if (fc.bootstrap!=null){ - fc.bootstrap.shutdown(); - fc.bootstrap = null; - } - } - } - public void connect(String host, int port) { try { // Start the connection attempt. @@ -58,8 +46,8 @@ class FileClient { public void waitForClose() { try { channel.closeFuture().sync(); - } catch (InterruptedException e){ - e.printStackTrace(); + } catch (InterruptedException e) { + LOG.warn("FileClient interrupted", e); } } diff --git a/core/src/main/java/spark/network/netty/FileClientHandler.java b/core/src/main/java/spark/network/netty/FileClientHandler.java index 2069dee5ca..9fc9449827 100644 --- a/core/src/main/java/spark/network/netty/FileClientHandler.java +++ b/core/src/main/java/spark/network/netty/FileClientHandler.java @@ -9,7 +9,14 @@ abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter { private FileHeader currentHeader = null; + private volatile boolean handlerCalled = false; + + public boolean isComplete() { + return handlerCalled; + } + public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header); + public abstract void handleError(String blockId); @Override public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) { @@ -26,6 +33,7 @@ abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter { // get file if(in.readableBytes() >= currentHeader.fileLen()) { handle(ctx, in, currentHeader); + handlerCalled = true; currentHeader = null; ctx.close(); } diff --git a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala index a91f5a886d..8ec46d42fa 100644 --- a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala +++ b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala @@ -9,19 +9,35 @@ import io.netty.util.CharsetUtil import spark.Logging import spark.network.ConnectionManagerId +import scala.collection.JavaConverters._ + private[spark] class ShuffleCopier extends Logging { - def getBlock(cmId: ConnectionManagerId, blockId: String, + def getBlock(host: String, port: Int, blockId: String, resultCollectCallback: (String, Long, ByteBuf) => Unit) { val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback) val fc = new FileClient(handler) - fc.init() - fc.connect(cmId.host, cmId.port) - fc.sendRequest(blockId) - fc.waitForClose() - fc.close() + try { + fc.init() + fc.connect(host, port) + fc.sendRequest(blockId) + fc.waitForClose() + fc.close() + } catch { + // Handle any socket-related exceptions in FileClient + case e: Exception => { + logError("Shuffle copy of block " + blockId + " from " + host + ":" + port + + " failed", e) + handler.handleError(blockId) + } + } + } + + def getBlock(cmId: ConnectionManagerId, blockId: String, + resultCollectCallback: (String, Long, ByteBuf) => Unit) { + getBlock(cmId.host, cmId.port, blockId, resultCollectCallback) } def getBlocks(cmId: ConnectionManagerId, @@ -44,20 +60,18 @@ private[spark] object ShuffleCopier extends Logging { logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)"); resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen)) } - } - def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) { - logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"") + override def handleError(blockId: String) { + if (!isComplete) { + resultCollectCallBack(blockId, -1, null) + } + } } - def runGetBlock(host:String, port:Int, file:String){ - val handler = new ShuffleClientHandler(echoResultCollectCallBack) - val fc = new FileClient(handler) - fc.init(); - fc.connect(host, port) - fc.sendRequest(file) - fc.waitForClose(); - fc.close() + def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) { + if (size != -1) { + logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"") + } } def main(args: Array[String]) { @@ -71,14 +85,16 @@ private[spark] object ShuffleCopier extends Logging { val threads = if (args.length > 3) args(3).toInt else 10 val copiers = Executors.newFixedThreadPool(80) - for (i <- Range(0, threads)) { - val runnable = new Runnable() { + val tasks = (for (i <- Range(0, threads)) yield { + Executors.callable(new Runnable() { def run() { - runGetBlock(host, port, file) + val copier = new ShuffleCopier() + copier.getBlock(host, port, file, echoResultCollectCallBack) } - } - copiers.execute(runnable) - } + }) + }).asJava + copiers.invokeAll(tasks) copiers.shutdown + System.exit(0) } } diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala index 1d69d658f7..fac416a5b3 100644 --- a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala @@ -265,7 +265,7 @@ object BlockFetcherIterator { }).toList } - //keep this to interrupt the threads when necessary + // keep this to interrupt the threads when necessary private def stopCopiers() { for (copier <- copiers) { copier.interrupt() @@ -312,9 +312,10 @@ object BlockFetcherIterator { resultsGotten += 1 val result = results.take() // if all the results has been retrieved, shutdown the copiers - if (resultsGotten == _totalBlocks && copiers != null) { - stopCopiers() - } + // NO need to stop the copiers if we got all the blocks ? + // if (resultsGotten == _totalBlocks && copiers != null) { + // stopCopiers() + // } (result.blockId, if (result.failed) None else Some(result.deserialize())) } } -- cgit v1.2.3 From 038cfc1a9acb32f8c17d883ea64f8cbb324ed82c Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Fri, 31 May 2013 23:32:18 -0700 Subject: Make connect timeout configurable --- core/src/main/java/spark/network/netty/FileClient.java | 6 ++++-- core/src/main/scala/spark/network/netty/ShuffleCopier.scala | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/spark/network/netty/FileClient.java b/core/src/main/java/spark/network/netty/FileClient.java index 9c9b976ebe..517772202f 100644 --- a/core/src/main/java/spark/network/netty/FileClient.java +++ b/core/src/main/java/spark/network/netty/FileClient.java @@ -17,9 +17,11 @@ class FileClient { private FileClientHandler handler = null; private Channel channel = null; private Bootstrap bootstrap = null; + private int connectTimeout = 60*1000; // 1 min - public FileClient(FileClientHandler handler) { + public FileClient(FileClientHandler handler, int connectTimeout) { this.handler = handler; + this.connectTimeout = connectTimeout; } public void init() { @@ -28,7 +30,7 @@ class FileClient { .channel(OioSocketChannel.class) .option(ChannelOption.SO_KEEPALIVE, true) .option(ChannelOption.TCP_NODELAY, true) - .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 0) // Disable connect timeout + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeout) // Disable connect timeout .handler(new FileClientChannelInitializer(handler)); } diff --git a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala index 8ec46d42fa..afb2cdbb3a 100644 --- a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala +++ b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala @@ -18,7 +18,8 @@ private[spark] class ShuffleCopier extends Logging { resultCollectCallback: (String, Long, ByteBuf) => Unit) { val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback) - val fc = new FileClient(handler) + val fc = new FileClient(handler, + System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt) try { fc.init() fc.connect(host, port) -- cgit v1.2.3 From a058b0acf3e5ae41e64640feeace3d4e32f47401 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Mon, 3 Jun 2013 12:10:00 -0700 Subject: Delete a file for a block if it already exists. --- core/src/main/scala/spark/storage/DiskStore.scala | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index c7281200e7..2be5d01e31 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -195,9 +195,15 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) } private def createFile(blockId: String, allowAppendExisting: Boolean = false): File = { - val file = getFile(blockId) + var file = getFile(blockId) if (!allowAppendExisting && file.exists()) { - throw new Exception("File for block " + blockId + " already exists on disk: " + file) + // NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task + // was rescheduled on the same machine as the old task ? + logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting") + file.delete() + // Reopen the file + file = getFile(blockId) + // throw new Exception("File for block " + blockId + " already exists on disk: " + file) } file } -- cgit v1.2.3 From cd347f547a9a9b7bdd0d3f4734ae5c13be54f75d Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Mon, 3 Jun 2013 12:27:51 -0700 Subject: Reuse the file object as it is valid after delete --- core/src/main/scala/spark/storage/DiskStore.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index 2be5d01e31..e51d258a21 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -201,8 +201,6 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) // was rescheduled on the same machine as the old task ? logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting") file.delete() - // Reopen the file - file = getFile(blockId) // throw new Exception("File for block " + blockId + " already exists on disk: " + file) } file -- cgit v1.2.3 From 96943a1cc054d7cf80eb8d3dfc7fb19ce48d3c0a Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Mon, 3 Jun 2013 12:29:38 -0700 Subject: var to val --- core/src/main/scala/spark/storage/DiskStore.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index e51d258a21..cd85fa1e9d 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -195,7 +195,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) } private def createFile(blockId: String, allowAppendExisting: Boolean = false): File = { - var file = getFile(blockId) + val file = getFile(blockId) if (!allowAppendExisting && file.exists()) { // NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task // was rescheduled on the same machine as the old task ? -- cgit v1.2.3 From c851957fe4798d5dfb8deba7bf79a035a0543c74 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Wed, 5 Jun 2013 14:28:38 -0700 Subject: Don't write zero block files with java serializer --- .../scala/spark/storage/BlockFetcherIterator.scala | 5 ++- core/src/main/scala/spark/storage/DiskStore.scala | 46 ++++++++++++++-------- .../scala/spark/storage/ShuffleBlockManager.scala | 2 +- core/src/test/scala/spark/ShuffleSuite.scala | 26 ++++++++++++ 4 files changed, 61 insertions(+), 18 deletions(-) diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala index fac416a5b3..843069239c 100644 --- a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala @@ -71,6 +71,7 @@ object BlockFetcherIterator { logDebug("Getting " + _totalBlocks + " blocks") protected var startTime = System.currentTimeMillis protected val localBlockIds = new ArrayBuffer[String]() + protected val localNonZeroBlocks = new ArrayBuffer[String]() protected val remoteBlockIds = new HashSet[String]() // A queue to hold our results. @@ -129,6 +130,8 @@ object BlockFetcherIterator { for ((address, blockInfos) <- blocksByAddress) { if (address == blockManagerId) { localBlockIds ++= blockInfos.map(_._1) + localNonZeroBlocks ++= blockInfos.filter(_._2 != 0).map(_._1) + _totalBlocks -= (localBlockIds.size - localNonZeroBlocks.size) } else { remoteBlockIds ++= blockInfos.map(_._1) // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them @@ -172,7 +175,7 @@ object BlockFetcherIterator { // Get the local blocks while remote blocks are being fetched. Note that it's okay to do // these all at once because they will just memory-map some files, so they won't consume // any memory that might exceed our maxBytesInFlight - for (id <- localBlockIds) { + for (id <- localNonZeroBlocks) { getLocalFromDisk(id, serializer) match { case Some(iter) => { // Pass 0 as size since it's not in flight diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index cd85fa1e9d..c1cff25552 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -35,21 +35,25 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) private var bs: OutputStream = null private var objOut: SerializationStream = null private var lastValidPosition = 0L + private var initialized = false override def open(): DiskBlockObjectWriter = { val fos = new FileOutputStream(f, true) channel = fos.getChannel() bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos)) objOut = serializer.newInstance().serializeStream(bs) + initialized = true this } override def close() { - objOut.close() - bs.close() - channel = null - bs = null - objOut = null + if (initialized) { + objOut.close() + bs.close() + channel = null + bs = null + objOut = null + } // Invoke the close callback handler. super.close() } @@ -59,23 +63,33 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) // Flush the partial writes, and set valid length to be the length of the entire file. // Return the number of bytes written for this commit. override def commit(): Long = { - // NOTE: Flush the serializer first and then the compressed/buffered output stream - objOut.flush() - bs.flush() - val prevPos = lastValidPosition - lastValidPosition = channel.position() - lastValidPosition - prevPos + if (initialized) { + // NOTE: Flush the serializer first and then the compressed/buffered output stream + objOut.flush() + bs.flush() + val prevPos = lastValidPosition + lastValidPosition = channel.position() + lastValidPosition - prevPos + } else { + // lastValidPosition is zero if stream is uninitialized + lastValidPosition + } } override def revertPartialWrites() { - // Discard current writes. We do this by flushing the outstanding writes and - // truncate the file to the last valid position. - objOut.flush() - bs.flush() - channel.truncate(lastValidPosition) + if (initialized) { + // Discard current writes. We do this by flushing the outstanding writes and + // truncate the file to the last valid position. + objOut.flush() + bs.flush() + channel.truncate(lastValidPosition) + } } override def write(value: Any) { + if (!initialized) { + open() + } objOut.writeObject(value) } diff --git a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala index 49eabfb0d2..44638e0c2d 100644 --- a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala @@ -24,7 +24,7 @@ class ShuffleBlockManager(blockManager: BlockManager) { val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024 val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId) - blockManager.getDiskBlockWriter(blockId, serializer, bufferSize).open() + blockManager.getDiskBlockWriter(blockId, serializer, bufferSize) } new ShuffleWriterGroup(mapId, writers) } diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index b967016cf7..33b02fff80 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -367,6 +367,32 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { assert(nonEmptyBlocks.size <= 4) } + test("zero sized blocks without kryo") { + // Use a local cluster with 2 processes to make sure there are both local and remote blocks + sc = new SparkContext("local-cluster[2,1,512]", "test") + + // 10 partitions from 4 keys + val NUM_BLOCKS = 10 + val a = sc.parallelize(1 to 4, NUM_BLOCKS) + val b = a.map(x => (x, x*2)) + + // NOTE: The default Java serializer doesn't create zero-sized blocks. + // So, use Kryo + val c = new ShuffledRDD(b, new HashPartitioner(10)) + + val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId + assert(c.count === 4) + + val blockSizes = (0 until NUM_BLOCKS).flatMap { id => + val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) + statuses.map(x => x._2) + } + val nonEmptyBlocks = blockSizes.filter(x => x > 0) + + // We should have at most 4 non-zero sized partitions + assert(nonEmptyBlocks.size <= 4) + } + } object ShuffleSuite { -- cgit v1.2.3 From cb2f5046ee99582a5038a78478c23468b14c134e Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Wed, 5 Jun 2013 15:09:02 -0700 Subject: Pass in bufferSize to BufferedOutputStream --- core/src/main/scala/spark/storage/DiskStore.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index c1cff25552..0af6e4a359 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -40,7 +40,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) override def open(): DiskBlockObjectWriter = { val fos = new FileOutputStream(f, true) channel = fos.getChannel() - bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos)) + bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos, bufferSize)) objOut = serializer.newInstance().serializeStream(bs) initialized = true this -- cgit v1.2.3 From ac480fd977e0de97bcfe646e39feadbd239c1c29 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Thu, 6 Jun 2013 16:34:27 -0700 Subject: Clean up variables and counters in BlockFetcherIterator --- .../scala/spark/storage/BlockFetcherIterator.scala | 54 +++++++++++++--------- 1 file changed, 31 insertions(+), 23 deletions(-) diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala index 843069239c..bb78207c9f 100644 --- a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala @@ -67,12 +67,20 @@ object BlockFetcherIterator { throw new IllegalArgumentException("BlocksByAddress is null") } - protected var _totalBlocks = blocksByAddress.map(_._2.size).sum - logDebug("Getting " + _totalBlocks + " blocks") + // Total number blocks fetched (local + remote). Also number of FetchResults expected + protected var _numBlocksToFetch = 0 + protected var startTime = System.currentTimeMillis - protected val localBlockIds = new ArrayBuffer[String]() - protected val localNonZeroBlocks = new ArrayBuffer[String]() - protected val remoteBlockIds = new HashSet[String]() + + // This represents the number of local blocks, also counting zero-sized blocks + private var numLocal = 0 + // BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks + protected val localBlocksToFetch = new ArrayBuffer[String]() + + // This represents the number of remote blocks, also counting zero-sized blocks + private var numRemote = 0 + // BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks + protected val remoteBlocksToFetch = new HashSet[String]() // A queue to hold our results. protected val results = new LinkedBlockingQueue[FetchResult] @@ -125,15 +133,15 @@ object BlockFetcherIterator { protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { // Split local and remote blocks. Remote blocks are further split into FetchRequests of size // at most maxBytesInFlight in order to limit the amount of data in flight. - val originalTotalBlocks = _totalBlocks val remoteRequests = new ArrayBuffer[FetchRequest] for ((address, blockInfos) <- blocksByAddress) { if (address == blockManagerId) { - localBlockIds ++= blockInfos.map(_._1) - localNonZeroBlocks ++= blockInfos.filter(_._2 != 0).map(_._1) - _totalBlocks -= (localBlockIds.size - localNonZeroBlocks.size) + numLocal = blockInfos.size + // Filter out zero-sized blocks + localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1) + _numBlocksToFetch += localBlocksToFetch.size } else { - remoteBlockIds ++= blockInfos.map(_._1) + numRemote += blockInfos.size // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 // nodes, rather than blocking on reading output from one node. @@ -147,10 +155,10 @@ object BlockFetcherIterator { // Skip empty blocks if (size > 0) { curBlocks += ((blockId, size)) + remoteBlocksToFetch += blockId + _numBlocksToFetch += 1 curRequestSize += size - } else if (size == 0) { - _totalBlocks -= 1 - } else { + } else if (size < 0) { throw new BlockException(blockId, "Negative block size " + size) } if (curRequestSize >= minRequestSize) { @@ -166,8 +174,8 @@ object BlockFetcherIterator { } } } - logInfo("Getting " + _totalBlocks + " non-zero-bytes blocks out of " + - originalTotalBlocks + " blocks") + logInfo("Getting " + _numBlocksToFetch + " non-zero-bytes blocks out of " + + totalBlocks + " blocks") remoteRequests } @@ -175,7 +183,7 @@ object BlockFetcherIterator { // Get the local blocks while remote blocks are being fetched. Note that it's okay to do // these all at once because they will just memory-map some files, so they won't consume // any memory that might exceed our maxBytesInFlight - for (id <- localNonZeroBlocks) { + for (id <- localBlocksToFetch) { getLocalFromDisk(id, serializer) match { case Some(iter) => { // Pass 0 as size since it's not in flight @@ -201,7 +209,7 @@ object BlockFetcherIterator { sendRequest(fetchRequests.dequeue()) } - val numGets = remoteBlockIds.size - fetchRequests.size + val numGets = remoteRequests.size - fetchRequests.size logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime)) // Get Local Blocks @@ -213,7 +221,7 @@ object BlockFetcherIterator { //an iterator that will read fetched blocks off the queue as they arrive. @volatile protected var resultsGotten = 0 - override def hasNext: Boolean = resultsGotten < _totalBlocks + override def hasNext: Boolean = resultsGotten < _numBlocksToFetch override def next(): (String, Option[Iterator[Any]]) = { resultsGotten += 1 @@ -230,9 +238,9 @@ object BlockFetcherIterator { } // Implementing BlockFetchTracker trait. - override def totalBlocks: Int = _totalBlocks - override def numLocalBlocks: Int = localBlockIds.size - override def numRemoteBlocks: Int = remoteBlockIds.size + override def totalBlocks: Int = numLocal + numRemote + override def numLocalBlocks: Int = numLocal + override def numRemoteBlocks: Int = numRemote override def remoteFetchTime: Long = _remoteFetchTime override def fetchWaitTime: Long = _fetchWaitTime override def remoteBytesRead: Long = _remoteBytesRead @@ -294,7 +302,7 @@ object BlockFetcherIterator { private var copiers: List[_ <: Thread] = null override def initialize() { - // Split Local Remote Blocks and adjust totalBlocks to include only the non 0-byte blocks + // Split Local Remote Blocks and set numBlocksToFetch val remoteRequests = splitLocalRemoteBlocks() // Add the remote requests into our queue in a random order for (request <- Utils.randomize(remoteRequests)) { @@ -316,7 +324,7 @@ object BlockFetcherIterator { val result = results.take() // if all the results has been retrieved, shutdown the copiers // NO need to stop the copiers if we got all the blocks ? - // if (resultsGotten == _totalBlocks && copiers != null) { + // if (resultsGotten == _numBlocksToFetch && copiers != null) { // stopCopiers() // } (result.blockId, if (result.failed) None else Some(result.deserialize())) -- cgit v1.2.3 From 1d9f0df0652f455145d2dfed43a9407df6de6c43 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Thu, 13 Jun 2013 14:46:25 -0700 Subject: Fix some comments and style --- core/src/main/java/spark/network/netty/FileClient.java | 2 +- core/src/main/scala/spark/network/netty/ShuffleCopier.scala | 8 ++++---- core/src/main/scala/spark/storage/BlockFetcherIterator.scala | 6 +----- core/src/main/scala/spark/storage/DiskStore.scala | 3 +-- core/src/test/scala/spark/ShuffleSuite.scala | 3 +-- 5 files changed, 8 insertions(+), 14 deletions(-) diff --git a/core/src/main/java/spark/network/netty/FileClient.java b/core/src/main/java/spark/network/netty/FileClient.java index 517772202f..a4bb4bc701 100644 --- a/core/src/main/java/spark/network/netty/FileClient.java +++ b/core/src/main/java/spark/network/netty/FileClient.java @@ -30,7 +30,7 @@ class FileClient { .channel(OioSocketChannel.class) .option(ChannelOption.SO_KEEPALIVE, true) .option(ChannelOption.TCP_NODELAY, true) - .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeout) // Disable connect timeout + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeout) .handler(new FileClientChannelInitializer(handler)); } diff --git a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala index afb2cdbb3a..8d5194a737 100644 --- a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala +++ b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala @@ -18,8 +18,9 @@ private[spark] class ShuffleCopier extends Logging { resultCollectCallback: (String, Long, ByteBuf) => Unit) { val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback) - val fc = new FileClient(handler, - System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt) + val connectTimeout = System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt + val fc = new FileClient(handler, connectTimeout) + try { fc.init() fc.connect(host, port) @@ -29,8 +30,7 @@ private[spark] class ShuffleCopier extends Logging { } catch { // Handle any socket-related exceptions in FileClient case e: Exception => { - logError("Shuffle copy of block " + blockId + " from " + host + ":" + port + - " failed", e) + logError("Shuffle copy of block " + blockId + " from " + host + ":" + port + " failed", e) handler.handleError(blockId) } } diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala index bb78207c9f..bec876213e 100644 --- a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala @@ -322,11 +322,7 @@ object BlockFetcherIterator { override def next(): (String, Option[Iterator[Any]]) = { resultsGotten += 1 val result = results.take() - // if all the results has been retrieved, shutdown the copiers - // NO need to stop the copiers if we got all the blocks ? - // if (resultsGotten == _numBlocksToFetch && copiers != null) { - // stopCopiers() - // } + // If all the results has been retrieved, copiers will exit automatically (result.blockId, if (result.failed) None else Some(result.deserialize())) } } diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index 0af6e4a359..15ab840155 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -212,10 +212,9 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) val file = getFile(blockId) if (!allowAppendExisting && file.exists()) { // NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task - // was rescheduled on the same machine as the old task ? + // was rescheduled on the same machine as the old task. logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting") file.delete() - // throw new Exception("File for block " + blockId + " already exists on disk: " + file) } file } diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index 33b02fff80..1916885a73 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -376,8 +376,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { val a = sc.parallelize(1 to 4, NUM_BLOCKS) val b = a.map(x => (x, x*2)) - // NOTE: The default Java serializer doesn't create zero-sized blocks. - // So, use Kryo + // NOTE: The default Java serializer should create zero-sized blocks val c = new ShuffledRDD(b, new HashPartitioner(10)) val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId -- cgit v1.2.3