 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *    http://www.apache.org/licenses/LICENSE-2.0
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * See the License for the specific language governing permissions and
 * limitations under the License.

package org.apache.spark.deploy.yarn

import java.io.IOException
import java.net.Socket
import java.util.concurrent.CopyOnWriteArrayList
import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}

import scala.collection.JavaConversions._

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.net.NetUtils
import org.apache.hadoop.security.UserGroupInformation
import org.apache.hadoop.util.ShutdownHookManager
import org.apache.hadoop.yarn.api._
import org.apache.hadoop.yarn.api.protocolrecords._
import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.client.api.AMRMClient
import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.ipc.YarnRPC
import org.apache.hadoop.yarn.util.{ConverterUtils, Records}

import org.apache.spark.{SparkContext, Logging}
import org.apache.spark.util.Utils

class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) extends Logging {

  def this(args: ApplicationMasterArguments) = this(args, new Configuration())

  private var rpc: YarnRPC = YarnRPC.create(conf)
  private val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
  private var appAttemptId: ApplicationAttemptId = _
  private var userThread: Thread = _
  private val fs = FileSystem.get(yarnConf)

  private var yarnAllocator: YarnAllocationHandler = _
  private var isFinished: Boolean = false
  private var uiAddress: String = _
  private val maxAppAttempts: Int = conf.getInt(
    YarnConfiguration.RM_AM_MAX_ATTEMPTS, YarnConfiguration.DEFAULT_RM_AM_MAX_ATTEMPTS)
  private var isLastAMRetry: Boolean = true
  private var amClient: AMRMClient[ContainerRequest] = _

  // Default to numWorkers * 2, with minimum of 3
  private val maxNumWorkerFailures = conf.getOrElse("spark.yarn.max.worker.failures",
    math.max(args.numWorkers * 2, 3).toString()).toInt

