aboutsummaryrefslogtreecommitdiff
path: root/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMaster.scala
blob: 6a0617cc063cd2615becf3dbc885fa8c0badcc07 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
package spark.deploy.yarn

import java.net.Socket
import java.util.concurrent.CopyOnWriteArrayList
import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.net.NetUtils
import org.apache.hadoop.yarn.api._
import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.api.protocolrecords._
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.ipc.YarnRPC
import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
import scala.collection.JavaConversions._
import spark.{SparkContext, Logging, Utils}
import org.apache.hadoop.security.UserGroupInformation
import java.security.PrivilegedExceptionAction

class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) extends Logging {

  def this(args: ApplicationMasterArguments) = this(args, new Configuration())
  
  private var rpc: YarnRPC = YarnRPC.create(conf)
  private var resourceManager: AMRMProtocol = null
  private var appAttemptId: ApplicationAttemptId = null
  private var userThread: Thread = null
  private val yarnConf: YarnConfiguration = new YarnConfiguration(conf)

  private var yarnAllocator: YarnAllocationHandler = null
  private var isFinished:Boolean = false

  def run() {
    
    appAttemptId = getApplicationAttemptId()
    resourceManager = registerWithResourceManager()
    val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster()

    // Compute number of threads for akka
    val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory()

    if (minimumMemory > 0) {
      val mem = args.workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD
      val numCore = (mem  / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0)

      if (numCore > 0) {
        // do not override - hits https://issues.apache.org/jira/browse/HADOOP-8406
        // TODO: Uncomment when hadoop is on a version which has this fixed.
        // args.workerCores = numCore
      }
    }

    // Workaround until hadoop moves to something which has
    // https://issues.apache.org/jira/browse/HADOOP-8406
    // ignore result
    // This does not, unfortunately, always work reliably ... but alleviates the bug a lot of times
    // Hence args.workerCores = numCore disabled above. Any better option ?
    // org.apache.hadoop.io.compress.CompressionCodecFactory.getCodecClasses(conf)
    
    ApplicationMaster.register(this)
    // Start the user's JAR
    userThread = startUserClass()
    
    // This a bit hacky, but we need to wait until the spark.driver.port property has
    // been set by the Thread executing the user class.
    waitForSparkMaster()
    
    // Allocate all containers
    allocateWorkers()
    
    // Wait for the user class to Finish     
    userThread.join()

    System.exit(0)
  }
  
  private def getApplicationAttemptId(): ApplicationAttemptId = {
    val envs = System.getenv()
    val containerIdString = envs.get(ApplicationConstants.AM_CONTAINER_ID_ENV)
    val containerId = ConverterUtils.toContainerId(containerIdString)
    val appAttemptId = containerId.getApplicationAttemptId()
    logInfo("ApplicationAttemptId: " + appAttemptId)
    return appAttemptId
  }
  
  private def registerWithResourceManager(): AMRMProtocol = {
    val rmAddress = NetUtils.createSocketAddr(yarnConf.get(
      YarnConfiguration.RM_SCHEDULER_ADDRESS,
      YarnConfiguration.DEFAULT_RM_SCHEDULER_ADDRESS))
    logInfo("Connecting to ResourceManager at " + rmAddress)
    return rpc.getProxy(classOf[AMRMProtocol], rmAddress, conf).asInstanceOf[AMRMProtocol]
  }
  
  private def registerApplicationMaster(): RegisterApplicationMasterResponse = {
    logInfo("Registering the ApplicationMaster")
    val appMasterRequest = Records.newRecord(classOf[RegisterApplicationMasterRequest])
      .asInstanceOf[RegisterApplicationMasterRequest]
    appMasterRequest.setApplicationAttemptId(appAttemptId)
    // Setting this to master host,port - so that the ApplicationReport at client has some sensible info. 
    // Users can then monitor stderr/stdout on that node if required.
    appMasterRequest.setHost(Utils.localHostName())
    appMasterRequest.setRpcPort(0)
    // What do we provide here ? Might make sense to expose something sensible later ?
    appMasterRequest.setTrackingUrl("")
    return resourceManager.registerApplicationMaster(appMasterRequest)
  }
  
