aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/CacheManager.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala7
-rw-r--r--core/src/test/scala/org/apache/spark/CacheManagerSuite.scala19
-rw-r--r--core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala27
4 files changed, 54 insertions, 3 deletions
diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala
index 5ddda4d695..f8584b90ca 100644
--- a/core/src/main/scala/org/apache/spark/CacheManager.scala
+++ b/core/src/main/scala/org/apache/spark/CacheManager.scala
@@ -68,7 +68,9 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
// Otherwise, cache the values and keep track of any updates in block statuses
val updatedBlocks = new ArrayBuffer[(BlockId, BlockStatus)]
val cachedValues = putInBlockManager(key, computedValues, storageLevel, updatedBlocks)
- context.taskMetrics.updatedBlocks = Some(updatedBlocks)
+ val metrics = context.taskMetrics
+ val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]())
+ metrics.updatedBlocks = Some(lastUpdatedBlocks ++ updatedBlocks.toSeq)
new InterruptibleIterator(context, cachedValues)
} finally {
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
index 67f72a94f0..76097f1c51 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
@@ -70,8 +70,11 @@ class StorageListener(storageStatusListener: StorageStatusListener) extends Spar
}
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized {
- // Remove all partitions that are no longer cached
- _rddInfoMap.retain { case (_, info) => info.numCachedPartitions > 0 }
+ // Remove all partitions that are no longer cached in current completed stage
+ val completedRddIds = stageCompleted.stageInfo.rddInfos.map(r => r.id).toSet
+ _rddInfoMap.retain { case (id, info) =>
+ !completedRddIds.contains(id) || info.numCachedPartitions > 0
+ }
}
override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD) = synchronized {
diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
index 9c5f394d38..90dcadcffd 100644
--- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
@@ -32,6 +32,8 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
var split: Partition = _
/** An RDD which returns the values [1, 2, 3, 4]. */
var rdd: RDD[Int] = _
+ var rdd2: RDD[Int] = _
+ var rdd3: RDD[Int] = _
before {
sc = new SparkContext("local", "test")
@@ -43,6 +45,16 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
override val getDependencies = List[Dependency[_]]()
override def compute(split: Partition, context: TaskContext) = Array(1, 2, 3, 4).iterator
}
+ rdd2 = new RDD[Int](sc, List(new OneToOneDependency(rdd))) {
+ override def getPartitions: Array[Partition] = firstParent[Int].partitions
+ override def compute(split: Partition, context: TaskContext) =
+ firstParent[Int].iterator(split, context)
+ }.cache()
+ rdd3 = new RDD[Int](sc, List(new OneToOneDependency(rdd2))) {
+ override def getPartitions: Array[Partition] = firstParent[Int].partitions
+ override def compute(split: Partition, context: TaskContext) =
+ firstParent[Int].iterator(split, context)
+ }.cache()
}
after {
@@ -87,4 +99,11 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
assert(value.toList === List(1, 2, 3, 4))
}
}
+
+ test("verify task metrics updated correctly") {
+ cacheManager = sc.env.cacheManager
+ val context = new TaskContext(0, 0, 0)
+ cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY)
+ assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala
index b860177705..a537c72ce7 100644
--- a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala
@@ -34,6 +34,7 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter {
private val memOnly = StorageLevel.MEMORY_ONLY
private val none = StorageLevel.NONE
private val taskInfo = new TaskInfo(0, 0, 0, 0, "big", "dog", TaskLocality.ANY, false)
+ private val taskInfo1 = new TaskInfo(1, 1, 1, 1, "big", "cat", TaskLocality.ANY, false)
private def rddInfo0 = new RDDInfo(0, "freedom", 100, memOnly)
private def rddInfo1 = new RDDInfo(1, "hostage", 200, memOnly)
private def rddInfo2 = new RDDInfo(2, "sanity", 300, memAndDisk)
@@ -162,4 +163,30 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter {
assert(storageListener._rddInfoMap(2).numCachedPartitions === 0)
}
+ test("verify StorageTab contains all cached rdds") {
+
+ val rddInfo0 = new RDDInfo(0, "rdd0", 1, memOnly)
+ val rddInfo1 = new RDDInfo(1, "rdd1", 1 ,memOnly)
+ val stageInfo0 = new StageInfo(0, 0, "stage0", 1, Seq(rddInfo0), "details")
+ val stageInfo1 = new StageInfo(1, 0, "stage1", 1, Seq(rddInfo1), "details")
+ val taskMetrics0 = new TaskMetrics
+ val taskMetrics1 = new TaskMetrics
+ val block0 = (RDDBlockId(0, 1), BlockStatus(memOnly, 100L, 0L, 0L))
+ val block1 = (RDDBlockId(1, 1), BlockStatus(memOnly, 200L, 0L, 0L))
+ taskMetrics0.updatedBlocks = Some(Seq(block0))
+ taskMetrics1.updatedBlocks = Some(Seq(block1))
+ bus.postToAll(SparkListenerBlockManagerAdded(bm1, 1000L))
+ bus.postToAll(SparkListenerStageSubmitted(stageInfo0))
+ assert(storageListener.rddInfoList.size === 0)
+ bus.postToAll(SparkListenerTaskEnd(0, 0, "big", Success, taskInfo, taskMetrics0))
+ assert(storageListener.rddInfoList.size === 1)
+ bus.postToAll(SparkListenerStageSubmitted(stageInfo1))
+ assert(storageListener.rddInfoList.size === 1)
+ bus.postToAll(SparkListenerStageCompleted(stageInfo0))
+ assert(storageListener.rddInfoList.size === 1)
+ bus.postToAll(SparkListenerTaskEnd(1, 0, "small", Success, taskInfo1, taskMetrics1))
+ assert(storageListener.rddInfoList.size === 2)
+ bus.postToAll(SparkListenerStageCompleted(stageInfo1))
+ assert(storageListener.rddInfoList.size === 2)
+ }
}