aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorShivaram Venkataraman <shivaram@eecs.berkeley.edu>2013-05-29 13:09:58 -0700
committerShivaram Venkataraman <shivaram@eecs.berkeley.edu>2013-05-29 13:18:54 -0700
commit618c8cae1ee5dede98824823e00f7863571c0e57 (patch)
tree23e11903b4fc7615d3e0550b7a1d273452352810
parent6ed71390d9a9af4f48c6c1aa0e86feb8c8ad3272 (diff)
downloadspark-618c8cae1ee5dede98824823e00f7863571c0e57.tar.gz
spark-618c8cae1ee5dede98824823e00f7863571c0e57.tar.bz2
spark-618c8cae1ee5dede98824823e00f7863571c0e57.zip
Skip fetching zero-sized blocks in OIO.
Also unify splitLocalRemoteBlocks for netty/nio and add a test case
-rw-r--r--core/src/main/scala/spark/storage/BlockFetcherIterator.scala61
-rw-r--r--core/src/test/scala/spark/ShuffleSuite.scala27
2 files changed, 39 insertions, 49 deletions
diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala
index 95308c7282..1d69d658f7 100644
--- a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala
+++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala
@@ -124,6 +124,7 @@ object BlockFetcherIterator {
protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
// Split local and remote blocks. Remote blocks are further split into FetchRequests of size
// at most maxBytesInFlight in order to limit the amount of data in flight.
+ val originalTotalBlocks = _totalBlocks
val remoteRequests = new ArrayBuffer[FetchRequest]
for ((address, blockInfos) <- blocksByAddress) {
if (address == blockManagerId) {
@@ -140,8 +141,15 @@ object BlockFetcherIterator {
var curBlocks = new ArrayBuffer[(String, Long)]
while (iterator.hasNext) {
val (blockId, size) = iterator.next()
- curBlocks += ((blockId, size))
- curRequestSize += size
+ // Skip empty blocks
+ if (size > 0) {
+ curBlocks += ((blockId, size))
+ curRequestSize += size
+ } else if (size == 0) {
+ _totalBlocks -= 1
+ } else {
+ throw new BlockException(blockId, "Negative block size " + size)
+ }
if (curRequestSize >= minRequestSize) {
// Add this FetchRequest
remoteRequests += new FetchRequest(address, curBlocks)
@@ -155,6 +163,8 @@ object BlockFetcherIterator {
}
}
}
+ logInfo("Getting " + _totalBlocks + " non-zero-bytes blocks out of " +
+ originalTotalBlocks + " blocks")
remoteRequests
}
@@ -278,53 +288,6 @@ object BlockFetcherIterator {
logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.host )
}
- override protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
- // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
- // at most maxBytesInFlight in order to limit the amount of data in flight.
- val originalTotalBlocks = _totalBlocks;
- val remoteRequests = new ArrayBuffer[FetchRequest]
- for ((address, blockInfos) <- blocksByAddress) {
- if (address == blockManagerId) {
- localBlockIds ++= blockInfos.map(_._1)
- } else {
- remoteBlockIds ++= blockInfos.map(_._1)
- // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
- // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
- // nodes, rather than blocking on reading output from one node.
- val minRequestSize = math.max(maxBytesInFlight / 5, 1L)
- logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize)
- val iterator = blockInfos.iterator
- var curRequestSize = 0L
- var curBlocks = new ArrayBuffer[(String, Long)]
- while (iterator.hasNext) {
- val (blockId, size) = iterator.next()
- if (size > 0) {
- curBlocks += ((blockId, size))
- curRequestSize += size
- } else if (size == 0) {
- //here we changes the totalBlocks
- _totalBlocks -= 1
- } else {
- throw new BlockException(blockId, "Negative block size " + size)
- }
- if (curRequestSize >= minRequestSize) {
- // Add this FetchRequest
- remoteRequests += new FetchRequest(address, curBlocks)
- curRequestSize = 0
- curBlocks = new ArrayBuffer[(String, Long)]
- }
- }
- // Add in the final request
- if (!curBlocks.isEmpty) {
- remoteRequests += new FetchRequest(address, curBlocks)
- }
- }
- }
- logInfo("Getting " + _totalBlocks + " non-zero-bytes blocks out of " +
- originalTotalBlocks + " blocks")
- remoteRequests
- }
-
private var copiers: List[_ <: Thread] = null
override def initialize() {
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index fdee7ca384..4e50ae2ca9 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -317,6 +317,33 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
val c = new ShuffledRDD(b, new HashPartitioner(3), classOf[spark.KryoSerializer].getName)
assert(c.count === 10)
}
+
+ test("zero sized blocks") {
+ // Use a local cluster with 2 processes to make sure there are both local and remote blocks
+ sc = new SparkContext("local-cluster[2,1,512]", "test")
+
+ // 10 partitions from 4 keys
+ val NUM_BLOCKS = 10
+ val a = sc.parallelize(1 to 4, NUM_BLOCKS)
+ val b = a.map(x => (x, x*2))
+
+ // NOTE: The default Java serializer doesn't create zero-sized blocks.
+ // So, use Kryo
+ val c = new ShuffledRDD(b, new HashPartitioner(10), classOf[spark.KryoSerializer].getName)
+
+ val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
+ assert(c.count === 4)
+
+ val blockSizes = (0 until NUM_BLOCKS).flatMap { id =>
+ val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id)
+ statuses.map(x => x._2)
+ }
+ val nonEmptyBlocks = blockSizes.filter(x => x > 0)
+
+ // We should have at most 4 non-zero sized partitions
+ assert(nonEmptyBlocks.size <= 4)
+ }
+
}
object ShuffleSuite {