aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRichard Benkovsky <richard.benkovsky@gooddata.com>2012-05-20 10:05:43 +0200
committerRichard Benkovsky <richard.benkovsky@gooddata.com>2012-05-22 11:04:54 +0200
commit3a1bcd4028d84fa5cc7a7cb230f41ae6bb87c352 (patch)
tree3bc8794b2f7cd842fd360051b21872b72cf732e4
parent8f2f736d5311968f7fa0baa93b1bbc8d7aeed4e1 (diff)
downloadspark-3a1bcd4028d84fa5cc7a7cb230f41ae6bb87c352.tar.gz
spark-3a1bcd4028d84fa5cc7a7cb230f41ae6bb87c352.tar.bz2
spark-3a1bcd4028d84fa5cc7a7cb230f41ae6bb87c352.zip
Added tests for CacheTrackerActor
-rw-r--r--core/src/main/scala/spark/CacheTracker.scala4
-rw-r--r--core/src/test/scala/spark/CacheTrackerSuite.scala97
2 files changed, 100 insertions, 1 deletions
diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala
index 0719f14a39..4867829c17 100644
--- a/core/src/main/scala/spark/CacheTracker.scala
+++ b/core/src/main/scala/spark/CacheTracker.scala
@@ -82,7 +82,8 @@ class CacheTrackerActor extends DaemonActor with Logging {
logInfo("Cache entry removed: (%s, %s) on %s".format(rddId, partition, host))
}
locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host)
-
+ reply('OK)
+
case MemoryCacheLost(host) =>
logInfo("Memory cache lost on " + host)
// TODO: Drop host from the memory locations list of all RDDs
@@ -225,6 +226,7 @@ class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging {
// Called by the Cache to report that an entry has been dropped from it
def dropEntry(datasetId: Any, partition: Int) {
datasetId match {
+ //TODO - do we really want to use '!!' when nobody checks returned future? '!' seems to enough here.
case (cache.keySpaceId, rddId: Int) => trackerActor !! DroppedFromCache(rddId, partition, Utils.getHost)
}
}
diff --git a/core/src/test/scala/spark/CacheTrackerSuite.scala b/core/src/test/scala/spark/CacheTrackerSuite.scala
new file mode 100644
index 0000000000..60290d14ca
--- /dev/null
+++ b/core/src/test/scala/spark/CacheTrackerSuite.scala
@@ -0,0 +1,97 @@
+package spark
+
+import org.scalatest.FunSuite
+import collection.mutable.HashMap
+
+class CacheTrackerSuite extends FunSuite {
+
+ test("CacheTrackerActor slave initialization & cache status") {
+ System.setProperty("spark.master.port", "1345")
+ val initialSize = 2L << 20
+
+ val tracker = new CacheTrackerActor
+ tracker.start()
+
+ tracker !? SlaveCacheStarted("host001", initialSize)
+
+ assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 0L)))
+
+ tracker !? StopCacheTracker
+ }
+
+ test("RegisterRDD") {
+ System.setProperty("spark.master.port", "1345")
+ val initialSize = 2L << 20
+
+ val tracker = new CacheTrackerActor
+ tracker.start()
+
+ tracker !? SlaveCacheStarted("host001", initialSize)
+
+ tracker !? RegisterRDD(1, 3)
+ tracker !? RegisterRDD(2, 1)
+
+ assert(getCacheLocations(tracker) == Map(1 -> List(List(), List(), List()), 2 -> List(List())))
+
+ tracker !? StopCacheTracker
+ }
+
+ test("AddedToCache") {
+ System.setProperty("spark.master.port", "1345")
+ val initialSize = 2L << 20
+
+ val tracker = new CacheTrackerActor
+ tracker.start()
+
+ tracker !? SlaveCacheStarted("host001", initialSize)
+
+ tracker !? RegisterRDD(1, 2)
+ tracker !? RegisterRDD(2, 1)
+
+ tracker !? AddedToCache(1, 0, "host001", 2L << 15)
+ tracker !? AddedToCache(1, 1, "host001", 2L << 11)
+ tracker !? AddedToCache(2, 0, "host001", 3L << 10)
+
+ assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 72704L)))
+
+ assert(getCacheLocations(tracker) == Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
+
+ tracker !? StopCacheTracker
+ }
+
+ test("DroppedFromCache") {
+ System.setProperty("spark.master.port", "1345")
+ val initialSize = 2L << 20
+
+ val tracker = new CacheTrackerActor
+ tracker.start()
+
+ tracker !? SlaveCacheStarted("host001", initialSize)
+
+ tracker !? RegisterRDD(1, 2)
+ tracker !? RegisterRDD(2, 1)
+
+ tracker !? AddedToCache(1, 0, "host001", 2L << 15)
+ tracker !? AddedToCache(1, 1, "host001", 2L << 11)
+ tracker !? AddedToCache(2, 0, "host001", 3L << 10)
+
+ assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 72704L)))
+ assert(getCacheLocations(tracker) == Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
+
+ tracker !? DroppedFromCache(1, 1, "host001", 2L << 11)
+
+ assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 68608L)))
+ assert(getCacheLocations(tracker) == Map(1 -> List(List("host001"),List()), 2 -> List(List("host001"))))
+
+ tracker !? StopCacheTracker
+ }
+
+ /**
+ * Helper function to get cacheLocations from CacheTracker
+ */
+ def getCacheLocations(tracker: CacheTrackerActor) = tracker !? GetCacheLocations match {
+ case h: HashMap[_, _] => h.asInstanceOf[HashMap[Int, Array[List[String]]]].map {
+ case (i, arr) => (i -> arr.toList)
+ }
+ }
+}