aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPatrick Wendell <pwendell@gmail.com>2013-09-05 23:54:09 -0700
committerPatrick Wendell <pwendell@gmail.com>2013-09-05 23:54:09 -0700
commitddcb9d310abd258ba49ee764ff076dee9178f975 (patch)
tree1c8932ae496dcda746ecb18feae778143e0566f1
parent699c331f2f1ea57adf43f9bf5c56f8400b45acb9 (diff)
parent3a04e76c89ac44a1345417597971e6e37944d371 (diff)
downloadspark-ddcb9d310abd258ba49ee764ff076dee9178f975.tar.gz
spark-ddcb9d310abd258ba49ee764ff076dee9178f975.tar.bz2
spark-ddcb9d310abd258ba49ee764ff076dee9178f975.zip
Merge pull request #895 from ilikerps/821
SPARK-821: Don't cache results when action run locally on driver
-rw-r--r--core/src/main/scala/org/apache/spark/CacheManager.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContext.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/CacheManagerSuite.scala91
-rw-r--r--core/src/test/scala/org/apache/spark/JavaAPISuite.java2
7 files changed, 102 insertions, 7 deletions
diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala
index e299a106ee..68b99ca125 100644
--- a/core/src/main/scala/org/apache/spark/CacheManager.scala
+++ b/core/src/main/scala/org/apache/spark/CacheManager.scala
@@ -66,10 +66,12 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
}
try {
// If we got here, we have to load the split
- val elements = new ArrayBuffer[Any]
logInfo("Computing partition " + split)
- elements ++= rdd.computeOrReadCheckpoint(split, context)
- // Try to put this block in the blockManager
+ val computedValues = rdd.computeOrReadCheckpoint(split, context)
+ // Persist the result, so long as the task is not running locally
+ if (context.runningLocally) { return computedValues }
+ val elements = new ArrayBuffer[Any]
+ elements ++= computedValues
blockManager.put(key, elements, storageLevel, true)
return elements.iterator.asInstanceOf[Iterator[T]]
} finally {
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index b2dd668330..c2c358c7ad 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -24,6 +24,7 @@ class TaskContext(
val stageId: Int,
val splitId: Int,
val attemptId: Long,
+ val runningLocally: Boolean = false,
val taskMetrics: TaskMetrics = TaskMetrics.empty()
) extends Serializable {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index cfcabca0b7..3e3f04f087 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -478,7 +478,8 @@ class DAGScheduler(
SparkEnv.set(env)
val rdd = job.finalStage.rdd
val split = rdd.partitions(job.partitions(0))
- val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0)
+ val taskContext =
+ new TaskContext(job.finalStage.id, job.partitions(0), 0, runningLocally = true)
try {
val result = job.func(taskContext, rdd.iterator(split, taskContext))
job.listener.taskSucceeded(0, result)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index 2b007cbe82..ca44ebb189 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -93,7 +93,7 @@ private[spark] class ResultTask[T, U](
}
override def run(attemptId: Long): U = {
- val context = new TaskContext(stageId, partition, attemptId)
+ val context = new TaskContext(stageId, partition, attemptId, runningLocally = false)
metrics = Some(context.taskMetrics)
try {
func(context, rdd.iterator(split, context))
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 764775fede..d23df0dd2b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -132,7 +132,7 @@ private[spark] class ShuffleMapTask(
override def run(attemptId: Long): MapStatus = {
val numOutputSplits = dep.partitioner.numPartitions
- val taskContext = new TaskContext(stageId, partition, attemptId)
+ val taskContext = new TaskContext(stageId, partition, attemptId, runningLocally = false)
metrics = Some(taskContext.taskMetrics)
val blockManager = SparkEnv.get.blockManager
diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
new file mode 100644
index 0000000000..214f8244d5
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.scalatest.mock.EasyMockSugar
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.{BlockManager, StorageLevel}
+
+// TODO: Test the CacheManager's thread-safety aspects
+class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar {
+ var sc : SparkContext = _
+ var blockManager: BlockManager = _
+ var cacheManager: CacheManager = _
+ var split: Partition = _
+ /** An RDD which returns the values [1, 2, 3, 4]. */
+ var rdd: RDD[Int] = _
+
+ before {
+ sc = new SparkContext("local", "test")
+ blockManager = mock[BlockManager]
+ cacheManager = new CacheManager(blockManager)
+ split = new Partition { override def index: Int = 0 }
+ rdd = new RDD(sc, Nil) {
+ override def getPartitions: Array[Partition] = Array(split)
+ override val getDependencies = List[Dependency[_]]()
+ override def compute(split: Partition, context: TaskContext) = Array(1, 2, 3, 4).iterator
+ }
+ }
+
+ after {
+ sc.stop()
+ }
+
+ test("get uncached rdd") {
+ expecting {
+ blockManager.get("rdd_0_0").andReturn(None)
+ blockManager.put("rdd_0_0", ArrayBuffer[Any](1, 2, 3, 4), StorageLevel.MEMORY_ONLY, true).
+ andReturn(0)
+ }
+
+ whenExecuting(blockManager) {
+ val context = new TaskContext(0, 0, 0, runningLocally = false, null)
+ val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
+ assert(value.toList === List(1, 2, 3, 4))
+ }
+ }
+
+ test("get cached rdd") {
+ expecting {
+ blockManager.get("rdd_0_0").andReturn(Some(ArrayBuffer(5, 6, 7).iterator))
+ }
+
+ whenExecuting(blockManager) {
+ val context = new TaskContext(0, 0, 0, runningLocally = false, null)
+ val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
+ assert(value.toList === List(5, 6, 7))
+ }
+ }
+
+ test("get uncached local rdd") {
+ expecting {
+ // Local computation should not persist the resulting value, so don't expect a put().
+ blockManager.get("rdd_0_0").andReturn(None)
+ }
+
+ whenExecuting(blockManager) {
+ val context = new TaskContext(0, 0, 0, runningLocally = true, null)
+ val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
+ assert(value.toList === List(1, 2, 3, 4))
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
index 8a869c9005..591c1d498d 100644
--- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
@@ -495,7 +495,7 @@ public class JavaAPISuite implements Serializable {
@Test
public void iterator() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
- TaskContext context = new TaskContext(0, 0, 0, null);
+ TaskContext context = new TaskContext(0, 0, 0, false, null);
Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue());
}