diff options
Diffstat (limited to 'core/src/main/scala/org/apache/spark/rdd/RDD.scala')
-rw-r--r-- | core/src/main/scala/org/apache/spark/rdd/RDD.scala | 33 |
1 files changed, 31 insertions, 2 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 6a6ad2d75a..e5fdebc65d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -37,7 +37,7 @@ import org.apache.spark.partial.BoundedDouble import org.apache.spark.partial.CountEvaluator import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult -import org.apache.spark.storage.StorageLevel +import org.apache.spark.storage.{RDDBlockId, StorageLevel} import org.apache.spark.util.{BoundedPriorityQueue, Utils} import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler, @@ -272,7 +272,7 @@ abstract class RDD[T: ClassTag]( */ final def iterator(split: Partition, context: TaskContext): Iterator[T] = { if (storageLevel != StorageLevel.NONE) { - SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel) + getOrCompute(split, context) } else { computeOrReadCheckpoint(split, context) } @@ -315,6 +315,35 @@ abstract class RDD[T: ClassTag]( } /** + * Gets or computes an RDD partition. Used by RDD.iterator() when an RDD is cached. + */ + private[spark] def getOrCompute(partition: Partition, context: TaskContext): Iterator[T] = { + val blockId = RDDBlockId(id, partition.index) + var readCachedBlock = true + // This method is called on executors, so we need call SparkEnv.get instead of sc.env. + SparkEnv.get.blockManager.getOrElseUpdate(blockId, storageLevel, () => { + readCachedBlock = false + computeOrReadCheckpoint(partition, context) + }) match { + case Left(blockResult) => + if (readCachedBlock) { + val existingMetrics = context.taskMetrics().registerInputMetrics(blockResult.readMethod) + existingMetrics.incBytesReadInternal(blockResult.bytes) + new InterruptibleIterator[T](context, blockResult.data.asInstanceOf[Iterator[T]]) { + override def next(): T = { + existingMetrics.incRecordsReadInternal(1) + delegate.next() + } + } + } else { + new InterruptibleIterator(context, blockResult.data.asInstanceOf[Iterator[T]]) + } + case Right(iter) => + new InterruptibleIterator(context, iter.asInstanceOf[Iterator[T]]) + } + } + + /** * Execute a block of code in a scope such that all new RDDs created in this body will * be part of the same scope. For more detail, see {{org.apache.spark.rdd.RDDOperationScope}}. * |