diff options
Diffstat (limited to 'core/src/main/scala/spark/RDDCache.scala')
-rw-r--r-- | core/src/main/scala/spark/RDDCache.scala | 76 |
1 files changed, 61 insertions, 15 deletions
diff --git a/core/src/main/scala/spark/RDDCache.scala b/core/src/main/scala/spark/RDDCache.scala index 2f2ec9d237..aae2d74900 100644 --- a/core/src/main/scala/spark/RDDCache.scala +++ b/core/src/main/scala/spark/RDDCache.scala @@ -3,31 +3,57 @@ package spark import scala.actors._ import scala.actors.Actor._ import scala.actors.remote._ +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet sealed trait CacheMessage -case class CacheEntryAdded(rddId: Int, partition: Int, host: String) -case class CacheEntryRemoved(rddId: Int, partition: Int, host: String) +case class AddedToCache(rddId: Int, partition: Int, host: String) extends CacheMessage +case class DroppedFromCache(rddId: Int, partition: Int, host: String) extends CacheMessage +case class MemoryCacheLost(host: String) extends CacheMessage +case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheMessage +case object GetCacheLocations extends CacheMessage class RDDCacheTracker extends DaemonActor with Logging { + val locs = new HashMap[Int, Array[List[String]]] + // TODO: Should probably store (String, CacheType) tuples + def act() { val port = System.getProperty("spark.master.port", "50501").toInt RemoteActor.alive(port) RemoteActor.register('RDDCacheTracker, self) - logInfo("Started on port " + port) + logInfo("Registered actor on port " + port) loop { react { - case CacheEntryAdded(rddId, partition, host) => - logInfo("Cache entry added: %s, %s, %s".format(rddId, partition, host)) + case RegisterRDD(rddId: Int, numPartitions: Int) => + logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions") + locs(rddId) = Array.fill[List[String]](numPartitions)(Nil) + reply("") + + case AddedToCache(rddId, partition, host) => + logInfo("Cache entry added: (%s, %s) on %s".format(rddId, partition, host)) + locs(rddId)(partition) = host :: locs(rddId)(partition) - case CacheEntryRemoved(rddId, partition, host) => - logInfo("Cache entry removed: %s, %s, %s".format(rddId, partition, host)) + case DroppedFromCache(rddId, partition, host) => + logInfo("Cache entry removed: (%s, %s) on %s".format(rddId, partition, host)) + locs(rddId)(partition) -= host + + case MemoryCacheLost(host) => + logInfo("Memory cache lost on " + host) + // TODO: Drop host from the memory locations list of all RDDs + + case GetCacheLocations => + logInfo("Asked for current cache locations") + val locsCopy = new HashMap[Int, Array[List[String]]] + for ((rddId, array) <- locs) { + locsCopy(rddId) = array.clone() + } + reply(locsCopy) } } } } -import scala.collection.mutable.HashSet private object RDDCache extends Logging { // Stores map results for various splits locally val cache = Cache.newKeySpace() @@ -38,6 +64,8 @@ private object RDDCache extends Logging { // Tracker actor on the master, or remote reference to it on workers var trackerActor: AbstractActor = null + val registeredRddIds = new HashSet[Int] + def initialize(isMaster: Boolean) { if (isMaster) { val tracker = new RDDCacheTracker @@ -50,16 +78,34 @@ private object RDDCache extends Logging { } } + // Registers an RDD (on master only) + def registerRDD(rddId: Int, numPartitions: Int) { + registeredRddIds.synchronized { + if (!registeredRddIds.contains(rddId)) { + registeredRddIds += rddId + trackerActor !? RegisterRDD(rddId, numPartitions) + } + } + } + + // Get a snapshot of the currently known locations + def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = { + (trackerActor !? GetCacheLocations) match { + case h: HashMap[Int, Array[List[String]]] => h + case _ => throw new SparkException( + "Internal error: RDDCache did not reply with a HashMap") + } + } + // Gets or computes an RDD split def getOrCompute[T](rdd: RDD[T], split: Split)(implicit m: ClassManifest[T]) : Iterator[T] = { val key = (rdd.id, split.index) - logInfo("CachedRDD split key is " + key) - val cache = RDDCache.cache - val loading = RDDCache.loading + logInfo("CachedRDD partition key is " + key) val cachedVal = cache.get(key) if (cachedVal != null) { // Split is in cache, so just return its values + logInfo("Found partition in cache!") return Iterator.fromArray(cachedVal.asInstanceOf[Array[T]]) } else { // Mark the split as loading (unless someone else marks it first) @@ -73,13 +119,13 @@ private object RDDCache extends Logging { loading.add(key) } } - val host = System.getProperty("spark.hostname", Utils.localHostName) - trackerActor ! CacheEntryAdded(rdd.id, split.index, host) // If we got here, we have to load the split + // Tell the master that we're doing so + val host = System.getProperty("spark.hostname", Utils.localHostName) + trackerActor ! AddedToCache(rdd.id, split.index, host) // TODO: fetch any remote copy of the split that may be available - // TODO: also notify the master that we're loading it // TODO: also register a listener for when it unloads - logInfo("Computing and caching " + split) + logInfo("Computing partition " + split) val array = rdd.compute(split).toArray(m) cache.put(key, array) loading.synchronized { |