diff options
author | Aaron Davidson <aaron@databricks.com> | 2013-09-05 15:28:14 -0700 |
---|---|---|
committer | Aaron Davidson <aaron@databricks.com> | 2013-09-05 15:34:42 -0700 |
commit | 1418d18af43229b442d3ed747fdb8088d4fa5b6f (patch) | |
tree | dba5fb804374b0a6bc3523e9b1ef2ef20f58e90c /core | |
parent | 714e7f9e32590c302ad315b7cbee72b2e8b32b9b (diff) | |
download | spark-1418d18af43229b442d3ed747fdb8088d4fa5b6f.tar.gz spark-1418d18af43229b442d3ed747fdb8088d4fa5b6f.tar.bz2 spark-1418d18af43229b442d3ed747fdb8088d4fa5b6f.zip |
SPARK-821: Don't cache results when action run locally on driver
Caching the results of local actions (e.g., rdd.first()) causes the driver to
store entire partitions in its own memory, which may be highly constrained.
This patch simply makes the CacheManager avoid caching the result of all locally-run computations.
Diffstat (limited to 'core')
4 files changed, 5 insertions, 4 deletions
diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index e299a106ee..a6f701b880 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -69,8 +69,8 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { val elements = new ArrayBuffer[Any] logInfo("Computing partition " + split) elements ++= rdd.computeOrReadCheckpoint(split, context) - // Try to put this block in the blockManager - blockManager.put(key, elements, storageLevel, true) + // Persist the result, so long as the task is not running locally + if (!context.runningLocally) blockManager.put(key, elements, storageLevel, true) return elements.iterator.asInstanceOf[Iterator[T]] } finally { loading.synchronized { diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index b2dd668330..c2c358c7ad 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -24,6 +24,7 @@ class TaskContext( val stageId: Int, val splitId: Int, val attemptId: Long, + val runningLocally: Boolean = false, val taskMetrics: TaskMetrics = TaskMetrics.empty() ) extends Serializable { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 92add5b073..b739118e2f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -478,7 +478,7 @@ class DAGScheduler( SparkEnv.set(env) val rdd = job.finalStage.rdd val split = rdd.partitions(job.partitions(0)) - val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0) + val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0, true) try { val result = job.func(taskContext, rdd.iterator(split, taskContext)) job.listener.taskSucceeded(0, result) diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java index 8a869c9005..591c1d498d 100644 --- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java @@ -495,7 +495,7 @@ public class JavaAPISuite implements Serializable { @Test public void iterator() { JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContext(0, 0, 0, null); + TaskContext context = new TaskContext(0, 0, 0, false, null); Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue()); } |