aboutsummaryrefslogtreecommitdiff
path: root/core/src
diff options
context:
space:
mode:
authorShivaram Venkataraman <shivaram@eecs.berkeley.edu>2013-06-05 14:28:38 -0700
committerShivaram Venkataraman <shivaram@eecs.berkeley.edu>2013-06-05 14:28:38 -0700
commitc851957fe4798d5dfb8deba7bf79a035a0543c74 (patch)
tree193e29ab284cf1a0d2c61413cdebeca90767c247 /core/src
parent96943a1cc054d7cf80eb8d3dfc7fb19ce48d3c0a (diff)
downloadspark-c851957fe4798d5dfb8deba7bf79a035a0543c74.tar.gz
spark-c851957fe4798d5dfb8deba7bf79a035a0543c74.tar.bz2
spark-c851957fe4798d5dfb8deba7bf79a035a0543c74.zip
Don't write zero block files with java serializer
Diffstat (limited to 'core/src')
-rw-r--r--core/src/main/scala/spark/storage/BlockFetcherIterator.scala5
-rw-r--r--core/src/main/scala/spark/storage/DiskStore.scala46
-rw-r--r--core/src/main/scala/spark/storage/ShuffleBlockManager.scala2
-rw-r--r--core/src/test/scala/spark/ShuffleSuite.scala26
4 files changed, 61 insertions, 18 deletions
diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala
index fac416a5b3..843069239c 100644
--- a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala
+++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala
@@ -71,6 +71,7 @@ object BlockFetcherIterator {
logDebug("Getting " + _totalBlocks + " blocks")
protected var startTime = System.currentTimeMillis
protected val localBlockIds = new ArrayBuffer[String]()
+ protected val localNonZeroBlocks = new ArrayBuffer[String]()
protected val remoteBlockIds = new HashSet[String]()
// A queue to hold our results.
@@ -129,6 +130,8 @@ object BlockFetcherIterator {
for ((address, blockInfos) <- blocksByAddress) {
if (address == blockManagerId) {
localBlockIds ++= blockInfos.map(_._1)
+ localNonZeroBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
+ _totalBlocks -= (localBlockIds.size - localNonZeroBlocks.size)
} else {
remoteBlockIds ++= blockInfos.map(_._1)
// Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
@@ -172,7 +175,7 @@ object BlockFetcherIterator {
// Get the local blocks while remote blocks are being fetched. Note that it's okay to do
// these all at once because they will just memory-map some files, so they won't consume
// any memory that might exceed our maxBytesInFlight
- for (id <- localBlockIds) {
+ for (id <- localNonZeroBlocks) {
getLocalFromDisk(id, serializer) match {
case Some(iter) => {
// Pass 0 as size since it's not in flight
diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala
index cd85fa1e9d..c1cff25552 100644
--- a/core/src/main/scala/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/spark/storage/DiskStore.scala
@@ -35,21 +35,25 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
private var bs: OutputStream = null
private var objOut: SerializationStream = null
private var lastValidPosition = 0L
+ private var initialized = false
override def open(): DiskBlockObjectWriter = {
val fos = new FileOutputStream(f, true)
channel = fos.getChannel()
bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos))
objOut = serializer.newInstance().serializeStream(bs)
+ initialized = true
this
}
override def close() {
- objOut.close()
- bs.close()
- channel = null
- bs = null
- objOut = null
+ if (initialized) {
+ objOut.close()
+ bs.close()
+ channel = null
+ bs = null
+ objOut = null
+ }
// Invoke the close callback handler.
super.close()
}
@@ -59,23 +63,33 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
// Flush the partial writes, and set valid length to be the length of the entire file.
// Return the number of bytes written for this commit.
override def commit(): Long = {
- // NOTE: Flush the serializer first and then the compressed/buffered output stream
- objOut.flush()
- bs.flush()
- val prevPos = lastValidPosition
- lastValidPosition = channel.position()
- lastValidPosition - prevPos
+ if (initialized) {
+ // NOTE: Flush the serializer first and then the compressed/buffered output stream
+ objOut.flush()
+ bs.flush()
+ val prevPos = lastValidPosition
+ lastValidPosition = channel.position()
+ lastValidPosition - prevPos
+ } else {
+ // lastValidPosition is zero if stream is uninitialized
+ lastValidPosition
+ }
}
override def revertPartialWrites() {
- // Discard current writes. We do this by flushing the outstanding writes and
- // truncate the file to the last valid position.
- objOut.flush()
- bs.flush()
- channel.truncate(lastValidPosition)
+ if (initialized) {
+ // Discard current writes. We do this by flushing the outstanding writes and
+ // truncate the file to the last valid position.
+ objOut.flush()
+ bs.flush()
+ channel.truncate(lastValidPosition)
+ }
}
override def write(value: Any) {
+ if (!initialized) {
+ open()
+ }
objOut.writeObject(value)
}
diff --git a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala
index 49eabfb0d2..44638e0c2d 100644
--- a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala
@@ -24,7 +24,7 @@ class ShuffleBlockManager(blockManager: BlockManager) {
val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId)
- blockManager.getDiskBlockWriter(blockId, serializer, bufferSize).open()
+ blockManager.getDiskBlockWriter(blockId, serializer, bufferSize)
}
new ShuffleWriterGroup(mapId, writers)
}
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index b967016cf7..33b02fff80 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -367,6 +367,32 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
assert(nonEmptyBlocks.size <= 4)
}
+ test("zero sized blocks without kryo") {
+ // Use a local cluster with 2 processes to make sure there are both local and remote blocks
+ sc = new SparkContext("local-cluster[2,1,512]", "test")
+
+ // 10 partitions from 4 keys
+ val NUM_BLOCKS = 10
+ val a = sc.parallelize(1 to 4, NUM_BLOCKS)
+ val b = a.map(x => (x, x*2))
+
+ // NOTE: The default Java serializer doesn't create zero-sized blocks.
+ // So, use Kryo
+ val c = new ShuffledRDD(b, new HashPartitioner(10))
+
+ val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
+ assert(c.count === 4)
+
+ val blockSizes = (0 until NUM_BLOCKS).flatMap { id =>
+ val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id)
+ statuses.map(x => x._2)
+ }
+ val nonEmptyBlocks = blockSizes.filter(x => x > 0)
+
+ // We should have at most 4 non-zero sized partitions
+ assert(nonEmptyBlocks.size <= 4)
+ }
+
}
object ShuffleSuite {