aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2011-03-06 12:16:38 -0800
committerMatei Zaharia <matei@eecs.berkeley.edu>2011-03-06 12:16:38 -0800
commit1df5a65a0158716c5634c55d57578fd00d3f5f1f (patch)
tree4af1d3ded72b5c4a4cc4456babaa76e4dc5413c4 /core
parente1436f1eaa32e968ae431bee078a54d1c0285535 (diff)
downloadspark-1df5a65a0158716c5634c55d57578fd00d3f5f1f.tar.gz
spark-1df5a65a0158716c5634c55d57578fd00d3f5f1f.tar.bz2
spark-1df5a65a0158716c5634c55d57578fd00d3f5f1f.zip
Pass cache locations correctly to DAGScheduler.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/spark/DAGScheduler.scala33
-rw-r--r--core/src/main/scala/spark/MapOutputTracker.scala2
-rw-r--r--core/src/main/scala/spark/RDDCache.scala76
-rw-r--r--core/src/main/scala/spark/SparkContext.scala1
4 files changed, 80 insertions, 32 deletions
diff --git a/core/src/main/scala/spark/DAGScheduler.scala b/core/src/main/scala/spark/DAGScheduler.scala
index ee3fda25a8..734cbea822 100644
--- a/core/src/main/scala/spark/DAGScheduler.scala
+++ b/core/src/main/scala/spark/DAGScheduler.scala
@@ -33,20 +33,14 @@ private abstract class DAGScheduler extends Scheduler with Logging {
val shuffleToMapStage = new HashMap[ShuffleDependency[_,_,_], Stage]
- val cacheLocs = new HashMap[RDD[_], Array[List[String]]]
+ var cacheLocs = new HashMap[Int, Array[List[String]]]
def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
- cacheLocs.getOrElseUpdate(rdd, Array.fill[List[String]](rdd.splits.size)(Nil))
+ cacheLocs(rdd.id)
}
-
- def addCacheLoc(rdd: RDD[_], partition: Int, host: String) {
- val locs = getCacheLocs(rdd)
- locs(partition) = host :: locs(partition)
- }
-
- def removeCacheLoc(rdd: RDD[_], partition: Int, host: String) {
- val locs = getCacheLocs(rdd)
- locs(partition) -= host
+
+ def updateCacheLocs() {
+ cacheLocs = RDDCache.getLocationsSnapshot()
}
def getShuffleMapStage(shuf: ShuffleDependency[_,_,_]): Stage = {
@@ -60,6 +54,9 @@ private abstract class DAGScheduler extends Scheduler with Logging {
}
def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_,_]]): Stage = {
+ // Kind of ugly: need to register RDDs with the cache here since
+ // we can't do it in its constructor because # of splits is unknown
+ RDDCache.registerRDD(rdd.id, rdd.splits.size)
val id = newStageId()
val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd))
idToStage(id) = stage
@@ -113,10 +110,10 @@ private abstract class DAGScheduler extends Scheduler with Logging {
missing.toList
}
- override def runJob[T, U](rdd: RDD[T], func: Iterator[T] => U)(implicit m: ClassManifest[U])
+ override def runJob[T, U](finalRdd: RDD[T], func: Iterator[T] => U)(implicit m: ClassManifest[U])
: Array[U] = {
- val numOutputParts: Int = rdd.splits.size
- val finalStage = newStage(rdd, None)
+ val numOutputParts: Int = finalRdd.splits.size
+ val finalStage = newStage(finalRdd, None)
val results = new Array[U](numOutputParts)
val finished = new Array[Boolean](numOutputParts)
var numFinished = 0
@@ -125,6 +122,8 @@ private abstract class DAGScheduler extends Scheduler with Logging {
val running = new HashSet[Stage]
val pendingTasks = new HashMap[Stage, HashSet[Task[_]]]
+ updateCacheLocs()
+
logInfo("Final stage: " + finalStage)
logInfo("Parents of final stage: " + finalStage.parents)
logInfo("Missing parents: " + getMissingParentStages(finalStage))
@@ -145,12 +144,13 @@ private abstract class DAGScheduler extends Scheduler with Logging {
}
def submitMissingTasks(stage: Stage) {
+ // Get our pending tasks and remember them in our pendingTasks entry
val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
var tasks = ArrayBuffer[Task[_]]()
if (stage == finalStage) {
for (p <- 0 until numOutputParts if (!finished(p))) {
- val locs = getPreferredLocs(rdd, p)
- tasks += new ResultTask(finalStage.id, rdd, func, p, locs)
+ val locs = getPreferredLocs(finalRdd, p)
+ tasks += new ResultTask(finalStage.id, finalRdd, func, p, locs)
}
} else {
for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
@@ -186,6 +186,7 @@ private abstract class DAGScheduler extends Scheduler with Logging {
if (pending.isEmpty) {
logInfo(stage + " finished; looking for newly runnable stages")
running -= stage
+ updateCacheLocs()
val newlyRunnable = new ArrayBuffer[Stage]
for (stage <- waiting if getMissingParentStages(stage) == Nil) {
newlyRunnable += stage
diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala
index ac62c6e411..a253176169 100644
--- a/core/src/main/scala/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/spark/MapOutputTracker.scala
@@ -11,7 +11,7 @@ class MapOutputTracker extends DaemonActor with Logging {
val port = System.getProperty("spark.master.port", "50501").toInt
RemoteActor.alive(port)
RemoteActor.register('MapOutputTracker, self)
- logInfo("Started on port " + port)
+ logInfo("Registered actor on port " + port)
}
}
diff --git a/core/src/main/scala/spark/RDDCache.scala b/core/src/main/scala/spark/RDDCache.scala
index 2f2ec9d237..aae2d74900 100644
--- a/core/src/main/scala/spark/RDDCache.scala
+++ b/core/src/main/scala/spark/RDDCache.scala
@@ -3,31 +3,57 @@ package spark
import scala.actors._
import scala.actors.Actor._
import scala.actors.remote._
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
sealed trait CacheMessage
-case class CacheEntryAdded(rddId: Int, partition: Int, host: String)
-case class CacheEntryRemoved(rddId: Int, partition: Int, host: String)
+case class AddedToCache(rddId: Int, partition: Int, host: String) extends CacheMessage
+case class DroppedFromCache(rddId: Int, partition: Int, host: String) extends CacheMessage
+case class MemoryCacheLost(host: String) extends CacheMessage
+case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheMessage
+case object GetCacheLocations extends CacheMessage
class RDDCacheTracker extends DaemonActor with Logging {
+ val locs = new HashMap[Int, Array[List[String]]]
+ // TODO: Should probably store (String, CacheType) tuples
+
def act() {
val port = System.getProperty("spark.master.port", "50501").toInt
RemoteActor.alive(port)
RemoteActor.register('RDDCacheTracker, self)
- logInfo("Started on port " + port)
+ logInfo("Registered actor on port " + port)
loop {
react {
- case CacheEntryAdded(rddId, partition, host) =>
- logInfo("Cache entry added: %s, %s, %s".format(rddId, partition, host))
+ case RegisterRDD(rddId: Int, numPartitions: Int) =>
+ logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions")
+ locs(rddId) = Array.fill[List[String]](numPartitions)(Nil)
+ reply("")
+
+ case AddedToCache(rddId, partition, host) =>
+ logInfo("Cache entry added: (%s, %s) on %s".format(rddId, partition, host))
+ locs(rddId)(partition) = host :: locs(rddId)(partition)
- case CacheEntryRemoved(rddId, partition, host) =>
- logInfo("Cache entry removed: %s, %s, %s".format(rddId, partition, host))
+ case DroppedFromCache(rddId, partition, host) =>
+ logInfo("Cache entry removed: (%s, %s) on %s".format(rddId, partition, host))
+ locs(rddId)(partition) -= host
+
+ case MemoryCacheLost(host) =>
+ logInfo("Memory cache lost on " + host)
+ // TODO: Drop host from the memory locations list of all RDDs
+
+ case GetCacheLocations =>
+ logInfo("Asked for current cache locations")
+ val locsCopy = new HashMap[Int, Array[List[String]]]
+ for ((rddId, array) <- locs) {
+ locsCopy(rddId) = array.clone()
+ }
+ reply(locsCopy)
}
}
}
}
-import scala.collection.mutable.HashSet
private object RDDCache extends Logging {
// Stores map results for various splits locally
val cache = Cache.newKeySpace()
@@ -38,6 +64,8 @@ private object RDDCache extends Logging {
// Tracker actor on the master, or remote reference to it on workers
var trackerActor: AbstractActor = null
+ val registeredRddIds = new HashSet[Int]
+
def initialize(isMaster: Boolean) {
if (isMaster) {
val tracker = new RDDCacheTracker
@@ -50,16 +78,34 @@ private object RDDCache extends Logging {
}
}
+ // Registers an RDD (on master only)
+ def registerRDD(rddId: Int, numPartitions: Int) {
+ registeredRddIds.synchronized {
+ if (!registeredRddIds.contains(rddId)) {
+ registeredRddIds += rddId
+ trackerActor !? RegisterRDD(rddId, numPartitions)
+ }
+ }
+ }
+
+ // Get a snapshot of the currently known locations
+ def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = {
+ (trackerActor !? GetCacheLocations) match {
+ case h: HashMap[Int, Array[List[String]]] => h
+ case _ => throw new SparkException(
+ "Internal error: RDDCache did not reply with a HashMap")
+ }
+ }
+
// Gets or computes an RDD split
def getOrCompute[T](rdd: RDD[T], split: Split)(implicit m: ClassManifest[T])
: Iterator[T] = {
val key = (rdd.id, split.index)
- logInfo("CachedRDD split key is " + key)
- val cache = RDDCache.cache
- val loading = RDDCache.loading
+ logInfo("CachedRDD partition key is " + key)
val cachedVal = cache.get(key)
if (cachedVal != null) {
// Split is in cache, so just return its values
+ logInfo("Found partition in cache!")
return Iterator.fromArray(cachedVal.asInstanceOf[Array[T]])
} else {
// Mark the split as loading (unless someone else marks it first)
@@ -73,13 +119,13 @@ private object RDDCache extends Logging {
loading.add(key)
}
}
- val host = System.getProperty("spark.hostname", Utils.localHostName)
- trackerActor ! CacheEntryAdded(rdd.id, split.index, host)
// If we got here, we have to load the split
+ // Tell the master that we're doing so
+ val host = System.getProperty("spark.hostname", Utils.localHostName)
+ trackerActor ! AddedToCache(rdd.id, split.index, host)
// TODO: fetch any remote copy of the split that may be available
- // TODO: also notify the master that we're loading it
// TODO: also register a listener for when it unloads
- logInfo("Computing and caching " + split)
+ logInfo("Computing partition " + split)
val array = rdd.compute(split).toArray(m)
cache.put(key, array)
loading.synchronized {
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index fda2ee3be7..5cce873c72 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -175,6 +175,7 @@ extends Logging {
private var nextRddId = new AtomicInteger(0)
+ // Register a new RDD, returning its RDD ID
private[spark] def newRddId(): Int = {
nextRddId.getAndIncrement()
}