diff options
author | Shivaram Venkataraman <shivaram@eecs.berkeley.edu> | 2013-06-05 14:28:38 -0700 |
---|---|---|
committer | Shivaram Venkataraman <shivaram@eecs.berkeley.edu> | 2013-06-05 14:28:38 -0700 |
commit | c851957fe4798d5dfb8deba7bf79a035a0543c74 (patch) | |
tree | 193e29ab284cf1a0d2c61413cdebeca90767c247 /core | |
parent | 96943a1cc054d7cf80eb8d3dfc7fb19ce48d3c0a (diff) | |
download | spark-c851957fe4798d5dfb8deba7bf79a035a0543c74.tar.gz spark-c851957fe4798d5dfb8deba7bf79a035a0543c74.tar.bz2 spark-c851957fe4798d5dfb8deba7bf79a035a0543c74.zip |
Don't write zero block files with java serializer
Diffstat (limited to 'core')
4 files changed, 61 insertions, 18 deletions
diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala index fac416a5b3..843069239c 100644 --- a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala @@ -71,6 +71,7 @@ object BlockFetcherIterator { logDebug("Getting " + _totalBlocks + " blocks") protected var startTime = System.currentTimeMillis protected val localBlockIds = new ArrayBuffer[String]() + protected val localNonZeroBlocks = new ArrayBuffer[String]() protected val remoteBlockIds = new HashSet[String]() // A queue to hold our results. @@ -129,6 +130,8 @@ object BlockFetcherIterator { for ((address, blockInfos) <- blocksByAddress) { if (address == blockManagerId) { localBlockIds ++= blockInfos.map(_._1) + localNonZeroBlocks ++= blockInfos.filter(_._2 != 0).map(_._1) + _totalBlocks -= (localBlockIds.size - localNonZeroBlocks.size) } else { remoteBlockIds ++= blockInfos.map(_._1) // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them @@ -172,7 +175,7 @@ object BlockFetcherIterator { // Get the local blocks while remote blocks are being fetched. Note that it's okay to do // these all at once because they will just memory-map some files, so they won't consume // any memory that might exceed our maxBytesInFlight - for (id <- localBlockIds) { + for (id <- localNonZeroBlocks) { getLocalFromDisk(id, serializer) match { case Some(iter) => { // Pass 0 as size since it's not in flight diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index cd85fa1e9d..c1cff25552 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -35,21 +35,25 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) private var bs: OutputStream = null private var objOut: SerializationStream = null private var lastValidPosition = 0L + private var initialized = false override def open(): DiskBlockObjectWriter = { val fos = new FileOutputStream(f, true) channel = fos.getChannel() bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos)) objOut = serializer.newInstance().serializeStream(bs) + initialized = true this } override def close() { - objOut.close() - bs.close() - channel = null - bs = null - objOut = null + if (initialized) { + objOut.close() + bs.close() + channel = null + bs = null + objOut = null + } // Invoke the close callback handler. super.close() } @@ -59,23 +63,33 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) // Flush the partial writes, and set valid length to be the length of the entire file. // Return the number of bytes written for this commit. override def commit(): Long = { - // NOTE: Flush the serializer first and then the compressed/buffered output stream - objOut.flush() - bs.flush() - val prevPos = lastValidPosition - lastValidPosition = channel.position() - lastValidPosition - prevPos + if (initialized) { + // NOTE: Flush the serializer first and then the compressed/buffered output stream + objOut.flush() + bs.flush() + val prevPos = lastValidPosition + lastValidPosition = channel.position() + lastValidPosition - prevPos + } else { + // lastValidPosition is zero if stream is uninitialized + lastValidPosition + } } override def revertPartialWrites() { - // Discard current writes. We do this by flushing the outstanding writes and - // truncate the file to the last valid position. - objOut.flush() - bs.flush() - channel.truncate(lastValidPosition) + if (initialized) { + // Discard current writes. We do this by flushing the outstanding writes and + // truncate the file to the last valid position. + objOut.flush() + bs.flush() + channel.truncate(lastValidPosition) + } } override def write(value: Any) { + if (!initialized) { + open() + } objOut.writeObject(value) } diff --git a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala index 49eabfb0d2..44638e0c2d 100644 --- a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala @@ -24,7 +24,7 @@ class ShuffleBlockManager(blockManager: BlockManager) { val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024 val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId) - blockManager.getDiskBlockWriter(blockId, serializer, bufferSize).open() + blockManager.getDiskBlockWriter(blockId, serializer, bufferSize) } new ShuffleWriterGroup(mapId, writers) } diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index b967016cf7..33b02fff80 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -367,6 +367,32 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { assert(nonEmptyBlocks.size <= 4) } + test("zero sized blocks without kryo") { + // Use a local cluster with 2 processes to make sure there are both local and remote blocks + sc = new SparkContext("local-cluster[2,1,512]", "test") + + // 10 partitions from 4 keys + val NUM_BLOCKS = 10 + val a = sc.parallelize(1 to 4, NUM_BLOCKS) + val b = a.map(x => (x, x*2)) + + // NOTE: The default Java serializer doesn't create zero-sized blocks. + // So, use Kryo + val c = new ShuffledRDD(b, new HashPartitioner(10)) + + val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId + assert(c.count === 4) + + val blockSizes = (0 until NUM_BLOCKS).flatMap { id => + val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) + statuses.map(x => x._2) + } + val nonEmptyBlocks = blockSizes.filter(x => x > 0) + + // We should have at most 4 non-zero sized partitions + assert(nonEmptyBlocks.size <= 4) + } + } object ShuffleSuite { |