aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2011-05-13 12:03:58 -0700
committerMatei Zaharia <matei@eecs.berkeley.edu>2011-05-13 12:03:58 -0700
commit4db50e26c75263b2edae468b0e8a9283b5c2e6f1 (patch)
tree19572154b19a6b4aa4309020b8c2c29803868e84 /core
parentaca8150c52d0cd3f54b47ab90010ab369a820844 (diff)
downloadspark-4db50e26c75263b2edae468b0e8a9283b5c2e6f1.tar.gz
spark-4db50e26c75263b2edae468b0e8a9283b5c2e6f1.tar.bz2
spark-4db50e26c75263b2edae468b0e8a9283b5c2e6f1.zip
Fixed unit tests by making them clean up the SparkContext after use and
thus clean up the various singletons (RDDCache, MapOutputTracker, etc). This isn't perfect yet (ideally we shouldn't use singleton objects at all) but we can fix that later.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/spark/MapOutputTracker.scala12
-rw-r--r--core/src/main/scala/spark/RDDCache.scala19
-rw-r--r--core/src/main/scala/spark/SparkContext.scala3
-rw-r--r--core/src/main/scala/spark/repl/SparkInterpreterLoop.scala5
-rw-r--r--core/src/test/scala/spark/ShuffleSuite.scala11
-rw-r--r--core/src/test/scala/spark/repl/ReplSuite.scala2
6 files changed, 47 insertions, 5 deletions
diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala
index 07fd605cca..4334034ecb 100644
--- a/core/src/main/scala/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/spark/MapOutputTracker.scala
@@ -9,6 +9,7 @@ import scala.collection.mutable.HashSet
sealed trait MapOutputTrackerMessage
case class GetMapOutputLocations(shuffleId: Int) extends MapOutputTrackerMessage
+case object StopMapOutputTracker extends MapOutputTrackerMessage
class MapOutputTracker(serverUris: ConcurrentHashMap[Int, Array[String]])
extends DaemonActor with Logging {
@@ -23,6 +24,9 @@ extends DaemonActor with Logging {
case GetMapOutputLocations(shuffleId: Int) =>
logInfo("Asked to get map output locations for shuffle " + shuffleId)
reply(serverUris.get(shuffleId))
+ case StopMapOutputTracker =>
+ reply('OK)
+ exit()
}
}
}
@@ -95,4 +99,10 @@ object MapOutputTracker extends Logging {
def getMapOutputUri(serverUri: String, shuffleId: Int, mapId: Int, reduceId: Int): String = {
"%s/shuffle/%s/%s/%s".format(serverUri, shuffleId, mapId, reduceId)
}
-} \ No newline at end of file
+
+ def stop() {
+ trackerActor !? StopMapOutputTracker
+ serverUris.clear()
+ trackerActor = null
+ }
+}
diff --git a/core/src/main/scala/spark/RDDCache.scala b/core/src/main/scala/spark/RDDCache.scala
index d6c63e61ec..c5557159a6 100644
--- a/core/src/main/scala/spark/RDDCache.scala
+++ b/core/src/main/scala/spark/RDDCache.scala
@@ -12,6 +12,7 @@ case class DroppedFromCache(rddId: Int, partition: Int, host: String) extends Ca
case class MemoryCacheLost(host: String) extends CacheMessage
case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheMessage
case object GetCacheLocations extends CacheMessage
+case object StopCacheTracker extends CacheMessage
class RDDCacheTracker extends DaemonActor with Logging {
val locs = new HashMap[Int, Array[List[String]]]
@@ -50,6 +51,10 @@ class RDDCacheTracker extends DaemonActor with Logging {
locsCopy(rddId) = array.clone()
}
reply(locsCopy)
+
+ case StopCacheTracker =>
+ reply('OK)
+ exit()
}
}
}
@@ -57,15 +62,15 @@ class RDDCacheTracker extends DaemonActor with Logging {
private object RDDCache extends Logging {
// Stores map results for various splits locally
- val cache = Cache.newKeySpace()
+ var cache: KeySpace = null
- // Remembers which splits are currently being loaded
+ // Remembers which splits are currently being loaded (on worker nodes)
val loading = new HashSet[(Int, Int)]
// Tracker actor on the master, or remote reference to it on workers
var trackerActor: AbstractActor = null
- val registeredRddIds = new HashSet[Int]
+ var registeredRddIds: HashSet[Int] = null
def initialize(isMaster: Boolean) {
if (isMaster) {
@@ -77,6 +82,8 @@ private object RDDCache extends Logging {
val port = System.getProperty("spark.master.port").toInt
trackerActor = RemoteActor.select(Node(host, port), 'RDDCacheTracker)
}
+ registeredRddIds = new HashSet[Int]
+ cache = Cache.newKeySpace()
}
// Registers an RDD (on master only)
@@ -138,4 +145,10 @@ private object RDDCache extends Logging {
return Iterator.fromArray(array)
}
}
+
+ def stop() {
+ trackerActor !? StopCacheTracker
+ registeredRddIds.clear()
+ trackerActor = null
+ }
}
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 9357db22c4..dc6964e14b 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -121,6 +121,9 @@ extends Logging {
def stop() {
scheduler.stop()
scheduler = null
+ // TODO: Broadcast.stop(), Cache.stop()?
+ MapOutputTracker.stop()
+ RDDCache.stop()
}
// Wait for the scheduler to be registered
diff --git a/core/src/main/scala/spark/repl/SparkInterpreterLoop.scala b/core/src/main/scala/spark/repl/SparkInterpreterLoop.scala
index d4974009ce..a118abf3ca 100644
--- a/core/src/main/scala/spark/repl/SparkInterpreterLoop.scala
+++ b/core/src/main/scala/spark/repl/SparkInterpreterLoop.scala
@@ -260,6 +260,8 @@ extends InterpreterControl {
plushln("Type :help for more information.")
}
+ var sparkContext: SparkContext = null
+
def createSparkContext(): SparkContext = {
val master = this.master match {
case Some(m) => m
@@ -268,7 +270,8 @@ extends InterpreterControl {
if (prop != null) prop else "local"
}
}
- new SparkContext(master, "Spark shell")
+ sparkContext = new SparkContext(master, "Spark shell")
+ sparkContext
}
/** The main read-eval-print loop for the interpreter. It calls
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index a5773614e8..3089360756 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -18,6 +18,7 @@ class ShuffleSuite extends FunSuite {
assert(valuesFor1.toList.sorted === List(1, 2, 3))
val valuesFor2 = groups.find(_._1 == 2).get._2
assert(valuesFor2.toList.sorted === List(1))
+ sc.stop()
}
test("groupByKey with duplicates") {
@@ -29,6 +30,7 @@ class ShuffleSuite extends FunSuite {
assert(valuesFor1.toList.sorted === List(1, 1, 2, 3))
val valuesFor2 = groups.find(_._1 == 2).get._2
assert(valuesFor2.toList.sorted === List(1))
+ sc.stop()
}
test("groupByKey with negative key hash codes") {
@@ -40,6 +42,7 @@ class ShuffleSuite extends FunSuite {
assert(valuesForMinus1.toList.sorted === List(1, 2, 3))
val valuesFor2 = groups.find(_._1 == 2).get._2
assert(valuesFor2.toList.sorted === List(1))
+ sc.stop()
}
test("groupByKey with many output partitions") {
@@ -51,6 +54,7 @@ class ShuffleSuite extends FunSuite {
assert(valuesFor1.toList.sorted === List(1, 2, 3))
val valuesFor2 = groups.find(_._1 == 2).get._2
assert(valuesFor2.toList.sorted === List(1))
+ sc.stop()
}
test("reduceByKey") {
@@ -58,6 +62,7 @@ class ShuffleSuite extends FunSuite {
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
val sums = pairs.reduceByKey(_+_).collect()
assert(sums.toSet === Set((1, 7), (2, 1)))
+ sc.stop()
}
test("reduceByKey with collectAsMap") {
@@ -67,6 +72,7 @@ class ShuffleSuite extends FunSuite {
assert(sums.size === 2)
assert(sums(1) === 7)
assert(sums(2) === 1)
+ sc.stop()
}
test("reduceByKey with many output partitons") {
@@ -74,6 +80,7 @@ class ShuffleSuite extends FunSuite {
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
val sums = pairs.reduceByKey(_+_, 10).collect()
assert(sums.toSet === Set((1, 7), (2, 1)))
+ sc.stop()
}
test("join") {
@@ -88,6 +95,7 @@ class ShuffleSuite extends FunSuite {
(2, (1, 'y')),
(2, (1, 'z'))
))
+ sc.stop()
}
test("join all-to-all") {
@@ -104,6 +112,7 @@ class ShuffleSuite extends FunSuite {
(1, (3, 'x')),
(1, (3, 'y'))
))
+ sc.stop()
}
test("join with no matches") {
@@ -112,6 +121,7 @@ class ShuffleSuite extends FunSuite {
val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w')))
val joined = rdd1.join(rdd2).collect()
assert(joined.size === 0)
+ sc.stop()
}
test("join with many output partitions") {
@@ -126,5 +136,6 @@ class ShuffleSuite extends FunSuite {
(2, (1, 'y')),
(2, (1, 'z'))
))
+ sc.stop()
}
}
diff --git a/core/src/test/scala/spark/repl/ReplSuite.scala b/core/src/test/scala/spark/repl/ReplSuite.scala
index 225e766c71..829b1d934e 100644
--- a/core/src/test/scala/spark/repl/ReplSuite.scala
+++ b/core/src/test/scala/spark/repl/ReplSuite.scala
@@ -27,6 +27,8 @@ class ReplSuite extends FunSuite {
val separator = System.getProperty("path.separator")
interp.main(Array("-classpath", paths.mkString(separator)))
spark.repl.Main.interp = null
+ if (interp.sparkContext != null)
+ interp.sparkContext.stop()
return out.toString
}