  private def waitForSparkMaster() {
    logInfo("Waiting for spark driver to be reachable.")
    var driverUp = false
    while(!driverUp) {
      val driverHost = System.getProperty("spark.driver.host")
      val driverPort = System.getProperty("spark.driver.port")
      try {
        val socket = new Socket(driverHost, driverPort.toInt)
        socket.close()
        logInfo("Master now available: " + driverHost + ":" + driverPort)
        driverUp = true
      } catch {
        case e: Exception =>
          logError("Failed to connect to driver at " + driverHost + ":" + driverPort)
        Thread.sleep(100)
      }
    }
  }

  private def startUserClass(): Thread  = {
    logInfo("Starting the user JAR in a separate Thread")
    val mainMethod = Class.forName(args.userClass, false, Thread.currentThread.getContextClassLoader)
      .getMethod("main", classOf[Array[String]])
    val t = new Thread {
      override def run() {
        var successed = false
        try {
          // Copy
          var mainArgs: Array[String] = new Array[String](args.userArgs.size())
          args.userArgs.copyToArray(mainArgs, 0, args.userArgs.size())
          mainMethod.invoke(null, mainArgs)
          // some job script has "System.exit(0)" at the end, for example SparkPi, SparkLR
          // userThread will stop here unless it has uncaught exception thrown out
          // It need shutdown hook to set SUCCEEDED
          successed = true
        } finally {
          if (successed) {
            ApplicationMaster.this.finishApplicationMaster(FinalApplicationStatus.SUCCEEDED)
          } else {
            ApplicationMaster.this.finishApplicationMaster(FinalApplicationStatus.FAILED)
          }
        }
      }
    }
    t.start()
    return t
  }

  private def allocateWorkers() {
    logInfo("Waiting for spark context initialization")

    try {
      var sparkContext: SparkContext = null
      ApplicationMaster.sparkContextRef.synchronized {
        var count = 0
        while (ApplicationMaster.sparkContextRef.get() == null) {
          logInfo("Waiting for spark context initialization ... " + count)
          count = count + 1
          ApplicationMaster.sparkContextRef.wait(10000L)
        }
        sparkContext = ApplicationMaster.sparkContextRef.get()
        assert(sparkContext != null)
        this.yarnAllocator = YarnAllocationHandler.newAllocator(yarnConf, resourceManager, appAttemptId, args, sparkContext.preferredNodeLocationData)
      }


      logInfo("Allocating " + args.numWorkers + " workers.")
      // Wait until all containers have finished
      // TODO: This is a bit ugly. Can we make it nicer?
      // TODO: Handle container failure
      while(yarnAllocator.getNumWorkersRunning < args.numWorkers &&
        // If user thread exists, then quit !
        userThread.isAlive) {

          this.yarnAllocator.allocateContainers(math.max(args.numWorkers - yarnAllocator.getNumWorkersRunning, 0))
          ApplicationMaster.incrementAllocatorLoop(1)
          Thread.sleep(100)
      }
    } finally {
      // in case of exceptions, etc - ensure that count is atleast ALLOCATOR_LOOP_WAIT_COUNT : 
      // so that the loop (in ApplicationMaster.sparkContextInitialized) breaks
      ApplicationMaster.incrementAllocatorLoop(ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT)
    }
    logInfo("All workers have launched.")

    // Launch a progress reporter thread, else app will get killed after expiration (def: 10mins) timeout
    if (userThread.isAlive) {
      // ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse.

      val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000)
      // must be <= timeoutInterval/ 2.
      // On other hand, also ensure that we are reasonably responsive without causing too many requests to RM.
      // so atleast 1 minute or timeoutInterval / 10 - whichever is higher.
      val interval = math.min(timeoutInterval / 2, math.max(timeoutInterval/ 10, 60000L))
      launchReporterThread(interval)
    }
  }

  // TODO: We might want to extend this to allocate more containers in case they die !
  private def launchReporterThread(_sleepTime: Long): Thread = {
    val sleepTime = if (_sleepTime <= 0 ) 0 else _sleepTime

    val t = new Thread {
      override def run() {
        while (userThread.isAlive) {
          val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning
          if (missingWorkerCount > 0) {
            logInfo("Allocating " + missingWorkerCount + " containers to make up for (potentially ?) lost containers")
            yarnAllocator.allocateContainers(missingWorkerCount)
          }
          else sendProgress()
          Thread.sleep(sleepTime)
        }
      }
    }
    // setting to daemon status, though this is usually not a good idea.
    t.setDaemon(true)
    t.start()
    logInfo("Started progress reporter thread - sleep time : " + sleepTime)
    return t
  }

  private def sendProgress() {
    logDebug("Sending progress")
    // simulated with an allocate request with no nodes requested ...
    yarnAllocator.allocateContainers(0)
  }

  /*
  def printContainers(containers: List[Container]) = {
    for (container <- containers) {
      logInfo("Launching shell command on a new container."
        + ", containerId=" + container.getId()
        + ", containerNode=" + container.getNodeId().getHost() 
        + ":" + container.getNodeId().getPort()
        + ", containerNodeURI=" + container.getNodeHttpAddress()
        + ", containerState" + container.getState()
        + ", containerResourceMemory"  
        + container.getResource().getMemory())
    }
  }
  */

  def finishApplicationMaster(status: FinalApplicationStatus) {

    synchronized {
      if (isFinished) {
        return
      }
      isFinished = true
    }

    logInfo("finishApplicationMaster with " + status)
    val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest])
      .asInstanceOf[FinishApplicationMasterRequest]
    finishReq.setAppAttemptId(appAttemptId)
    finishReq.setFinishApplicationStatus(status)
    resourceManager.finishApplicationMaster(finishReq)

  }
 
}

