diff options
4 files changed, 54 insertions, 3 deletions
diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index 5ddda4d695..f8584b90ca 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -68,7 +68,9 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { // Otherwise, cache the values and keep track of any updates in block statuses val updatedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] val cachedValues = putInBlockManager(key, computedValues, storageLevel, updatedBlocks) - context.taskMetrics.updatedBlocks = Some(updatedBlocks) + val metrics = context.taskMetrics + val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]()) + metrics.updatedBlocks = Some(lastUpdatedBlocks ++ updatedBlocks.toSeq) new InterruptibleIterator(context, cachedValues) } finally { diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala index 67f72a94f0..76097f1c51 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala @@ -70,8 +70,11 @@ class StorageListener(storageStatusListener: StorageStatusListener) extends Spar } override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized { - // Remove all partitions that are no longer cached - _rddInfoMap.retain { case (_, info) => info.numCachedPartitions > 0 } + // Remove all partitions that are no longer cached in current completed stage + val completedRddIds = stageCompleted.stageInfo.rddInfos.map(r => r.id).toSet + _rddInfoMap.retain { case (id, info) => + !completedRddIds.contains(id) || info.numCachedPartitions > 0 + } } override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD) = synchronized { diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index 9c5f394d38..90dcadcffd 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -32,6 +32,8 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar var split: Partition = _ /** An RDD which returns the values [1, 2, 3, 4]. */ var rdd: RDD[Int] = _ + var rdd2: RDD[Int] = _ + var rdd3: RDD[Int] = _ before { sc = new SparkContext("local", "test") @@ -43,6 +45,16 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar override val getDependencies = List[Dependency[_]]() override def compute(split: Partition, context: TaskContext) = Array(1, 2, 3, 4).iterator } + rdd2 = new RDD[Int](sc, List(new OneToOneDependency(rdd))) { + override def getPartitions: Array[Partition] = firstParent[Int].partitions + override def compute(split: Partition, context: TaskContext) = + firstParent[Int].iterator(split, context) + }.cache() + rdd3 = new RDD[Int](sc, List(new OneToOneDependency(rdd2))) { + override def getPartitions: Array[Partition] = firstParent[Int].partitions + override def compute(split: Partition, context: TaskContext) = + firstParent[Int].iterator(split, context) + }.cache() } after { @@ -87,4 +99,11 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar assert(value.toList === List(1, 2, 3, 4)) } } + + test("verify task metrics updated correctly") { + cacheManager = sc.env.cacheManager + val context = new TaskContext(0, 0, 0) + cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY) + assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2) + } } diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala index b860177705..a537c72ce7 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala @@ -34,6 +34,7 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter { private val memOnly = StorageLevel.MEMORY_ONLY private val none = StorageLevel.NONE private val taskInfo = new TaskInfo(0, 0, 0, 0, "big", "dog", TaskLocality.ANY, false) + private val taskInfo1 = new TaskInfo(1, 1, 1, 1, "big", "cat", TaskLocality.ANY, false) private def rddInfo0 = new RDDInfo(0, "freedom", 100, memOnly) private def rddInfo1 = new RDDInfo(1, "hostage", 200, memOnly) private def rddInfo2 = new RDDInfo(2, "sanity", 300, memAndDisk) @@ -162,4 +163,30 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter { assert(storageListener._rddInfoMap(2).numCachedPartitions === 0) } + test("verify StorageTab contains all cached rdds") { + + val rddInfo0 = new RDDInfo(0, "rdd0", 1, memOnly) + val rddInfo1 = new RDDInfo(1, "rdd1", 1 ,memOnly) + val stageInfo0 = new StageInfo(0, 0, "stage0", 1, Seq(rddInfo0), "details") + val stageInfo1 = new StageInfo(1, 0, "stage1", 1, Seq(rddInfo1), "details") + val taskMetrics0 = new TaskMetrics + val taskMetrics1 = new TaskMetrics + val block0 = (RDDBlockId(0, 1), BlockStatus(memOnly, 100L, 0L, 0L)) + val block1 = (RDDBlockId(1, 1), BlockStatus(memOnly, 200L, 0L, 0L)) + taskMetrics0.updatedBlocks = Some(Seq(block0)) + taskMetrics1.updatedBlocks = Some(Seq(block1)) + bus.postToAll(SparkListenerBlockManagerAdded(bm1, 1000L)) + bus.postToAll(SparkListenerStageSubmitted(stageInfo0)) + assert(storageListener.rddInfoList.size === 0) + bus.postToAll(SparkListenerTaskEnd(0, 0, "big", Success, taskInfo, taskMetrics0)) + assert(storageListener.rddInfoList.size === 1) + bus.postToAll(SparkListenerStageSubmitted(stageInfo1)) + assert(storageListener.rddInfoList.size === 1) + bus.postToAll(SparkListenerStageCompleted(stageInfo0)) + assert(storageListener.rddInfoList.size === 1) + bus.postToAll(SparkListenerTaskEnd(1, 0, "small", Success, taskInfo1, taskMetrics1)) + assert(storageListener.rddInfoList.size === 2) + bus.postToAll(SparkListenerStageCompleted(stageInfo1)) + assert(storageListener.rddInfoList.size === 2) + } } |