aboutsummaryrefslogtreecommitdiff
path: root/core/src
diff options
context:
space:
mode:
authorShivaram Venkataraman <shivaram@eecs.berkeley.edu>2013-06-12 16:38:37 -0700
committerShivaram Venkataraman <shivaram@eecs.berkeley.edu>2013-06-12 16:38:37 -0700
commit5da4287b1dbfb8cfcec9c915926d8a8755bd52b2 (patch)
tree5b0d5e79fed362bfc67fcc56345ee17102b8423f /core/src
parent5e9a9317c5ef4c549994989bae7826dcb37d1a3b (diff)
parentac480fd977e0de97bcfe646e39feadbd239c1c29 (diff)
downloadspark-5da4287b1dbfb8cfcec9c915926d8a8755bd52b2.tar.gz
spark-5da4287b1dbfb8cfcec9c915926d8a8755bd52b2.tar.bz2
spark-5da4287b1dbfb8cfcec9c915926d8a8755bd52b2.zip
Merge branch 'netty-dbg' of github.com:shivaram/spark into netty-dbg
Diffstat (limited to 'core/src')
-rw-r--r--core/src/main/scala/spark/storage/BlockFetcherIterator.scala51
-rw-r--r--core/src/main/scala/spark/storage/DiskStore.scala54
-rw-r--r--core/src/main/scala/spark/storage/ShuffleBlockManager.scala2
-rw-r--r--core/src/test/scala/spark/ShuffleSuite.scala26
4 files changed, 94 insertions, 39 deletions
diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala
index fac416a5b3..bb78207c9f 100644
--- a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala
+++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala
@@ -67,11 +67,20 @@ object BlockFetcherIterator {
throw new IllegalArgumentException("BlocksByAddress is null")
}
- protected var _totalBlocks = blocksByAddress.map(_._2.size).sum
- logDebug("Getting " + _totalBlocks + " blocks")
+ // Total number blocks fetched (local + remote). Also number of FetchResults expected
+ protected var _numBlocksToFetch = 0
+
protected var startTime = System.currentTimeMillis
- protected val localBlockIds = new ArrayBuffer[String]()
- protected val remoteBlockIds = new HashSet[String]()
+
+ // This represents the number of local blocks, also counting zero-sized blocks
+ private var numLocal = 0
+ // BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks
+ protected val localBlocksToFetch = new ArrayBuffer[String]()
+
+ // This represents the number of remote blocks, also counting zero-sized blocks
+ private var numRemote = 0
+ // BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks
+ protected val remoteBlocksToFetch = new HashSet[String]()
// A queue to hold our results.
protected val results = new LinkedBlockingQueue[FetchResult]
@@ -124,13 +133,15 @@ object BlockFetcherIterator {
protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
// Split local and remote blocks. Remote blocks are further split into FetchRequests of size
// at most maxBytesInFlight in order to limit the amount of data in flight.
- val originalTotalBlocks = _totalBlocks
val remoteRequests = new ArrayBuffer[FetchRequest]
for ((address, blockInfos) <- blocksByAddress) {
if (address == blockManagerId) {
- localBlockIds ++= blockInfos.map(_._1)
+ numLocal = blockInfos.size
+ // Filter out zero-sized blocks
+ localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1)
+ _numBlocksToFetch += localBlocksToFetch.size
} else {
- remoteBlockIds ++= blockInfos.map(_._1)
+ numRemote += blockInfos.size
// Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
// smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
// nodes, rather than blocking on reading output from one node.
@@ -144,10 +155,10 @@ object BlockFetcherIterator {
// Skip empty blocks
if (size > 0) {
curBlocks += ((blockId, size))
+ remoteBlocksToFetch += blockId
+ _numBlocksToFetch += 1
curRequestSize += size
- } else if (size == 0) {
- _totalBlocks -= 1
- } else {
+ } else if (size < 0) {
throw new BlockException(blockId, "Negative block size " + size)
}
if (curRequestSize >= minRequestSize) {
@@ -163,8 +174,8 @@ object BlockFetcherIterator {
}
}
}
- logInfo("Getting " + _totalBlocks + " non-zero-bytes blocks out of " +
- originalTotalBlocks + " blocks")
+ logInfo("Getting " + _numBlocksToFetch + " non-zero-bytes blocks out of " +
+ totalBlocks + " blocks")
remoteRequests
}
@@ -172,7 +183,7 @@ object BlockFetcherIterator {
// Get the local blocks while remote blocks are being fetched. Note that it's okay to do
// these all at once because they will just memory-map some files, so they won't consume
// any memory that might exceed our maxBytesInFlight
- for (id <- localBlockIds) {
+ for (id <- localBlocksToFetch) {
getLocalFromDisk(id, serializer) match {
case Some(iter) => {
// Pass 0 as size since it's not in flight
@@ -198,7 +209,7 @@ object BlockFetcherIterator {
sendRequest(fetchRequests.dequeue())
}
- val numGets = remoteBlockIds.size - fetchRequests.size
+ val numGets = remoteRequests.size - fetchRequests.size
logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime))
// Get Local Blocks
@@ -210,7 +221,7 @@ object BlockFetcherIterator {
//an iterator that will read fetched blocks off the queue as they arrive.
@volatile protected var resultsGotten = 0
- override def hasNext: Boolean = resultsGotten < _totalBlocks
+ override def hasNext: Boolean = resultsGotten < _numBlocksToFetch
override def next(): (String, Option[Iterator[Any]]) = {
resultsGotten += 1
@@ -227,9 +238,9 @@ object BlockFetcherIterator {
}
// Implementing BlockFetchTracker trait.
- override def totalBlocks: Int = _totalBlocks
- override def numLocalBlocks: Int = localBlockIds.size
- override def numRemoteBlocks: Int = remoteBlockIds.size
+ override def totalBlocks: Int = numLocal + numRemote
+ override def numLocalBlocks: Int = numLocal
+ override def numRemoteBlocks: Int = numRemote
override def remoteFetchTime: Long = _remoteFetchTime
override def fetchWaitTime: Long = _fetchWaitTime
override def remoteBytesRead: Long = _remoteBytesRead
@@ -291,7 +302,7 @@ object BlockFetcherIterator {
private var copiers: List[_ <: Thread] = null
override def initialize() {
- // Split Local Remote Blocks and adjust totalBlocks to include only the non 0-byte blocks
+ // Split Local Remote Blocks and set numBlocksToFetch
val remoteRequests = splitLocalRemoteBlocks()
// Add the remote requests into our queue in a random order
for (request <- Utils.randomize(remoteRequests)) {
@@ -313,7 +324,7 @@ object BlockFetcherIterator {
val result = results.take()
// if all the results has been retrieved, shutdown the copiers
// NO need to stop the copiers if we got all the blocks ?
- // if (resultsGotten == _totalBlocks && copiers != null) {
+ // if (resultsGotten == _numBlocksToFetch && copiers != null) {
// stopCopiers()
// }
(result.blockId, if (result.failed) None else Some(result.deserialize()))
diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala
index c7281200e7..0af6e4a359 100644
--- a/core/src/main/scala/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/spark/storage/DiskStore.scala
@@ -35,21 +35,25 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
private var bs: OutputStream = null
private var objOut: SerializationStream = null
private var lastValidPosition = 0L
+ private var initialized = false
override def open(): DiskBlockObjectWriter = {
val fos = new FileOutputStream(f, true)
channel = fos.getChannel()
- bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos))
+ bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos, bufferSize))
objOut = serializer.newInstance().serializeStream(bs)
+ initialized = true
this
}
override def close() {
- objOut.close()
- bs.close()
- channel = null
- bs = null
- objOut = null
+ if (initialized) {
+ objOut.close()
+ bs.close()
+ channel = null
+ bs = null
+ objOut = null
+ }
// Invoke the close callback handler.
super.close()
}
@@ -59,23 +63,33 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
// Flush the partial writes, and set valid length to be the length of the entire file.
// Return the number of bytes written for this commit.
override def commit(): Long = {
- // NOTE: Flush the serializer first and then the compressed/buffered output stream
- objOut.flush()
- bs.flush()
- val prevPos = lastValidPosition
- lastValidPosition = channel.position()
- lastValidPosition - prevPos
+ if (initialized) {
+ // NOTE: Flush the serializer first and then the compressed/buffered output stream
+ objOut.flush()
+ bs.flush()
+ val prevPos = lastValidPosition
+ lastValidPosition = channel.position()
+ lastValidPosition - prevPos
+ } else {
+ // lastValidPosition is zero if stream is uninitialized
+ lastValidPosition
+ }
}
override def revertPartialWrites() {
- // Discard current writes. We do this by flushing the outstanding writes and
- // truncate the file to the last valid position.
- objOut.flush()
- bs.flush()
- channel.truncate(lastValidPosition)
+ if (initialized) {
+ // Discard current writes. We do this by flushing the outstanding writes and
+ // truncate the file to the last valid position.
+ objOut.flush()
+ bs.flush()
+ channel.truncate(lastValidPosition)
+ }
}
override def write(value: Any) {
+ if (!initialized) {
+ open()
+ }
objOut.writeObject(value)
}
@@ -197,7 +211,11 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
private def createFile(blockId: String, allowAppendExisting: Boolean = false): File = {
val file = getFile(blockId)
if (!allowAppendExisting && file.exists()) {
- throw new Exception("File for block " + blockId + " already exists on disk: " + file)
+ // NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task
+ // was rescheduled on the same machine as the old task ?
+ logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting")
+ file.delete()
+ // throw new Exception("File for block " + blockId + " already exists on disk: " + file)
}
file
}
diff --git a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala
index 49eabfb0d2..44638e0c2d 100644
--- a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala
@@ -24,7 +24,7 @@ class ShuffleBlockManager(blockManager: BlockManager) {
val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId)
- blockManager.getDiskBlockWriter(blockId, serializer, bufferSize).open()
+ blockManager.getDiskBlockWriter(blockId, serializer, bufferSize)
}
new ShuffleWriterGroup(mapId, writers)
}
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index b967016cf7..33b02fff80 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -367,6 +367,32 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
assert(nonEmptyBlocks.size <= 4)
}
+ test("zero sized blocks without kryo") {
+ // Use a local cluster with 2 processes to make sure there are both local and remote blocks
+ sc = new SparkContext("local-cluster[2,1,512]", "test")
+
+ // 10 partitions from 4 keys
+ val NUM_BLOCKS = 10
+ val a = sc.parallelize(1 to 4, NUM_BLOCKS)
+ val b = a.map(x => (x, x*2))
+
+ // NOTE: The default Java serializer doesn't create zero-sized blocks.
+ // So, use Kryo
+ val c = new ShuffledRDD(b, new HashPartitioner(10))
+
+ val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
+ assert(c.count === 4)
+
+ val blockSizes = (0 until NUM_BLOCKS).flatMap { id =>
+ val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id)
+ statuses.map(x => x._2)
+ }
+ val nonEmptyBlocks = blockSizes.filter(x => x > 0)
+
+ // We should have at most 4 non-zero sized partitions
+ assert(nonEmptyBlocks.size <= 4)
+ }
+
}
object ShuffleSuite {