aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorAaron Davidson <aaron@databricks.com>2013-09-05 18:01:43 -0700
committerAaron Davidson <aaron@databricks.com>2013-09-05 18:06:30 -0700
commit4f2236a1c5fdc6c3a5711cef73407f926bf1a779 (patch)
tree4a6ca7f895137ba0d01a76d8e939b8e51fd864cd /core
parent1418d18af43229b442d3ed747fdb8088d4fa5b6f (diff)
downloadspark-4f2236a1c5fdc6c3a5711cef73407f926bf1a779.tar.gz
spark-4f2236a1c5fdc6c3a5711cef73407f926bf1a779.tar.bz2
spark-4f2236a1c5fdc6c3a5711cef73407f926bf1a779.zip
Add unit test and address comments
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/CacheManager.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/CacheManagerSuite.scala89
5 files changed, 98 insertions, 6 deletions
diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala
index a6f701b880..85a2b2b331 100644
--- a/core/src/main/scala/org/apache/spark/CacheManager.scala
+++ b/core/src/main/scala/org/apache/spark/CacheManager.scala
@@ -66,11 +66,13 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
}
try {
// If we got here, we have to load the split
- val elements = new ArrayBuffer[Any]
logInfo("Computing partition " + split)
- elements ++= rdd.computeOrReadCheckpoint(split, context)
+ val computedValues = rdd.computeOrReadCheckpoint(split, context)
// Persist the result, so long as the task is not running locally
- if (!context.runningLocally) blockManager.put(key, elements, storageLevel, true)
+ if (context.runningLocally) return computedValues
+ val elements = new ArrayBuffer[Any]
+ elements ++= computedValues
+ blockManager.put(key, elements, storageLevel, true)
return elements.iterator.asInstanceOf[Iterator[T]]
} finally {
loading.synchronized {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index b739118e2f..ba329e1a57 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -478,7 +478,8 @@ class DAGScheduler(
SparkEnv.set(env)
val rdd = job.finalStage.rdd
val split = rdd.partitions(job.partitions(0))
- val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0, true)
+ val taskContext =
+ new TaskContext(job.finalStage.id, job.partitions(0), 0, runningLocally = true)
try {
val result = job.func(taskContext, rdd.iterator(split, taskContext))
job.listener.taskSucceeded(0, result)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index 2b007cbe82..ca44ebb189 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -93,7 +93,7 @@ private[spark] class ResultTask[T, U](
}
override def run(attemptId: Long): U = {
- val context = new TaskContext(stageId, partition, attemptId)
+ val context = new TaskContext(stageId, partition, attemptId, runningLocally = false)
metrics = Some(context.taskMetrics)
try {
func(context, rdd.iterator(split, context))
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 764775fede..d23df0dd2b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -132,7 +132,7 @@ private[spark] class ShuffleMapTask(
override def run(attemptId: Long): MapStatus = {
val numOutputSplits = dep.partitioner.numPartitions
- val taskContext = new TaskContext(stageId, partition, attemptId)
+ val taskContext = new TaskContext(stageId, partition, attemptId, runningLocally = false)
metrics = Some(taskContext.taskMetrics)
val blockManager = SparkEnv.get.blockManager
diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
new file mode 100644
index 0000000000..a85d666ace
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import scala.Array
+import scala.collection.mutable.ArrayBuffer
+
+import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.scalatest.mock.EasyMockSugar
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.{BlockManager, StorageLevel}
+import org.junit.Assert
+
+// TODO: Test the CacheManager's thread-safety aspects
+class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar {
+ val sc = new SparkContext("local", "test")
+
+ val split = new Partition { override def index: Int = 0 }
+
+ // An RDD which returns the values [1, 2, 3, 4].
+ val rdd = new RDD[Int](sc, Nil) {
+ override def getPartitions: Array[Partition] = Array(split)
+ override val getDependencies = List[Dependency[_]]()
+ override def compute(split: Partition, context: TaskContext) = Array(1, 2, 3, 4).iterator
+ }
+
+ var blockManager: BlockManager = _
+ var cacheManager: CacheManager = _
+
+ before {
+ blockManager = mock[BlockManager]
+ cacheManager = new CacheManager(blockManager)
+ }
+
+ test("get uncached rdd") {
+ expecting {
+ blockManager.get("rdd_0_0").andReturn(None)
+ blockManager.put("rdd_0_0", ArrayBuffer[Any](1, 2, 3, 4), StorageLevel.MEMORY_ONLY, true).
+ andReturn(0)
+ }
+
+ whenExecuting(blockManager) {
+ val context = new TaskContext(0, 0, 0, runningLocally = false, null)
+ val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
+ Assert.assertEquals(value.toList, List(1, 2, 3, 4))
+ }
+ }
+
+ test("get cached rdd") {
+ expecting {
+ blockManager.get("rdd_0_0").andReturn(Some(ArrayBuffer(5, 6, 7).iterator))
+ }
+
+ whenExecuting(blockManager) {
+ val context = new TaskContext(0, 0, 0, runningLocally = false, null)
+ val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
+ Assert.assertEquals(value.toList, List(5, 6, 7))
+ }
+ }
+
+ test("get uncached local rdd") {
+ expecting {
+ // Local computation should not persist the resulting value, so don't expect a put().
+ blockManager.get("rdd_0_0").andReturn(None)
+ }
+
+ whenExecuting(blockManager) {
+ val context = new TaskContext(0, 0, 0, runningLocally = true, null)
+ val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
+ Assert.assertEquals(value.toList, List(1, 2, 3, 4))
+ }
+ }
+}