aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--bagel/src/test/scala/bagel/BagelSuite.scala2
-rw-r--r--core/src/main/scala/spark/MapOutputTracker.scala12
-rw-r--r--core/src/main/scala/spark/RDDCache.scala19
-rw-r--r--core/src/main/scala/spark/SparkContext.scala3
-rw-r--r--core/src/main/scala/spark/repl/SparkInterpreterLoop.scala5
-rw-r--r--core/src/test/scala/spark/ShuffleSuite.scala11
-rw-r--r--core/src/test/scala/spark/repl/ReplSuite.scala2
7 files changed, 49 insertions, 5 deletions
diff --git a/bagel/src/test/scala/bagel/BagelSuite.scala b/bagel/src/test/scala/bagel/BagelSuite.scala
index 1b47fc9ed5..9e64d3f136 100644
--- a/bagel/src/test/scala/bagel/BagelSuite.scala
+++ b/bagel/src/test/scala/bagel/BagelSuite.scala
@@ -28,6 +28,7 @@ class BagelSuite extends FunSuite with Assertions {
})
for (vert <- result.collect)
assert(vert.age === numSupersteps)
+ sc.stop()
}
test("halting by message silence") {
@@ -49,5 +50,6 @@ class BagelSuite extends FunSuite with Assertions {
})
for (vert <- result.collect)
assert(vert.age === numSupersteps)
+ sc.stop()
}
}
diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala
index 07fd605cca..4334034ecb 100644
--- a/core/src/main/scala/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/spark/MapOutputTracker.scala
@@ -9,6 +9,7 @@ import scala.collection.mutable.HashSet
sealed trait MapOutputTrackerMessage
case class GetMapOutputLocations(shuffleId: Int) extends MapOutputTrackerMessage
+case object StopMapOutputTracker extends MapOutputTrackerMessage
class MapOutputTracker(serverUris: ConcurrentHashMap[Int, Array[String]])
extends DaemonActor with Logging {
@@ -23,6 +24,9 @@ extends DaemonActor with Logging {
case GetMapOutputLocations(shuffleId: Int) =>
logInfo("Asked to get map output locations for shuffle " + shuffleId)
reply(serverUris.get(shuffleId))
+ case StopMapOutputTracker =>
+ reply('OK)
+ exit()
}
}
}
@@ -95,4 +99,10 @@ object MapOutputTracker extends Logging {
def getMapOutputUri(serverUri: String, shuffleId: Int, mapId: Int, reduceId: Int): String = {
"%s/shuffle/%s/%s/%s".format(serverUri, shuffleId, mapId, reduceId)
}
-} \ No newline at end of file
+
+ def stop() {
+ trackerActor !? StopMapOutputTracker
+ serverUris.clear()
+ trackerActor = null
+ }
+}
diff --git a/core/src/main/scala/spark/RDDCache.scala b/core/src/main/scala/spark/RDDCache.scala
index d6c63e61ec..c5557159a6 100644
--- a/core/src/main/scala/spark/RDDCache.scala
+++ b/core/src/main/scala/spark/RDDCache.scala
@@ -12,6 +12,7 @@ case class DroppedFromCache(rddId: Int, partition: Int, host: String) extends Ca
case class MemoryCacheLost(host: String) extends CacheMessage
case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheMessage
case object GetCacheLocations extends CacheMessage
+case object StopCacheTracker extends CacheMessage
class RDDCacheTracker extends DaemonActor with Logging {
val locs = new HashMap[Int, Array[List[String]]]
@@ -50,6 +51,10 @@ class RDDCacheTracker extends DaemonActor with Logging {
locsCopy(rddId) = array.clone()
}
reply(locsCopy)
+
+ case StopCacheTracker =>
+ reply('OK)
+ exit()
}
}
}
@@ -57,15 +62,15 @@ class RDDCacheTracker extends DaemonActor with Logging {
private object RDDCache extends Logging {
// Stores map results for various splits locally
- val cache = Cache.newKeySpace()
+ var cache: KeySpace = null
- // Remembers which splits are currently being loaded
+ // Remembers which splits are currently being loaded (on worker nodes)
val loading = new HashSet[(Int, Int)]
// Tracker actor on the master, or remote reference to it on workers
var trackerActor: AbstractActor = null
- val registeredRddIds = new HashSet[Int]
+ var registeredRddIds: HashSet[Int] = null
def initialize(isMaster: Boolean) {
if (isMaster) {
@@ -77,6 +82,8 @@ private object RDDCache extends Logging {
val port = System.getProperty("spark.master.port").toInt
trackerActor = RemoteActor.select(Node(host, port), 'RDDCacheTracker)
}
+ registeredRddIds = new HashSet[Int]
+ cache = Cache.newKeySpace()
}
// Registers an RDD (on master only)
@@ -138,4 +145,10 @@ private object RDDCache extends Logging {
return Iterator.fromArray(array)
}
}
+
+ def stop() {
+ trackerActor !? StopCacheTracker
+ registeredRddIds.clear()
+ trackerActor = null
+ }
}
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 9357db22c4..dc6964e14b 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -121,6 +121,9 @@ extends Logging {
def stop() {
scheduler.stop()
scheduler = null
+ // TODO: Broadcast.stop(), Cache.stop()?
+ MapOutputTracker.stop()
+ RDDCache.stop()
}
// Wait for the scheduler to be registered
diff --git a/core/src/main/scala/spark/repl/SparkInterpreterLoop.scala b/core/src/main/scala/spark/repl/SparkInterpreterLoop.scala
index d4974009ce..a118abf3ca 100644
--- a/core/src/main/scala/spark/repl/SparkInterpreterLoop.scala
+++ b/core/src/main/scala/spark/repl/SparkInterpreterLoop.scala
@@ -260,6 +260,8 @@ extends InterpreterControl {
plushln("Type :help for more information.")
}
+ var sparkContext: SparkContext = null
+
def createSparkContext(): SparkContext = {
val master = this.master match {
case Some(m) => m
@@ -268,7 +270,8 @@ extends InterpreterControl {
if (prop != null) prop else "local"
}
}
- new SparkContext(master, "Spark shell")
+ sparkContext = new SparkContext(master, "Spark shell")
+ sparkContext
}
/** The main read-eval-print loop for the interpreter. It calls
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index a5773614e8..3089360756 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -18,6 +18,7 @@ class ShuffleSuite extends FunSuite {
assert(valuesFor1.toList.sorted === List(1, 2, 3))
val valuesFor2 = groups.find(_._1 == 2).get._2
assert(valuesFor2.toList.sorted === List(1))
+ sc.stop()
}
test("groupByKey with duplicates") {
@@ -29,6 +30,7 @@ class ShuffleSuite extends FunSuite {
assert(valuesFor1.toList.sorted === List(1, 1, 2, 3))
val valuesFor2 = groups.find(_._1 == 2).get._2
assert(valuesFor2.toList.sorted === List(1))
+ sc.stop()
}
test("groupByKey with negative key hash codes") {
@@ -40,6 +42,7 @@ class ShuffleSuite extends FunSuite {
assert(valuesForMinus1.toList.sorted === List(1, 2, 3))
val valuesFor2 = groups.find(_._1 == 2).get._2
assert(valuesFor2.toList.sorted === List(1))
+ sc.stop()
}
test("groupByKey with many output partitions") {
@@ -51,6 +54,7 @@ class ShuffleSuite extends FunSuite {
assert(valuesFor1.toList.sorted === List(1, 2, 3))
val valuesFor2 = groups.find(_._1 == 2).get._2
assert(valuesFor2.toList.sorted === List(1))
+ sc.stop()
}
test("reduceByKey") {
@@ -58,6 +62,7 @@ class ShuffleSuite extends FunSuite {
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
val sums = pairs.reduceByKey(_+_).collect()
assert(sums.toSet === Set((1, 7), (2, 1)))
+ sc.stop()
}
test("reduceByKey with collectAsMap") {
@@ -67,6 +72,7 @@ class ShuffleSuite extends FunSuite {
assert(sums.size === 2)
assert(sums(1) === 7)
assert(sums(2) === 1)
+ sc.stop()
}
test("reduceByKey with many output partitons") {
@@ -74,6 +80,7 @@ class ShuffleSuite extends FunSuite {
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
val sums = pairs.reduceByKey(_+_, 10).collect()
assert(sums.toSet === Set((1, 7), (2, 1)))
+ sc.stop()
}
test("join") {
@@ -88,6 +95,7 @@ class ShuffleSuite extends FunSuite {
(2, (1, 'y')),
(2, (1, 'z'))
))
+ sc.stop()
}
test("join all-to-all") {
@@ -104,6 +112,7 @@ class ShuffleSuite extends FunSuite {
(1, (3, 'x')),
(1, (3, 'y'))
))
+ sc.stop()
}
test("join with no matches") {
@@ -112,6 +121,7 @@ class ShuffleSuite extends FunSuite {
val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w')))
val joined = rdd1.join(rdd2).collect()
assert(joined.size === 0)
+ sc.stop()
}
test("join with many output partitions") {
@@ -126,5 +136,6 @@ class ShuffleSuite extends FunSuite {
(2, (1, 'y')),
(2, (1, 'z'))
))
+ sc.stop()
}
}
diff --git a/core/src/test/scala/spark/repl/ReplSuite.scala b/core/src/test/scala/spark/repl/ReplSuite.scala
index 225e766c71..829b1d934e 100644
--- a/core/src/test/scala/spark/repl/ReplSuite.scala
+++ b/core/src/test/scala/spark/repl/ReplSuite.scala
@@ -27,6 +27,8 @@ class ReplSuite extends FunSuite {
val separator = System.getProperty("path.separator")
interp.main(Array("-classpath", paths.mkString(separator)))
spark.repl.Main.interp = null
+ if (interp.sparkContext != null)
+ interp.sparkContext.stop()
return out.toString
}