From 1df5a65a0158716c5634c55d57578fd00d3f5f1f Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 6 Mar 2011 12:16:38 -0800 Subject: Pass cache locations correctly to DAGScheduler. --- core/src/main/scala/spark/DAGScheduler.scala | 33 +++++----- core/src/main/scala/spark/MapOutputTracker.scala | 2 +- core/src/main/scala/spark/RDDCache.scala | 76 +++++++++++++++++++----- core/src/main/scala/spark/SparkContext.scala | 1 + 4 files changed, 80 insertions(+), 32 deletions(-) diff --git a/core/src/main/scala/spark/DAGScheduler.scala b/core/src/main/scala/spark/DAGScheduler.scala index ee3fda25a8..734cbea822 100644 --- a/core/src/main/scala/spark/DAGScheduler.scala +++ b/core/src/main/scala/spark/DAGScheduler.scala @@ -33,20 +33,14 @@ private abstract class DAGScheduler extends Scheduler with Logging { val shuffleToMapStage = new HashMap[ShuffleDependency[_,_,_], Stage] - val cacheLocs = new HashMap[RDD[_], Array[List[String]]] + var cacheLocs = new HashMap[Int, Array[List[String]]] def getCacheLocs(rdd: RDD[_]): Array[List[String]] = { - cacheLocs.getOrElseUpdate(rdd, Array.fill[List[String]](rdd.splits.size)(Nil)) + cacheLocs(rdd.id) } - - def addCacheLoc(rdd: RDD[_], partition: Int, host: String) { - val locs = getCacheLocs(rdd) - locs(partition) = host :: locs(partition) - } - - def removeCacheLoc(rdd: RDD[_], partition: Int, host: String) { - val locs = getCacheLocs(rdd) - locs(partition) -= host + + def updateCacheLocs() { + cacheLocs = RDDCache.getLocationsSnapshot() } def getShuffleMapStage(shuf: ShuffleDependency[_,_,_]): Stage = { @@ -60,6 +54,9 @@ private abstract class DAGScheduler extends Scheduler with Logging { } def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_,_]]): Stage = { + // Kind of ugly: need to register RDDs with the cache here since + // we can't do it in its constructor because # of splits is unknown + RDDCache.registerRDD(rdd.id, rdd.splits.size) val id = newStageId() val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd)) idToStage(id) = stage @@ -113,10 +110,10 @@ private abstract class DAGScheduler extends Scheduler with Logging { missing.toList } - override def runJob[T, U](rdd: RDD[T], func: Iterator[T] => U)(implicit m: ClassManifest[U]) + override def runJob[T, U](finalRdd: RDD[T], func: Iterator[T] => U)(implicit m: ClassManifest[U]) : Array[U] = { - val numOutputParts: Int = rdd.splits.size - val finalStage = newStage(rdd, None) + val numOutputParts: Int = finalRdd.splits.size + val finalStage = newStage(finalRdd, None) val results = new Array[U](numOutputParts) val finished = new Array[Boolean](numOutputParts) var numFinished = 0 @@ -125,6 +122,8 @@ private abstract class DAGScheduler extends Scheduler with Logging { val running = new HashSet[Stage] val pendingTasks = new HashMap[Stage, HashSet[Task[_]]] + updateCacheLocs() + logInfo("Final stage: " + finalStage) logInfo("Parents of final stage: " + finalStage.parents) logInfo("Missing parents: " + getMissingParentStages(finalStage)) @@ -145,12 +144,13 @@ private abstract class DAGScheduler extends Scheduler with Logging { } def submitMissingTasks(stage: Stage) { + // Get our pending tasks and remember them in our pendingTasks entry val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet) var tasks = ArrayBuffer[Task[_]]() if (stage == finalStage) { for (p <- 0 until numOutputParts if (!finished(p))) { - val locs = getPreferredLocs(rdd, p) - tasks += new ResultTask(finalStage.id, rdd, func, p, locs) + val locs = getPreferredLocs(finalRdd, p) + tasks += new ResultTask(finalStage.id, finalRdd, func, p, locs) } } else { for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) { @@ -186,6 +186,7 @@ private abstract class DAGScheduler extends Scheduler with Logging { if (pending.isEmpty) { logInfo(stage + " finished; looking for newly runnable stages") running -= stage + updateCacheLocs() val newlyRunnable = new ArrayBuffer[Stage] for (stage <- waiting if getMissingParentStages(stage) == Nil) { newlyRunnable += stage diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index ac62c6e411..a253176169 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -11,7 +11,7 @@ class MapOutputTracker extends DaemonActor with Logging { val port = System.getProperty("spark.master.port", "50501").toInt RemoteActor.alive(port) RemoteActor.register('MapOutputTracker, self) - logInfo("Started on port " + port) + logInfo("Registered actor on port " + port) } } diff --git a/core/src/main/scala/spark/RDDCache.scala b/core/src/main/scala/spark/RDDCache.scala index 2f2ec9d237..aae2d74900 100644 --- a/core/src/main/scala/spark/RDDCache.scala +++ b/core/src/main/scala/spark/RDDCache.scala @@ -3,31 +3,57 @@ package spark import scala.actors._ import scala.actors.Actor._ import scala.actors.remote._ +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet sealed trait CacheMessage -case class CacheEntryAdded(rddId: Int, partition: Int, host: String) -case class CacheEntryRemoved(rddId: Int, partition: Int, host: String) +case class AddedToCache(rddId: Int, partition: Int, host: String) extends CacheMessage +case class DroppedFromCache(rddId: Int, partition: Int, host: String) extends CacheMessage +case class MemoryCacheLost(host: String) extends CacheMessage +case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheMessage +case object GetCacheLocations extends CacheMessage class RDDCacheTracker extends DaemonActor with Logging { + val locs = new HashMap[Int, Array[List[String]]] + // TODO: Should probably store (String, CacheType) tuples + def act() { val port = System.getProperty("spark.master.port", "50501").toInt RemoteActor.alive(port) RemoteActor.register('RDDCacheTracker, self) - logInfo("Started on port " + port) + logInfo("Registered actor on port " + port) loop { react { - case CacheEntryAdded(rddId, partition, host) => - logInfo("Cache entry added: %s, %s, %s".format(rddId, partition, host)) + case RegisterRDD(rddId: Int, numPartitions: Int) => + logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions") + locs(rddId) = Array.fill[List[String]](numPartitions)(Nil) + reply("") + + case AddedToCache(rddId, partition, host) => + logInfo("Cache entry added: (%s, %s) on %s".format(rddId, partition, host)) + locs(rddId)(partition) = host :: locs(rddId)(partition) - case CacheEntryRemoved(rddId, partition, host) => - logInfo("Cache entry removed: %s, %s, %s".format(rddId, partition, host)) + case DroppedFromCache(rddId, partition, host) => + logInfo("Cache entry removed: (%s, %s) on %s".format(rddId, partition, host)) + locs(rddId)(partition) -= host + + case MemoryCacheLost(host) => + logInfo("Memory cache lost on " + host) + // TODO: Drop host from the memory locations list of all RDDs + + case GetCacheLocations => + logInfo("Asked for current cache locations") + val locsCopy = new HashMap[Int, Array[List[String]]] + for ((rddId, array) <- locs) { + locsCopy(rddId) = array.clone() + } + reply(locsCopy) } } } } -import scala.collection.mutable.HashSet private object RDDCache extends Logging { // Stores map results for various splits locally val cache = Cache.newKeySpace() @@ -38,6 +64,8 @@ private object RDDCache extends Logging { // Tracker actor on the master, or remote reference to it on workers var trackerActor: AbstractActor = null + val registeredRddIds = new HashSet[Int] + def initialize(isMaster: Boolean) { if (isMaster) { val tracker = new RDDCacheTracker @@ -50,16 +78,34 @@ private object RDDCache extends Logging { } } + // Registers an RDD (on master only) + def registerRDD(rddId: Int, numPartitions: Int) { + registeredRddIds.synchronized { + if (!registeredRddIds.contains(rddId)) { + registeredRddIds += rddId + trackerActor !? RegisterRDD(rddId, numPartitions) + } + } + } + + // Get a snapshot of the currently known locations + def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = { + (trackerActor !? GetCacheLocations) match { + case h: HashMap[Int, Array[List[String]]] => h + case _ => throw new SparkException( + "Internal error: RDDCache did not reply with a HashMap") + } + } + // Gets or computes an RDD split def getOrCompute[T](rdd: RDD[T], split: Split)(implicit m: ClassManifest[T]) : Iterator[T] = { val key = (rdd.id, split.index) - logInfo("CachedRDD split key is " + key) - val cache = RDDCache.cache - val loading = RDDCache.loading + logInfo("CachedRDD partition key is " + key) val cachedVal = cache.get(key) if (cachedVal != null) { // Split is in cache, so just return its values + logInfo("Found partition in cache!") return Iterator.fromArray(cachedVal.asInstanceOf[Array[T]]) } else { // Mark the split as loading (unless someone else marks it first) @@ -73,13 +119,13 @@ private object RDDCache extends Logging { loading.add(key) } } - val host = System.getProperty("spark.hostname", Utils.localHostName) - trackerActor ! CacheEntryAdded(rdd.id, split.index, host) // If we got here, we have to load the split + // Tell the master that we're doing so + val host = System.getProperty("spark.hostname", Utils.localHostName) + trackerActor ! AddedToCache(rdd.id, split.index, host) // TODO: fetch any remote copy of the split that may be available - // TODO: also notify the master that we're loading it // TODO: also register a listener for when it unloads - logInfo("Computing and caching " + split) + logInfo("Computing partition " + split) val array = rdd.compute(split).toArray(m) cache.put(key, array) loading.synchronized { diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index fda2ee3be7..5cce873c72 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -175,6 +175,7 @@ extends Logging { private var nextRddId = new AtomicInteger(0) + // Register a new RDD, returning its RDD ID private[spark] def newRddId(): Int = { nextRddId.getAndIncrement() } -- cgit v1.2.3