aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
blob: 69822f568c78abf559ef2045ab37fd04026233dd (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
package spark.scheduler.cluster

import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}

import akka.actor._
import akka.util.duration._
import akka.pattern.ask

import spark.{SparkException, Logging, TaskState}
import akka.dispatch.Await
import java.util.concurrent.atomic.AtomicInteger
import akka.remote.{RemoteClientShutdown, RemoteClientDisconnected, RemoteClientLifeCycleEvent}

/**
 * A standalone scheduler backend, which waits for standalone executors to connect to it through
 * Akka. These may be executed in a variety of ways, such as Mesos tasks for the coarse-grained
 * Mesos mode or standalone processes for Spark's standalone deploy mode (spark.deploy.*).
 */
private[spark]
class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: ActorSystem)
  extends SchedulerBackend with Logging {

  // Use an atomic variable to track total number of cores in the cluster for simplicity and speed
  var totalCoreCount = new AtomicInteger(0)

  class MasterActor(sparkProperties: Seq[(String, String)]) extends Actor {
    val executorActor = new HashMap[String, ActorRef]
    val executorAddress = new HashMap[String, Address]
    val executorHost = new HashMap[String, String]
    val freeCores = new HashMap[String, Int]
    val actorToExecutorId = new HashMap[ActorRef, String]
    val addressToExecutorId = new HashMap[Address, String]

    override def preStart() {
      // Listen for remote client disconnection events, since they don't go through Akka's watch()
      context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
    }

    def receive = {
      case RegisterExecutor(executorId, host, cores) =>
        if (executorActor.contains(executorId)) {
          sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId)
        } else {
          logInfo("Registered executor: " + sender + " with ID " + executorId)
          sender ! RegisteredExecutor(sparkProperties)
          context.watch(sender)
          executorActor(executorId) = sender
          executorHost(executorId) = host
          freeCores(executorId) = cores
          executorAddress(executorId) = sender.path.address
          actorToExecutorId(sender) = executorId
          addressToExecutorId(sender.path.address) = executorId
          totalCoreCount.addAndGet(cores)
          makeOffers()
        }

      case StatusUpdate(executorId, taskId, state, data) =>
        scheduler.statusUpdate(taskId, state, data.value)
        if (TaskState.isFinished(state)) {
          freeCores(executorId) += 1
          makeOffers(executorId)
        }

      case ReviveOffers =>
        makeOffers()

      case StopMaster =>
        sender ! true
        context.stop(self)

      case Terminated(actor) =>
        actorToExecutorId.get(actor).foreach(removeExecutor(_, "Akka actor terminated"))

      case RemoteClientDisconnected(transport, address) =>
        addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client disconnected"))

      case RemoteClientShutdown(transport, address) =>
        addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client shutdown"))
    }

    // Make fake resource offers on all executors
    def makeOffers() {
      launchTasks(scheduler.resourceOffers(
        executorHost.toArray.map {case (id, host) => new WorkerOffer(id, host, freeCores(id))}))
    }

    // Make fake resource offers on just one executor
    def makeOffers(executorId: String) {
      launchTasks(scheduler.resourceOffers(
        Seq(new WorkerOffer(executorId, executorHost(executorId), freeCores(executorId)))))
    }

    // Launch tasks returned by a set of resource offers
    def launchTasks(tasks: Seq[Seq[TaskDescription]]) {
      for (task <- tasks.flatten) {
        freeCores(task.executorId) -= 1
        executorActor(task.executorId) ! LaunchTask(task)
      }
    }

    // Remove a disconnected slave from the cluster
    def removeExecutor(executorId: String, reason: String) {
      logInfo("Slave " + executorId + " disconnected, so removing it")
      val numCores = freeCores(executorId)
      actorToExecutorId -= executorActor(executorId)
      addressToExecutorId -= executorAddress(executorId)
      executorActor -= executorId
      executorHost -= executorId
      freeCores -= executorId
      executorHost -= executorId
      totalCoreCount.addAndGet(-numCores)
      scheduler.executorLost(executorId, SlaveLost(reason))
    }
  }

  var masterActor: ActorRef = null
  val taskIdsOnSlave = new HashMap[String, HashSet[String]]

  def start() {
    val properties = new ArrayBuffer[(String, String)]
    val iterator = System.getProperties.entrySet.iterator
    while (iterator.hasNext) {
      val entry = iterator.next
      val (key, value) = (entry.getKey.toString, entry.getValue.toString)
      if (key.startsWith("spark.")) {
        properties += ((key, value))
      }
    }
    masterActor = actorSystem.actorOf(
      Props(new MasterActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME)
  }

  def stop() {
    try {
      if (masterActor != null) {
        val timeout = 5.seconds
        val future = masterActor.ask(StopMaster)(timeout)
        Await.result(future, timeout)
      }
    } catch {
      case e: Exception =>
        throw new SparkException("Error stopping standalone scheduler's master actor", e)
    }
  }

  def reviveOffers() {
    masterActor ! ReviveOffers
  }

  def defaultParallelism(): Int = math.max(totalCoreCount.get(), 2)
}

private[spark] object StandaloneSchedulerBackend {
  val ACTOR_NAME = "StandaloneScheduler"
}