diff options
Diffstat (limited to 'core')
4 files changed, 43 insertions, 68 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index d884095671..a9600336f0 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -246,7 +246,6 @@ class SparkContext( taskScheduler.start() @volatile private[spark] var dagScheduler = new DAGScheduler(taskScheduler) - dagScheduler.start() ui.start() diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index d0b21e896e..42bb3884c8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -19,9 +19,10 @@ package org.apache.spark.scheduler import java.io.NotSerializableException import java.util.Properties -import java.util.concurrent.{LinkedBlockingQueue, TimeUnit} import java.util.concurrent.atomic.AtomicInteger +import akka.actor._ +import akka.util.duration._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} import org.apache.spark._ @@ -65,12 +66,12 @@ class DAGScheduler( // Called by TaskScheduler to report task's starting. def taskStarted(task: Task[_], taskInfo: TaskInfo) { - eventQueue.put(BeginEvent(task, taskInfo)) + eventProcessActor ! BeginEvent(task, taskInfo) } // Called to report that a task has completed and results are being fetched remotely. def taskGettingResult(task: Task[_], taskInfo: TaskInfo) { - eventQueue.put(GettingResultEvent(task, taskInfo)) + eventProcessActor ! GettingResultEvent(task, taskInfo) } // Called by TaskScheduler to report task completions or failures. @@ -81,23 +82,23 @@ class DAGScheduler( accumUpdates: Map[Long, Any], taskInfo: TaskInfo, taskMetrics: TaskMetrics) { - eventQueue.put(CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics)) + eventProcessActor ! CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics) } // Called by TaskScheduler when an executor fails. def executorLost(execId: String) { - eventQueue.put(ExecutorLost(execId)) + eventProcessActor ! ExecutorLost(execId) } // Called by TaskScheduler when a host is added def executorGained(execId: String, host: String) { - eventQueue.put(ExecutorGained(execId, host)) + eventProcessActor ! ExecutorGained(execId, host) } // Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or // cancellation of the job itself. def taskSetFailed(taskSet: TaskSet, reason: String) { - eventQueue.put(TaskSetFailed(taskSet, reason)) + eventProcessActor ! TaskSetFailed(taskSet, reason) } // The time, in millis, to wait for fetch failure events to stop coming in after one is detected; @@ -109,7 +110,30 @@ class DAGScheduler( // resubmit failed stages val POLL_TIMEOUT = 10L - private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent] + private val eventProcessActor: ActorRef = env.actorSystem.actorOf(Props(new Actor { + override def preStart() { + context.system.scheduler.schedule(RESUBMIT_TIMEOUT milliseconds, RESUBMIT_TIMEOUT milliseconds) { + if (failed.size > 0) { + resubmitFailedStages() + } + } + } + + /** + * The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure + * events and responds by launching tasks. This runs in a dedicated thread and receives events + * via the eventQueue. + */ + def receive = { + case event: DAGSchedulerEvent => + logDebug("Got event of type " + event.getClass.getName) + + if (!processEvent(event)) + submitWaitingStages() + else + context.stop(self) + } + })) private[scheduler] val nextJobId = new AtomicInteger(0) @@ -150,16 +174,6 @@ class DAGScheduler( val metadataCleaner = new MetadataCleaner(MetadataCleanerType.DAG_SCHEDULER, this.cleanup) - // Start a thread to run the DAGScheduler event loop - def start() { - new Thread("DAGScheduler") { - setDaemon(true) - override def run() { - DAGScheduler.this.run() - } - }.start() - } - def addSparkListener(listener: SparkListener) { listenerBus.addListener(listener) } @@ -301,8 +315,7 @@ class DAGScheduler( assert(partitions.size > 0) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler) - eventQueue.put(JobSubmitted(jobId, rdd, func2, partitions.toArray, allowLocal, callSite, - waiter, properties)) + eventProcessActor ! JobSubmitted(jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties) waiter } @@ -337,8 +350,7 @@ class DAGScheduler( val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val partitions = (0 until rdd.partitions.size).toArray val jobId = nextJobId.getAndIncrement() - eventQueue.put(JobSubmitted(jobId, rdd, func2, partitions, allowLocal = false, callSite, - listener, properties)) + eventProcessActor ! JobSubmitted(jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties) listener.awaitResult() // Will throw an exception if the job fails } @@ -347,19 +359,19 @@ class DAGScheduler( */ def cancelJob(jobId: Int) { logInfo("Asked to cancel job " + jobId) - eventQueue.put(JobCancelled(jobId)) + eventProcessActor ! JobCancelled(jobId) } def cancelJobGroup(groupId: String) { logInfo("Asked to cancel job group " + groupId) - eventQueue.put(JobGroupCancelled(groupId)) + eventProcessActor ! JobGroupCancelled(groupId) } /** * Cancel all jobs that are running or waiting in the queue. */ def cancelAllJobs() { - eventQueue.put(AllJobsCancelled) + eventProcessActor ! AllJobsCancelled } /** @@ -474,42 +486,6 @@ class DAGScheduler( } } - - /** - * The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure - * events and responds by launching tasks. This runs in a dedicated thread and receives events - * via the eventQueue. - */ - private def run() { - SparkEnv.set(env) - - while (true) { - val event = eventQueue.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS) - if (event != null) { - logDebug("Got event of type " + event.getClass.getName) - } - this.synchronized { // needed in case other threads makes calls into methods of this class - if (event != null) { - if (processEvent(event)) { - return - } - } - - val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability - // Periodically resubmit failed stages if some map output fetches have failed and we have - // waited at least RESUBMIT_TIMEOUT. We wait for this short time because when a node fails, - // tasks on many other nodes are bound to get a fetch failure, and they won't all get it at - // the same time, so we want to make sure we've identified all the reduce tasks that depend - // on the failed node. - if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) { - resubmitFailedStages() - } else { - submitWaitingStages() - } - } - } - } - /** * Run a job on an RDD locally, assuming it has only a single partition and no dependencies. * We run the operation in a separate thread just in case it takes a bunch of time, so that we @@ -878,7 +854,7 @@ class DAGScheduler( // If the RDD has narrow dependencies, pick the first partition of the first narrow dep // that has any placement preferences. Ideally we would choose based on transfer sizes, // but this will do for now. - rdd.dependencies.foreach(_ match { + rdd.dependencies.foreach { case n: NarrowDependency[_] => for (inPart <- n.getParents(partition)) { val locs = getPreferredLocs(n.rdd, inPart) @@ -886,7 +862,7 @@ class DAGScheduler( return locs } case _ => - }) + } Nil } @@ -909,7 +885,7 @@ class DAGScheduler( } def stop() { - eventQueue.put(StopDAGScheduler) + eventProcessActor ! StopDAGScheduler metadataCleaner.cancel() taskSched.stop() } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index a34c95b6f0..702aca8323 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -891,9 +891,9 @@ private[spark] object BlockManager extends Logging { blockManagerMaster: BlockManagerMaster = null) : Map[BlockId, Seq[BlockManagerId]] = { - // env == null and blockManagerMaster != null is used in tests + // blockManagerMaster != null is used in tests assert (env != null || blockManagerMaster != null) - val blockLocations: Seq[Seq[BlockManagerId]] = if (env != null) { + val blockLocations: Seq[Seq[BlockManagerId]] = if (blockManagerMaster == null) { env.blockManager.getLocationBlockIds(blockIds) } else { blockManagerMaster.getLocations(blockIds) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 00f2fdd657..a4d41ebbff 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -100,7 +100,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont cacheLocations.clear() results.clear() mapOutputTracker = new MapOutputTrackerMaster() - scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null) { + scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, sc.env) { override def runLocally(job: ActiveJob) { // don't bother with the thread while unit testing runLocallyWithinThread(job) |