object ApplicationMaster {
  // number of times to wait for the allocator loop to complete.
  // each loop iteration waits for 100ms, so maximum of 3 seconds.
  // This is to ensure that we have reasonable number of containers before we start
  // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be optimal as more 
  // containers are available. Might need to handle this better.
  private val ALLOCATOR_LOOP_WAIT_COUNT = 30
  def incrementAllocatorLoop(by: Int) {
    val count = yarnAllocatorLoop.getAndAdd(by)
    if (count >= ALLOCATOR_LOOP_WAIT_COUNT) {
      yarnAllocatorLoop.synchronized {
        // to wake threads off wait ...
        yarnAllocatorLoop.notifyAll()
      }
    }
  }

  private val applicationMasters = new CopyOnWriteArrayList[ApplicationMaster]()

  def register(master: ApplicationMaster) {
    applicationMasters.add(master)
  }

  val sparkContextRef: AtomicReference[SparkContext] = new AtomicReference[SparkContext](null)
  val yarnAllocatorLoop: AtomicInteger = new AtomicInteger(0)

  def sparkContextInitialized(sc: SparkContext): Boolean = {
    var modified = false
    sparkContextRef.synchronized {
      modified = sparkContextRef.compareAndSet(null, sc)
      sparkContextRef.notifyAll()
    }

    // Add a shutdown hook - as a best case effort in case users do not call sc.stop or do System.exit
    // Should not really have to do this, but it helps yarn to evict resources earlier.
    // not to mention, prevent Client declaring failure even though we exit'ed properly.
    if (modified) {
      Runtime.getRuntime().addShutdownHook(new Thread with Logging { 
        // This is not just to log, but also to ensure that log system is initialized for this instance when we actually are 'run'
        logInfo("Adding shutdown hook for context " + sc)
        override def run() { 
          logInfo("Invoking sc stop from shutdown hook") 
          sc.stop() 
          // best case ...
          for (master <- applicationMasters) {
            master.finishApplicationMaster(FinalApplicationStatus.SUCCEEDED)
          }
        } 
      } )
    }

    // Wait for initialization to complete and atleast 'some' nodes can get allocated
    yarnAllocatorLoop.synchronized {
      while (yarnAllocatorLoop.get() <= ALLOCATOR_LOOP_WAIT_COUNT) {
        yarnAllocatorLoop.wait(1000L)
      }
    }
    modified
  }

  def main(argStrings: Array[String]) {
    val args = new ApplicationMasterArguments(argStrings)
    new ApplicationMaster(args).run()
  }
}