aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/org/apache/spark/rdd/RDD.scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala/org/apache/spark/rdd/RDD.scala')
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala33
1 files changed, 31 insertions, 2 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 6a6ad2d75a..e5fdebc65d 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -37,7 +37,7 @@ import org.apache.spark.partial.BoundedDouble
import org.apache.spark.partial.CountEvaluator
import org.apache.spark.partial.GroupedCountEvaluator
import org.apache.spark.partial.PartialResult
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.storage.{RDDBlockId, StorageLevel}
import org.apache.spark.util.{BoundedPriorityQueue, Utils}
import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler,
@@ -272,7 +272,7 @@ abstract class RDD[T: ClassTag](
*/
final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
if (storageLevel != StorageLevel.NONE) {
- SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel)
+ getOrCompute(split, context)
} else {
computeOrReadCheckpoint(split, context)
}
@@ -315,6 +315,35 @@ abstract class RDD[T: ClassTag](
}
/**
+ * Gets or computes an RDD partition. Used by RDD.iterator() when an RDD is cached.
+ */
+ private[spark] def getOrCompute(partition: Partition, context: TaskContext): Iterator[T] = {
+ val blockId = RDDBlockId(id, partition.index)
+ var readCachedBlock = true
+ // This method is called on executors, so we need call SparkEnv.get instead of sc.env.
+ SparkEnv.get.blockManager.getOrElseUpdate(blockId, storageLevel, () => {
+ readCachedBlock = false
+ computeOrReadCheckpoint(partition, context)
+ }) match {
+ case Left(blockResult) =>
+ if (readCachedBlock) {
+ val existingMetrics = context.taskMetrics().registerInputMetrics(blockResult.readMethod)
+ existingMetrics.incBytesReadInternal(blockResult.bytes)
+ new InterruptibleIterator[T](context, blockResult.data.asInstanceOf[Iterator[T]]) {
+ override def next(): T = {
+ existingMetrics.incRecordsReadInternal(1)
+ delegate.next()
+ }
+ }
+ } else {
+ new InterruptibleIterator(context, blockResult.data.asInstanceOf[Iterator[T]])
+ }
+ case Right(iter) =>
+ new InterruptibleIterator(context, iter.asInstanceOf[Iterator[T]])
+ }
+ }
+
+ /**
* Execute a block of code in a scope such that all new RDDs created in this body will
* be part of the same scope. For more detail, see {{org.apache.spark.rdd.RDDOperationScope}}.
*