aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2013-06-22 16:22:47 -0700
committerMatei Zaharia <matei@eecs.berkeley.edu>2013-06-22 16:22:47 -0700
commit3e61beff7b41217a40afdccd1e413d9b90fe6e54 (patch)
treecd820cc06b7d3281573bf7366fc07003b318ebb4
parentd92d3f7938dec954ea31de232f50cafd4b644065 (diff)
parent1d9f0df0652f455145d2dfed43a9407df6de6c43 (diff)
downloadspark-3e61beff7b41217a40afdccd1e413d9b90fe6e54.tar.gz
spark-3e61beff7b41217a40afdccd1e413d9b90fe6e54.tar.bz2
spark-3e61beff7b41217a40afdccd1e413d9b90fe6e54.zip
Merge pull request #648 from shivaram/netty-dbg
Shuffle fixes and cleanup
-rw-r--r--core/src/main/java/spark/network/netty/FileClient.java28
-rw-r--r--core/src/main/java/spark/network/netty/FileClientHandler.java8
-rw-r--r--core/src/main/scala/spark/network/netty/ShuffleCopier.scala65
-rw-r--r--core/src/main/scala/spark/storage/BlockFetcherIterator.scala56
-rw-r--r--core/src/main/scala/spark/storage/DiskStore.scala53
-rw-r--r--core/src/main/scala/spark/storage/ShuffleBlockManager.scala2
-rw-r--r--core/src/test/scala/spark/ShuffleSuite.scala25
7 files changed, 151 insertions, 86 deletions
diff --git a/core/src/main/java/spark/network/netty/FileClient.java b/core/src/main/java/spark/network/netty/FileClient.java
index 3a62dacbc8..a4bb4bc701 100644
--- a/core/src/main/java/spark/network/netty/FileClient.java
+++ b/core/src/main/java/spark/network/netty/FileClient.java
@@ -8,15 +8,20 @@ import io.netty.channel.ChannelOption;
import io.netty.channel.oio.OioEventLoopGroup;
import io.netty.channel.socket.oio.OioSocketChannel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
class FileClient {
+ private Logger LOG = LoggerFactory.getLogger(this.getClass().getName());
private FileClientHandler handler = null;
private Channel channel = null;
private Bootstrap bootstrap = null;
+ private int connectTimeout = 60*1000; // 1 min
- public FileClient(FileClientHandler handler) {
+ public FileClient(FileClientHandler handler, int connectTimeout) {
this.handler = handler;
+ this.connectTimeout = connectTimeout;
}
public void init() {
@@ -25,25 +30,10 @@ class FileClient {
.channel(OioSocketChannel.class)
.option(ChannelOption.SO_KEEPALIVE, true)
.option(ChannelOption.TCP_NODELAY, true)
+ .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeout)
.handler(new FileClientChannelInitializer(handler));
}
- public static final class ChannelCloseListener implements ChannelFutureListener {
- private FileClient fc = null;
-
- public ChannelCloseListener(FileClient fc){
- this.fc = fc;
- }
-
- @Override
- public void operationComplete(ChannelFuture future) {
- if (fc.bootstrap!=null){
- fc.bootstrap.shutdown();
- fc.bootstrap = null;
- }
- }
- }
-
public void connect(String host, int port) {
try {
// Start the connection attempt.
@@ -58,8 +48,8 @@ class FileClient {
public void waitForClose() {
try {
channel.closeFuture().sync();
- } catch (InterruptedException e){
- e.printStackTrace();
+ } catch (InterruptedException e) {
+ LOG.warn("FileClient interrupted", e);
}
}
diff --git a/core/src/main/java/spark/network/netty/FileClientHandler.java b/core/src/main/java/spark/network/netty/FileClientHandler.java
index 2069dee5ca..9fc9449827 100644
--- a/core/src/main/java/spark/network/netty/FileClientHandler.java
+++ b/core/src/main/java/spark/network/netty/FileClientHandler.java
@@ -9,7 +9,14 @@ abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter {
private FileHeader currentHeader = null;
+ private volatile boolean handlerCalled = false;
+
+ public boolean isComplete() {
+ return handlerCalled;
+ }
+
public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header);
+ public abstract void handleError(String blockId);
@Override
public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) {
@@ -26,6 +33,7 @@ abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter {
// get file
if(in.readableBytes() >= currentHeader.fileLen()) {
handle(ctx, in, currentHeader);
+ handlerCalled = true;
currentHeader = null;
ctx.close();
}
diff --git a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala
index a91f5a886d..8d5194a737 100644
--- a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala
+++ b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala
@@ -9,19 +9,36 @@ import io.netty.util.CharsetUtil
import spark.Logging
import spark.network.ConnectionManagerId
+import scala.collection.JavaConverters._
+
private[spark] class ShuffleCopier extends Logging {
- def getBlock(cmId: ConnectionManagerId, blockId: String,
+ def getBlock(host: String, port: Int, blockId: String,
resultCollectCallback: (String, Long, ByteBuf) => Unit) {
val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback)
- val fc = new FileClient(handler)
- fc.init()
- fc.connect(cmId.host, cmId.port)
- fc.sendRequest(blockId)
- fc.waitForClose()
- fc.close()
+ val connectTimeout = System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt
+ val fc = new FileClient(handler, connectTimeout)
+
+ try {
+ fc.init()
+ fc.connect(host, port)
+ fc.sendRequest(blockId)
+ fc.waitForClose()
+ fc.close()
+ } catch {
+ // Handle any socket-related exceptions in FileClient
+ case e: Exception => {
+ logError("Shuffle copy of block " + blockId + " from " + host + ":" + port + " failed", e)
+ handler.handleError(blockId)
+ }
+ }
+ }
+
+ def getBlock(cmId: ConnectionManagerId, blockId: String,
+ resultCollectCallback: (String, Long, ByteBuf) => Unit) {
+ getBlock(cmId.host, cmId.port, blockId, resultCollectCallback)
}
def getBlocks(cmId: ConnectionManagerId,
@@ -44,20 +61,18 @@ private[spark] object ShuffleCopier extends Logging {
logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)");
resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen))
}
- }
- def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) {
- logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"")
+ override def handleError(blockId: String) {
+ if (!isComplete) {
+ resultCollectCallBack(blockId, -1, null)
+ }
+ }
}
- def runGetBlock(host:String, port:Int, file:String){
- val handler = new ShuffleClientHandler(echoResultCollectCallBack)
- val fc = new FileClient(handler)
- fc.init();
- fc.connect(host, port)
- fc.sendRequest(file)
- fc.waitForClose();
- fc.close()
+ def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) {
+ if (size != -1) {
+ logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"")
+ }
}
def main(args: Array[String]) {
@@ -71,14 +86,16 @@ private[spark] object ShuffleCopier extends Logging {
val threads = if (args.length > 3) args(3).toInt else 10
val copiers = Executors.newFixedThreadPool(80)
- for (i <- Range(0, threads)) {
- val runnable = new Runnable() {
+ val tasks = (for (i <- Range(0, threads)) yield {
+ Executors.callable(new Runnable() {
def run() {
- runGetBlock(host, port, file)
+ val copier = new ShuffleCopier()
+ copier.getBlock(host, port, file, echoResultCollectCallBack)
}
- }
- copiers.execute(runnable)
- }
+ })
+ }).asJava
+ copiers.invokeAll(tasks)
copiers.shutdown
+ System.exit(0)
}
}
diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala
index 1d69d658f7..bec876213e 100644
--- a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala
+++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala
@@ -67,11 +67,20 @@ object BlockFetcherIterator {
throw new IllegalArgumentException("BlocksByAddress is null")
}
- protected var _totalBlocks = blocksByAddress.map(_._2.size).sum
- logDebug("Getting " + _totalBlocks + " blocks")
+ // Total number blocks fetched (local + remote). Also number of FetchResults expected
+ protected var _numBlocksToFetch = 0
+
protected var startTime = System.currentTimeMillis
- protected val localBlockIds = new ArrayBuffer[String]()
- protected val remoteBlockIds = new HashSet[String]()
+
+ // This represents the number of local blocks, also counting zero-sized blocks
+ private var numLocal = 0
+ // BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks
+ protected val localBlocksToFetch = new ArrayBuffer[String]()
+
+ // This represents the number of remote blocks, also counting zero-sized blocks
+ private var numRemote = 0
+ // BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks
+ protected val remoteBlocksToFetch = new HashSet[String]()
// A queue to hold our results.
protected val results = new LinkedBlockingQueue[FetchResult]
@@ -124,13 +133,15 @@ object BlockFetcherIterator {
protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
// Split local and remote blocks. Remote blocks are further split into FetchRequests of size
// at most maxBytesInFlight in order to limit the amount of data in flight.
- val originalTotalBlocks = _totalBlocks
val remoteRequests = new ArrayBuffer[FetchRequest]
for ((address, blockInfos) <- blocksByAddress) {
if (address == blockManagerId) {
- localBlockIds ++= blockInfos.map(_._1)
+ numLocal = blockInfos.size
+ // Filter out zero-sized blocks
+ localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1)
+ _numBlocksToFetch += localBlocksToFetch.size
} else {
- remoteBlockIds ++= blockInfos.map(_._1)
+ numRemote += blockInfos.size
// Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
// smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
// nodes, rather than blocking on reading output from one node.
@@ -144,10 +155,10 @@ object BlockFetcherIterator {
// Skip empty blocks
if (size > 0) {
curBlocks += ((blockId, size))
+ remoteBlocksToFetch += blockId
+ _numBlocksToFetch += 1
curRequestSize += size
- } else if (size == 0) {
- _totalBlocks -= 1
- } else {
+ } else if (size < 0) {
throw new BlockException(blockId, "Negative block size " + size)
}
if (curRequestSize >= minRequestSize) {
@@ -163,8 +174,8 @@ object BlockFetcherIterator {
}
}
}
- logInfo("Getting " + _totalBlocks + " non-zero-bytes blocks out of " +
- originalTotalBlocks + " blocks")
+ logInfo("Getting " + _numBlocksToFetch + " non-zero-bytes blocks out of " +
+ totalBlocks + " blocks")
remoteRequests
}
@@ -172,7 +183,7 @@ object BlockFetcherIterator {
// Get the local blocks while remote blocks are being fetched. Note that it's okay to do
// these all at once because they will just memory-map some files, so they won't consume
// any memory that might exceed our maxBytesInFlight
- for (id <- localBlockIds) {
+ for (id <- localBlocksToFetch) {
getLocalFromDisk(id, serializer) match {
case Some(iter) => {
// Pass 0 as size since it's not in flight
@@ -198,7 +209,7 @@ object BlockFetcherIterator {
sendRequest(fetchRequests.dequeue())
}
- val numGets = remoteBlockIds.size - fetchRequests.size
+ val numGets = remoteRequests.size - fetchRequests.size
logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime))
// Get Local Blocks
@@ -210,7 +221,7 @@ object BlockFetcherIterator {
//an iterator that will read fetched blocks off the queue as they arrive.
@volatile protected var resultsGotten = 0
- override def hasNext: Boolean = resultsGotten < _totalBlocks
+ override def hasNext: Boolean = resultsGotten < _numBlocksToFetch
override def next(): (String, Option[Iterator[Any]]) = {
resultsGotten += 1
@@ -227,9 +238,9 @@ object BlockFetcherIterator {
}
// Implementing BlockFetchTracker trait.
- override def totalBlocks: Int = _totalBlocks
- override def numLocalBlocks: Int = localBlockIds.size
- override def numRemoteBlocks: Int = remoteBlockIds.size
+ override def totalBlocks: Int = numLocal + numRemote
+ override def numLocalBlocks: Int = numLocal
+ override def numRemoteBlocks: Int = numRemote
override def remoteFetchTime: Long = _remoteFetchTime
override def fetchWaitTime: Long = _fetchWaitTime
override def remoteBytesRead: Long = _remoteBytesRead
@@ -265,7 +276,7 @@ object BlockFetcherIterator {
}).toList
}
- //keep this to interrupt the threads when necessary
+ // keep this to interrupt the threads when necessary
private def stopCopiers() {
for (copier <- copiers) {
copier.interrupt()
@@ -291,7 +302,7 @@ object BlockFetcherIterator {
private var copiers: List[_ <: Thread] = null
override def initialize() {
- // Split Local Remote Blocks and adjust totalBlocks to include only the non 0-byte blocks
+ // Split Local Remote Blocks and set numBlocksToFetch
val remoteRequests = splitLocalRemoteBlocks()
// Add the remote requests into our queue in a random order
for (request <- Utils.randomize(remoteRequests)) {
@@ -311,10 +322,7 @@ object BlockFetcherIterator {
override def next(): (String, Option[Iterator[Any]]) = {
resultsGotten += 1
val result = results.take()
- // if all the results has been retrieved, shutdown the copiers
- if (resultsGotten == _totalBlocks && copiers != null) {
- stopCopiers()
- }
+ // If all the results has been retrieved, copiers will exit automatically
(result.blockId, if (result.failed) None else Some(result.deserialize()))
}
}
diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala
index 9914beec99..da859eebcb 100644
--- a/core/src/main/scala/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/spark/storage/DiskStore.scala
@@ -35,21 +35,25 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
private var bs: OutputStream = null
private var objOut: SerializationStream = null
private var lastValidPosition = 0L
+ private var initialized = false
override def open(): DiskBlockObjectWriter = {
val fos = new FileOutputStream(f, true)
channel = fos.getChannel()
- bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos))
+ bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos, bufferSize))
objOut = serializer.newInstance().serializeStream(bs)
+ initialized = true
this
}
override def close() {
- objOut.close()
- bs.close()
- channel = null
- bs = null
- objOut = null
+ if (initialized) {
+ objOut.close()
+ bs.close()
+ channel = null
+ bs = null
+ objOut = null
+ }
// Invoke the close callback handler.
super.close()
}
@@ -59,23 +63,33 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
// Flush the partial writes, and set valid length to be the length of the entire file.
// Return the number of bytes written for this commit.
override def commit(): Long = {
- // NOTE: Flush the serializer first and then the compressed/buffered output stream
- objOut.flush()
- bs.flush()
- val prevPos = lastValidPosition
- lastValidPosition = channel.position()
- lastValidPosition - prevPos
+ if (initialized) {
+ // NOTE: Flush the serializer first and then the compressed/buffered output stream
+ objOut.flush()
+ bs.flush()
+ val prevPos = lastValidPosition
+ lastValidPosition = channel.position()
+ lastValidPosition - prevPos
+ } else {
+ // lastValidPosition is zero if stream is uninitialized
+ lastValidPosition
+ }
}
override def revertPartialWrites() {
- // Discard current writes. We do this by flushing the outstanding writes and
- // truncate the file to the last valid position.
- objOut.flush()
- bs.flush()
- channel.truncate(lastValidPosition)
+ if (initialized) {
+ // Discard current writes. We do this by flushing the outstanding writes and
+ // truncate the file to the last valid position.
+ objOut.flush()
+ bs.flush()
+ channel.truncate(lastValidPosition)
+ }
}
override def write(value: Any) {
+ if (!initialized) {
+ open()
+ }
objOut.writeObject(value)
}
@@ -196,7 +210,10 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
private def createFile(blockId: String, allowAppendExisting: Boolean = false): File = {
val file = getFile(blockId)
if (!allowAppendExisting && file.exists()) {
- throw new Exception("File for block " + blockId + " already exists on disk: " + file)
+ // NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task
+ // was rescheduled on the same machine as the old task.
+ logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting")
+ file.delete()
}
file
}
diff --git a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala
index 49eabfb0d2..44638e0c2d 100644
--- a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala
@@ -24,7 +24,7 @@ class ShuffleBlockManager(blockManager: BlockManager) {
val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId)
- blockManager.getDiskBlockWriter(blockId, serializer, bufferSize).open()
+ blockManager.getDiskBlockWriter(blockId, serializer, bufferSize)
}
new ShuffleWriterGroup(mapId, writers)
}
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index b967016cf7..1916885a73 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -367,6 +367,31 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
assert(nonEmptyBlocks.size <= 4)
}
+ test("zero sized blocks without kryo") {
+ // Use a local cluster with 2 processes to make sure there are both local and remote blocks
+ sc = new SparkContext("local-cluster[2,1,512]", "test")
+
+ // 10 partitions from 4 keys
+ val NUM_BLOCKS = 10
+ val a = sc.parallelize(1 to 4, NUM_BLOCKS)
+ val b = a.map(x => (x, x*2))
+
+ // NOTE: The default Java serializer should create zero-sized blocks
+ val c = new ShuffledRDD(b, new HashPartitioner(10))
+
+ val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
+ assert(c.count === 4)
+
+ val blockSizes = (0 until NUM_BLOCKS).flatMap { id =>
+ val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id)
+ statuses.map(x => x._2)
+ }
+ val nonEmptyBlocks = blockSizes.filter(x => x > 0)
+
+ // We should have at most 4 non-zero sized partitions
+ assert(nonEmptyBlocks.size <= 4)
+ }
+
}
object ShuffleSuite {