aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
blob: b000e328e678472dc5b05497c651b1c897b52260 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
package spark.scheduler.local

import java.io.File
import java.util.concurrent.atomic.AtomicInteger
import java.nio.ByteBuffer
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet

import spark._
import spark.TaskState.TaskState
import spark.executor.ExecutorURLClassLoader
import spark.scheduler._
import spark.scheduler.cluster._
import akka.actor._

/**
 * A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
 * the scheduler also allows each task to fail up to maxFailures times, which is useful for
 * testing fault recovery.
 */

private[spark] case class LocalReviveOffers()
private[spark] case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer)

private[spark] class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Actor with Logging {
  def receive = {
    case LocalReviveOffers =>
      launchTask(localScheduler.resourceOffer(freeCores))
    case LocalStatusUpdate(taskId, state, serializeData) =>
      freeCores += 1
      localScheduler.statusUpdate(taskId, state, serializeData)
      launchTask(localScheduler.resourceOffer(freeCores))
  }

  def launchTask(tasks : Seq[TaskDescription]) {
    for (task <- tasks) {
      freeCores -= 1
      localScheduler.threadPool.submit(new Runnable {
        def run() {
          localScheduler.runTask(task.taskId,task.serializedTask)
        }
      })
    }
  }
}

private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext)
  extends TaskScheduler
  with Logging {

  var attemptId = new AtomicInteger(0)
  var threadPool = Utils.newDaemonFixedThreadPool(threads)
  val env = SparkEnv.get
  var listener: TaskSchedulerListener = null

  // Application dependencies (added through SparkContext) that we've fetched so far on this node.
  // Each map holds the master's timestamp for the version of that file or JAR we got.
  val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
  val currentJars: HashMap[String, Long] = new HashMap[String, Long]()

  val classLoader = new ExecutorURLClassLoader(Array(), Thread.currentThread.getContextClassLoader)

  var schedulableBuilder: SchedulableBuilder = null
  var rootPool: Pool = null
  val activeTaskSets = new HashMap[String, TaskSetManager]
  val taskIdToTaskSetId = new HashMap[Long, String]
  val taskSetTaskIds = new HashMap[String, HashSet[Long]]

  var localActor: ActorRef = null

  override def start() {
    //default scheduler is FIFO
    val schedulingMode = System.getProperty("spark.cluster.schedulingmode", "FIFO")
    //temporarily set rootPool name to empty
    rootPool = new Pool("", SchedulingMode.withName(schedulingMode), 0, 0)
    schedulableBuilder = {
      schedulingMode match {
        case "FIFO" =>
          new FIFOSchedulableBuilder(rootPool)
        case "FAIR" =>
          new FairSchedulableBuilder(rootPool)
      }
    }
    schedulableBuilder.buildPools()

    localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test")
  }

  override def setListener(listener: TaskSchedulerListener) {
    this.listener = listener
  }

  override def submitTasks(taskSet: TaskSet) {
    synchronized {
      var manager = new LocalTaskSetManager(this, taskSet)
      schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
      activeTaskSets(taskSet.id) = manager
      taskSetTaskIds(taskSet.id) = new HashSet[Long]()
      localActor ! LocalReviveOffers
    }
  }

  def resourceOffer(freeCores: Int): Seq[TaskDescription] = {
    synchronized {
      var freeCpuCores = freeCores
      val tasks = new ArrayBuffer[TaskDescription](freeCores)
      val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue()
      for (manager <- sortedTaskSetQueue) {
        logDebug("parentName:%s,name:%s,runningTasks:%s".format(manager.parent.name, manager.name, manager.runningTasks))
      }

      var launchTask = false
      for (manager <- sortedTaskSetQueue) {
        do {
          launchTask = false
          manager.slaveOffer(null,null,freeCpuCores) match {
            case Some(task) =>
              tasks += task
              taskIdToTaskSetId(task.taskId) = manager.taskSet.id
              taskSetTaskIds(manager.taskSet.id) += task.taskId
              freeCpuCores -= 1
              launchTask = true
            case None => {}
            }
        } while(launchTask)
      }
      return tasks
    }
  }

  def taskSetFinished(manager: TaskSetManager) {
    synchronized {
      activeTaskSets -= manager.taskSet.id
      manager.parent.removeSchedulable(manager)
      logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
      taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
      taskSetTaskIds -= manager.taskSet.id
    }
  }

  def runTask(taskId: Long, bytes: ByteBuffer) {
    logInfo("Running " + taskId)
    val info = new TaskInfo(taskId, 0, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
    // Set the Spark execution environment for the worker thread
    SparkEnv.set(env)
    val ser = SparkEnv.get.closureSerializer.newInstance()
    var attemptedTask: Option[Task[_]] = None
    val start = System.currentTimeMillis()
    var taskStart: Long = 0
    try {
      Accumulators.clear()
      Thread.currentThread().setContextClassLoader(classLoader)

      // Serialize and deserialize the task so that accumulators are changed to thread-local ones;
      // this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
      val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes)
      updateDependencies(taskFiles, taskJars)   // Download any files added with addFile
      val deserializedTask = ser.deserialize[Task[_]](
        taskBytes, Thread.currentThread.getContextClassLoader)
      attemptedTask = Some(deserializedTask)
      val deserTime = System.currentTimeMillis() - start
      taskStart = System.currentTimeMillis()

      // Run it
      val result: Any = deserializedTask.run(taskId)

      // Serialize and deserialize the result to emulate what the Mesos
      // executor does. This is useful to catch serialization errors early
      // on in development (so when users move their local Spark programs
      // to the cluster, they don't get surprised by serialization errors).
      val serResult = ser.serialize(result)
      deserializedTask.metrics.get.resultSize = serResult.limit()
      val resultToReturn = ser.deserialize[Any](serResult)
      val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
        ser.serialize(Accumulators.values))
      val serviceTime = System.currentTimeMillis() - taskStart
      logInfo("Finished " + taskId)
      deserializedTask.metrics.get.executorRunTime = serviceTime.toInt
      deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt
      val taskResult = new TaskResult(result, accumUpdates, deserializedTask.metrics.getOrElse(null))
      val serializedResult = ser.serialize(taskResult)
      localActor ! LocalStatusUpdate(taskId, TaskState.FINISHED, serializedResult)
    } catch {
      case t: Throwable => {
        val serviceTime = System.currentTimeMillis() - taskStart
        val metrics = attemptedTask.flatMap(t => t.metrics)
        metrics.foreach{m => m.executorRunTime = serviceTime.toInt}
        val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace, metrics)
        localActor ! LocalStatusUpdate(taskId, TaskState.FAILED, ser.serialize(failure))
      }
    }
  }

  /**
   * Download any missing dependencies if we receive a new set of files and JARs from the
   * SparkContext. Also adds any new JARs we fetched to the class loader.
   */
  private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
    synchronized {
      // Fetch missing dependencies
      for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
        logInfo("Fetching " + name + " with timestamp " + timestamp)
        Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
        currentFiles(name) = timestamp
      }

      for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
        logInfo("Fetching " + name + " with timestamp " + timestamp)
        Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
        currentJars(name) = timestamp
        // Add it to our class loader
        val localName = name.split("/").last
        val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL
        if (!classLoader.getURLs.contains(url)) {
          logInfo("Adding " + url + " to class loader")
          classLoader.addURL(url)
        }
      }
    }
  }

  def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer) {
    synchronized {
      val taskSetId = taskIdToTaskSetId(taskId)
      val taskSetManager = activeTaskSets(taskSetId)
      taskSetTaskIds(taskSetId) -= taskId
      taskSetManager.statusUpdate(taskId, state, serializedData)
    }
  }

   override def stop() {
    threadPool.shutdownNow()
  }

  override def defaultParallelism() = threads
}