  def run() {
    // Setup the directories so things go to YARN approved directories rather
    // than user specified and /tmp.
    conf.set("spark.local.dir",  getLocalDirs())

    // Use priority 30 as it's higher then HDFS. It's same priority as MapReduce is using.
    ShutdownHookManager.get().addShutdownHook(new AppMasterShutdownHook(this), 30)

    appAttemptId = getApplicationAttemptId()
    isLastAMRetry = appAttemptId.getAttemptId() >= maxAppAttempts
    amClient = AMRMClient.createAMRMClient()

    // Workaround until hadoop moves to something which has
    // https://issues.apache.org/jira/browse/HADOOP-8406 - fixed in (2.0.2-alpha but no 0.23 line)
    // org.apache.hadoop.io.compress.CompressionCodecFactory.getCodecClasses(conf)


    // Start the user's JAR
    userThread = startUserClass()

    // This a bit hacky, but we need to wait until the spark.driver.port property has
    // been set by the Thread executing the user class.


    // Do this after Spark master is up and SparkContext is created so that we can register UI Url.
    val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster()

    // Allocate all containers

    // Wait for the user class to Finish


  /** Get the Yarn approved local directories. */
  private def getLocalDirs(): String = {
    // Hadoop 0.23 and 2.x have different Environment variable names for the
    // local dirs, so lets check both. We assume one of the 2 is set.
    // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X
    val localDirs = Option(System.getenv("YARN_LOCAL_DIRS"))

    if (localDirs.isEmpty()) {
      throw new Exception("Yarn Local dirs can't be empty")

  private def getApplicationAttemptId(): ApplicationAttemptId = {
    val envs = System.getenv()
    val containerIdString = envs.get(ApplicationConstants.Environment.CONTAINER_ID.name())
    val containerId = ConverterUtils.toContainerId(containerIdString)
    val appAttemptId = containerId.getApplicationAttemptId()
    logInfo("ApplicationAttemptId: " + appAttemptId)

  private def registerApplicationMaster(): RegisterApplicationMasterResponse = {
    logInfo("Registering the ApplicationMaster")
    amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress)

  private def waitForSparkMaster() {
    logInfo("Waiting for Spark driver to be reachable.")
    var driverUp = false
    var tries = 0
    val numTries = conf.getOrElse("spark.yarn.applicationMaster.waitTries", "10").toInt
    while (!driverUp && tries < numTries) {
      val driverHost = conf.get("spark.driver.host")
      val driverPort = conf.get("spark.driver.port")
      try {
        val socket = new Socket(driverHost, driverPort.toInt)
        logInfo("Driver now available: %s:%s".format(driverHost, driverPort))
        driverUp = true
      } catch {
        case e: Exception => {
          logWarning("Failed to connect to driver at %s:%s, retrying ...".
            format(driverHost, driverPort))
          tries = tries + 1

  private def startUserClass(): Thread  = {
    logInfo("Starting the user JAR in a separate Thread")
    val mainMethod = Class.forName(
      false /* initialize */,
      Thread.currentThread.getContextClassLoader).getMethod("main", classOf[Array[String]])
    val t = new Thread {
      override def run() {
        var successed = false
        try {
          // Copy
          var mainArgs: Array[String] = new Array[String](args.userArgs.size)
          args.userArgs.copyToArray(mainArgs, 0, args.userArgs.size)
          mainMethod.invoke(null, mainArgs)
          // some job script has "System.exit(0)" at the end, for example SparkPi, SparkLR
          // userThread will stop here unless it has uncaught exception thrown out
          // It need shutdown hook to set SUCCEEDED
          successed = true
        } finally {
          logDebug("finishing main")
          isLastAMRetry = true
          if (successed) {
          } else {

  // This need to happen before allocateWorkers()
  private def waitForSparkContextInitialized() {
    logInfo("Waiting for Spark context initialization")
    try {
      var sparkContext: SparkContext = null
      ApplicationMaster.sparkContextRef.synchronized {
        var numTries = 0
        val waitTime = 10000L
        val maxNumTries = conf.getOrElse("spark.yarn.ApplicationMaster.waitTries", "10").toInt
        while (ApplicationMaster.sparkContextRef.get() == null && numTries < maxNumTries) {
          logInfo("Waiting for Spark context initialization ... " + numTries)
          numTries = numTries + 1
        sparkContext = ApplicationMaster.sparkContextRef.get()
        assert(sparkContext != null || numTries >= maxNumTries)

        if (sparkContext != null) {
          uiAddress = sparkContext.ui.appUIAddress
          this.yarnAllocator = YarnAllocationHandler.newAllocator(
        } else {
          logWarning("Unable to retrieve SparkContext inspite of waiting for %d, maxNumTries = %d".
            format(numTries * waitTime, maxNumTries))
          this.yarnAllocator = YarnAllocationHandler.newAllocator(
    } finally {
      // In case of exceptions, etc - ensure that count is at least ALLOCATOR_LOOP_WAIT_COUNT :
      // so that the loop (in ApplicationMaster.sparkContextInitialized) breaks.

  private def allocateWorkers() {
    try {
      logInfo("Allocating " + args.numWorkers + " workers.")
      // Wait until all containers have finished
      // TODO: This is a bit ugly. Can we make it nicer?
      // TODO: Handle container failure
      // Exits the loop if the user thread exits.
      while (yarnAllocator.getNumWorkersRunning < args.numWorkers && userThread.isAlive) {
        if (yarnAllocator.getNumWorkersFailed >= maxNumWorkerFailures) {
            "max number of worker failures reached")
    } finally {
      // In case of exceptions, etc - ensure that count is at least ALLOCATOR_LOOP_WAIT_COUNT,
      // so that the loop in ApplicationMaster#sparkContextInitialized() breaks.
    logInfo("All workers have launched.")

    // Launch a progress reporter thread, else the app will get killed after expiration
    // (def: 10mins) timeout.
    if (userThread.isAlive) {
      // Ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapses.
      val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000)

      // we want to be reasonably responsive without causing too many requests to RM.
      val schedulerInterval =
        conf.getOrElse("spark.yarn.scheduler.heartbeat.interval-ms", "5000").toLong

      // must be <= timeoutInterval / 2.
      val interval = math.min(timeoutInterval / 2, schedulerInterval)


  private def launchReporterThread(_sleepTime: Long): Thread = {
    val sleepTime = if (_sleepTime <= 0 ) 0 else _sleepTime

    val t = new Thread {
      override def run() {
        while (userThread.isAlive) {
          if (yarnAllocator.getNumWorkersFailed >= maxNumWorkerFailures) {
              "max number of worker failures reached")
          val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning -
          if (missingWorkerCount > 0) {
            logInfo("Allocating %d containers to make up for (potentially) lost containers".
    // Setting to daemon status, though this is usually not a good idea.
    logInfo("Started progress reporter thread - sleep time : " + sleepTime)

  private def sendProgress() {
    logDebug("Sending progress")
    // Simulated with an allocate request with no nodes requested.

  def printContainers(containers: List[Container]) = {
    for (container <- containers) {
      logInfo("Launching shell command on a new container."
        + ", containerId=" + container.getId()
        + ", containerNode=" + container.getNodeId().getHost()
        + ":" + container.getNodeId().getPort()
        + ", containerNodeURI=" + container.getNodeHttpAddress()
        + ", containerState" + container.getState()
        + ", containerResourceMemory"
        + container.getResource().getMemory())

  def finishApplicationMaster(status: FinalApplicationStatus, diagnostics: String = "") {
    synchronized {
      if (isFinished) {
      isFinished = true

    logInfo("finishApplicationMaster with " + status)
    // Set tracking URL to empty since we don't have a history server.
    amClient.unregisterApplicationMaster(status, "" /* appMessage */, "" /* appTrackingUrl */)

   * Clean up the staging directory.
  private def cleanupStagingDir() {
    var stagingDirPath: Path = null
    try {
      val preserveFiles = conf.getOrElse("spark.yarn.preserve.staging.files", "false").toBoolean
      if (!preserveFiles) {
        stagingDirPath = new Path(System.getenv("SPARK_YARN_STAGING_DIR"))
        if (stagingDirPath == null) {
          logError("Staging directory is null")
        logInfo("Deleting staging directory " + stagingDirPath)
        fs.delete(stagingDirPath, true)
    } catch {
      case ioe: IOException =>
        logError("Failed to cleanup staging dir " + stagingDirPath, ioe)

  // The shutdown hook that runs when a signal is received AND during normal close of the JVM.
  class AppMasterShutdownHook(appMaster: ApplicationMaster) extends Runnable {

    def run() {
      logInfo("AppMaster received a signal.")
      // we need to clean up staging dir before HDFS is shut down
      // make sure we don't delete it until this is the last AM
      if (appMaster.isLastAMRetry) appMaster.cleanupStagingDir()

object ApplicationMaster {
  // Number of times to wait for the allocator loop to complete.
  // Each loop iteration waits for 100ms, so maximum of 3 seconds.
  // This is to ensure that we have reasonable number of containers before we start
  // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be
  // optimal as more containers are available. Might need to handle this better.
  private val ALLOCATOR_LOOP_WAIT_COUNT = 30

  private val applicationMasters = new CopyOnWriteArrayList[ApplicationMaster]()

  val sparkContextRef: AtomicReference[SparkContext] =
    new AtomicReference[SparkContext](null /* initialValue */)

  val yarnAllocatorLoop: AtomicInteger = new AtomicInteger(0)

  def incrementAllocatorLoop(by: Int) {
    val count = yarnAllocatorLoop.getAndAdd(by)
    if (count >= ALLOCATOR_LOOP_WAIT_COUNT) {
      yarnAllocatorLoop.synchronized {
        // to wake threads off wait ...

  def register(master: ApplicationMaster) {

  // TODO(harvey): See whether this should be discarded - it isn't used anywhere atm...
  def sparkContextInitialized(sc: SparkContext): Boolean = {
    var modified = false
    sparkContextRef.synchronized {
      modified = sparkContextRef.compareAndSet(null, sc)

    // Add a shutdown hook - as a best case effort in case users do not call sc.stop or do
    // System.exit.
    // Should not really have to do this, but it helps YARN to evict resources earlier.
    // Not to mention, prevent the Client from declaring failure even though we exited properly.
    // Note that this will unfortunately not properly clean up the staging files because it gets
    // called too late, after the filesystem is already shutdown.
    if (modified) {
      Runtime.getRuntime().addShutdownHook(new Thread with Logging {
        // This is not only logs, but also ensures that log system is initialized for this instance
        // when we are actually 'run'-ing.
        logInfo("Adding shutdown hook for context " + sc)
        override def run() {
          logInfo("Invoking sc stop from shutdown hook")
          // Best case ...
          for (master <- applicationMasters) {
      } )

    // Wait for initialization to complete and atleast 'some' nodes can get allocated.
    yarnAllocatorLoop.synchronized {
      while (yarnAllocatorLoop.get() <= ALLOCATOR_LOOP_WAIT_COUNT) {

  def main(argStrings: Array[String]) {
    val args = new ApplicationMasterArguments(argStrings)
    new ApplicationMaster(args).run()