1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
|
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package spark.scheduler.local
import java.io.File
import java.lang.management.ManagementFactory
import java.util.concurrent.atomic.AtomicInteger
import java.nio.ByteBuffer
import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import spark._
import spark.TaskState.TaskState
import spark.executor.ExecutorURLClassLoader
import spark.scheduler._
import spark.scheduler.cluster._
import spark.scheduler.cluster.SchedulingMode.SchedulingMode
import akka.actor._
/**
* A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
* the scheduler also allows each task to fail up to maxFailures times, which is useful for
* testing fault recovery.
*/
private[spark]
case class LocalReviveOffers()
private[spark]
case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer)
private[spark]
class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Actor with Logging {
def receive = {
case LocalReviveOffers =>
launchTask(localScheduler.resourceOffer(freeCores))
case LocalStatusUpdate(taskId, state, serializeData) =>
freeCores += 1
localScheduler.statusUpdate(taskId, state, serializeData)
launchTask(localScheduler.resourceOffer(freeCores))
}
def launchTask(tasks : Seq[TaskDescription]) {
for (task <- tasks) {
freeCores -= 1
localScheduler.threadPool.submit(new Runnable {
def run() {
localScheduler.runTask(task.taskId, task.serializedTask)
}
})
}
}
}
private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext)
extends TaskScheduler
with Logging {
var attemptId = new AtomicInteger(0)
var threadPool = Utils.newDaemonFixedThreadPool(threads)
val env = SparkEnv.get
var listener: TaskSchedulerListener = null
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
// Each map holds the master's timestamp for the version of that file or JAR we got.
val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
val currentJars: HashMap[String, Long] = new HashMap[String, Long]()
val classLoader = new ExecutorURLClassLoader(Array(), Thread.currentThread.getContextClassLoader)
var schedulableBuilder: SchedulableBuilder = null
var rootPool: Pool = null
val schedulingMode: SchedulingMode = SchedulingMode.withName(
System.getProperty("spark.cluster.schedulingmode", "FIFO"))
val activeTaskSets = new HashMap[String, TaskSetManager]
val taskIdToTaskSetId = new HashMap[Long, String]
val taskSetTaskIds = new HashMap[String, HashSet[Long]]
var localActor: ActorRef = null
override def start() {
// temporarily set rootPool name to empty
rootPool = new Pool("", schedulingMode, 0, 0)
schedulableBuilder = {
schedulingMode match {
case SchedulingMode.FIFO =>
new FIFOSchedulableBuilder(rootPool)
case SchedulingMode.FAIR =>
new FairSchedulableBuilder(rootPool)
}
}
schedulableBuilder.buildPools()
localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test")
}
override def setListener(listener: TaskSchedulerListener) {
this.listener = listener
}
override def submitTasks(taskSet: TaskSet) {
synchronized {
val manager = new LocalTaskSetManager(this, taskSet)
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
activeTaskSets(taskSet.id) = manager
taskSetTaskIds(taskSet.id) = new HashSet[Long]()
localActor ! LocalReviveOffers
}
}
def resourceOffer(freeCores: Int): Seq[TaskDescription] = {
synchronized {
var freeCpuCores = freeCores
val tasks = new ArrayBuffer[TaskDescription](freeCores)
val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue()
for (manager <- sortedTaskSetQueue) {
logDebug("parentName:%s,name:%s,runningTasks:%s".format(
manager.parent.name, manager.name, manager.runningTasks))
}
var launchTask = false
for (manager <- sortedTaskSetQueue) {
do {
launchTask = false
manager.resourceOffer(null, null, freeCpuCores, null) match {
case Some(task) =>
tasks += task
taskIdToTaskSetId(task.taskId) = manager.taskSet.id
taskSetTaskIds(manager.taskSet.id) += task.taskId
freeCpuCores -= 1
launchTask = true
case None => {}
}
} while(launchTask)
}
return tasks
}
}
def taskSetFinished(manager: TaskSetManager) {
synchronized {
activeTaskSets -= manager.taskSet.id
manager.parent.removeSchedulable(manager)
logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
taskSetTaskIds -= manager.taskSet.id
}
}
def runTask(taskId: Long, bytes: ByteBuffer) {
logInfo("Running " + taskId)
val info = new TaskInfo(taskId, 0, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
// Set the Spark execution environment for the worker thread
SparkEnv.set(env)
val ser = SparkEnv.get.closureSerializer.newInstance()
val objectSer = SparkEnv.get.serializer.newInstance()
var attemptedTask: Option[Task[_]] = None
val start = System.currentTimeMillis()
var taskStart: Long = 0
def getTotalGCTime = ManagementFactory.getGarbageCollectorMXBeans.map(g => g.getCollectionTime).sum
val startGCTime = getTotalGCTime
try {
Accumulators.clear()
Thread.currentThread().setContextClassLoader(classLoader)
// Serialize and deserialize the task so that accumulators are changed to thread-local ones;
// this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes)
updateDependencies(taskFiles, taskJars) // Download any files added with addFile
val deserializedTask = ser.deserialize[Task[_]](
taskBytes, Thread.currentThread.getContextClassLoader)
attemptedTask = Some(deserializedTask)
val deserTime = System.currentTimeMillis() - start
taskStart = System.currentTimeMillis()
// Run it
val result: Any = deserializedTask.run(taskId)
// Serialize and deserialize the result to emulate what the Mesos
// executor does. This is useful to catch serialization errors early
// on in development (so when users move their local Spark programs
// to the cluster, they don't get surprised by serialization errors).
val serResult = objectSer.serialize(result)
deserializedTask.metrics.get.resultSize = serResult.limit()
val resultToReturn = objectSer.deserialize[Any](serResult)
val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
ser.serialize(Accumulators.values))
val serviceTime = System.currentTimeMillis() - taskStart
logInfo("Finished " + taskId)
deserializedTask.metrics.get.executorRunTime = serviceTime.toInt
deserializedTask.metrics.get.jvmGCTime = getTotalGCTime - startGCTime
deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt
val taskResult = new TaskResult(result, accumUpdates, deserializedTask.metrics.getOrElse(null))
val serializedResult = ser.serialize(taskResult)
localActor ! LocalStatusUpdate(taskId, TaskState.FINISHED, serializedResult)
} catch {
case t: Throwable => {
val serviceTime = System.currentTimeMillis() - taskStart
val metrics = attemptedTask.flatMap(t => t.metrics)
for (m <- metrics) {
m.executorRunTime = serviceTime.toInt
m.jvmGCTime = getTotalGCTime - startGCTime
}
val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace, metrics)
localActor ! LocalStatusUpdate(taskId, TaskState.FAILED, ser.serialize(failure))
}
}
}
/**
* Download any missing dependencies if we receive a new set of files and JARs from the
* SparkContext. Also adds any new JARs we fetched to the class loader.
*/
private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
synchronized {
// Fetch missing dependencies
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
currentFiles(name) = timestamp
}
for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
currentJars(name) = timestamp
// Add it to our class loader
val localName = name.split("/").last
val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL
if (!classLoader.getURLs.contains(url)) {
logInfo("Adding " + url + " to class loader")
classLoader.addURL(url)
}
}
}
}
def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer) {
synchronized {
val taskSetId = taskIdToTaskSetId(taskId)
val taskSetManager = activeTaskSets(taskSetId)
taskSetTaskIds(taskSetId) -= taskId
taskSetManager.statusUpdate(taskId, state, serializedData)
}
}
override def stop() {
threadPool.shutdownNow()
}
override def defaultParallelism() = threads
}
|