aboutsummaryrefslogtreecommitdiff
path: root/resource-managers/yarn/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'resource-managers/yarn/src/main')
-rw-r--r--resource-managers/yarn/src/main/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider3
-rw-r--r--resource-managers/yarn/src/main/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager1
-rw-r--r--resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala791
-rw-r--r--resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala105
-rw-r--r--resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala1541
-rw-r--r--resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala86
-rw-r--r--resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala186
-rw-r--r--resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala266
-rw-r--r--resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala224
-rw-r--r--resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala727
-rw-r--r--resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala135
-rw-r--r--resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala317
-rw-r--r--resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala347
-rw-r--r--resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala235
-rw-r--r--resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManager.scala105
-rw-r--r--resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala130
-rw-r--r--resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala74
-rw-r--r--resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HDFSCredentialProvider.scala110
-rw-r--r--resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HiveCredentialProvider.scala129
-rw-r--r--resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ServiceCredentialProvider.scala57
-rw-r--r--resource-managers/yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala53
-rw-r--r--resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala143
-rw-r--r--resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala157
-rw-r--r--resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala56
-rw-r--r--resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala37
-rw-r--r--resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala67
-rw-r--r--resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala39
-rw-r--r--resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala315
28 files changed, 6436 insertions, 0 deletions
diff --git a/resource-managers/yarn/src/main/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider b/resource-managers/yarn/src/main/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider
new file mode 100644
index 0000000000..22ead56d23
--- /dev/null
+++ b/resource-managers/yarn/src/main/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider
@@ -0,0 +1,3 @@
+org.apache.spark.deploy.yarn.security.HDFSCredentialProvider
+org.apache.spark.deploy.yarn.security.HBaseCredentialProvider
+org.apache.spark.deploy.yarn.security.HiveCredentialProvider
diff --git a/resource-managers/yarn/src/main/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager b/resource-managers/yarn/src/main/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager
new file mode 100644
index 0000000000..6e8a1ebfc6
--- /dev/null
+++ b/resource-managers/yarn/src/main/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager
@@ -0,0 +1 @@
+org.apache.spark.scheduler.cluster.YarnClusterManager
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
new file mode 100644
index 0000000000..0378ef4fac
--- /dev/null
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -0,0 +1,791 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.io.{File, IOException}
+import java.lang.reflect.InvocationTargetException
+import java.net.{Socket, URI, URL}
+import java.util.concurrent.{TimeoutException, TimeUnit}
+
+import scala.collection.mutable.HashMap
+import scala.concurrent.Promise
+import scala.concurrent.duration.Duration
+import scala.util.control.NonFatal
+
+import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.yarn.api._
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
+
+import org.apache.spark._
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.deploy.history.HistoryServer
+import org.apache.spark.deploy.yarn.config._
+import org.apache.spark.deploy.yarn.security.{AMCredentialRenewer, ConfigurableCredentialManager}
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config._
+import org.apache.spark.rpc._
+import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, YarnSchedulerBackend}
+import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
+import org.apache.spark.util._
+
+/**
+ * Common application master functionality for Spark on Yarn.
+ */
+private[spark] class ApplicationMaster(
+ args: ApplicationMasterArguments,
+ client: YarnRMClient)
+ extends Logging {
+
+ // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be
+ // optimal as more containers are available. Might need to handle this better.
+
+ private val sparkConf = new SparkConf()
+ private val yarnConf: YarnConfiguration = SparkHadoopUtil.get.newConfiguration(sparkConf)
+ .asInstanceOf[YarnConfiguration]
+ private val isClusterMode = args.userClass != null
+
+ // Default to twice the number of executors (twice the maximum number of executors if dynamic
+ // allocation is enabled), with a minimum of 3.
+
+ private val maxNumExecutorFailures = {
+ val effectiveNumExecutors =
+ if (Utils.isDynamicAllocationEnabled(sparkConf)) {
+ sparkConf.get(DYN_ALLOCATION_MAX_EXECUTORS)
+ } else {
+ sparkConf.get(EXECUTOR_INSTANCES).getOrElse(0)
+ }
+ // By default, effectiveNumExecutors is Int.MaxValue if dynamic allocation is enabled. We need
+ // avoid the integer overflow here.
+ val defaultMaxNumExecutorFailures = math.max(3,
+ if (effectiveNumExecutors > Int.MaxValue / 2) Int.MaxValue else (2 * effectiveNumExecutors))
+
+ sparkConf.get(MAX_EXECUTOR_FAILURES).getOrElse(defaultMaxNumExecutorFailures)
+ }
+
+ @volatile private var exitCode = 0
+ @volatile private var unregistered = false
+ @volatile private var finished = false
+ @volatile private var finalStatus = getDefaultFinalStatus
+ @volatile private var finalMsg: String = ""
+ @volatile private var userClassThread: Thread = _
+
+ @volatile private var reporterThread: Thread = _
+ @volatile private var allocator: YarnAllocator = _
+
+ // Lock for controlling the allocator (heartbeat) thread.
+ private val allocatorLock = new Object()
+
+ // Steady state heartbeat interval. We want to be reasonably responsive without causing too many
+ // requests to RM.
+ private val heartbeatInterval = {
+ // Ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapses.
+ val expiryInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000)
+ math.max(0, math.min(expiryInterval / 2, sparkConf.get(RM_HEARTBEAT_INTERVAL)))
+ }
+
+ // Initial wait interval before allocator poll, to allow for quicker ramp up when executors are
+ // being requested.
+ private val initialAllocationInterval = math.min(heartbeatInterval,
+ sparkConf.get(INITIAL_HEARTBEAT_INTERVAL))
+
+ // Next wait interval before allocator poll.
+ private var nextAllocationInterval = initialAllocationInterval
+
+ private var rpcEnv: RpcEnv = null
+ private var amEndpoint: RpcEndpointRef = _
+
+ // In cluster mode, used to tell the AM when the user's SparkContext has been initialized.
+ private val sparkContextPromise = Promise[SparkContext]()
+
+ private var credentialRenewer: AMCredentialRenewer = _
+
+ // Load the list of localized files set by the client. This is used when launching executors,
+ // and is loaded here so that these configs don't pollute the Web UI's environment page in
+ // cluster mode.
+ private val localResources = {
+ logInfo("Preparing Local resources")
+ val resources = HashMap[String, LocalResource]()
+
+ def setupDistributedCache(
+ file: String,
+ rtype: LocalResourceType,
+ timestamp: String,
+ size: String,
+ vis: String): Unit = {
+ val uri = new URI(file)
+ val amJarRsrc = Records.newRecord(classOf[LocalResource])
+ amJarRsrc.setType(rtype)
+ amJarRsrc.setVisibility(LocalResourceVisibility.valueOf(vis))
+ amJarRsrc.setResource(ConverterUtils.getYarnUrlFromURI(uri))
+ amJarRsrc.setTimestamp(timestamp.toLong)
+ amJarRsrc.setSize(size.toLong)
+
+ val fileName = Option(uri.getFragment()).getOrElse(new Path(uri).getName())
+ resources(fileName) = amJarRsrc
+ }
+
+ val distFiles = sparkConf.get(CACHED_FILES)
+ val fileSizes = sparkConf.get(CACHED_FILES_SIZES)
+ val timeStamps = sparkConf.get(CACHED_FILES_TIMESTAMPS)
+ val visibilities = sparkConf.get(CACHED_FILES_VISIBILITIES)
+ val resTypes = sparkConf.get(CACHED_FILES_TYPES)
+
+ for (i <- 0 to distFiles.size - 1) {
+ val resType = LocalResourceType.valueOf(resTypes(i))
+ setupDistributedCache(distFiles(i), resType, timeStamps(i).toString, fileSizes(i).toString,
+ visibilities(i))
+ }
+
+ // Distribute the conf archive to executors.
+ sparkConf.get(CACHED_CONF_ARCHIVE).foreach { path =>
+ val uri = new URI(path)
+ val fs = FileSystem.get(uri, yarnConf)
+ val status = fs.getFileStatus(new Path(uri))
+ // SPARK-16080: Make sure to use the correct name for the destination when distributing the
+ // conf archive to executors.
+ val destUri = new URI(uri.getScheme(), uri.getRawSchemeSpecificPart(),
+ Client.LOCALIZED_CONF_DIR)
+ setupDistributedCache(destUri.toString(), LocalResourceType.ARCHIVE,
+ status.getModificationTime().toString, status.getLen.toString,
+ LocalResourceVisibility.PRIVATE.name())
+ }
+
+ // Clean up the configuration so it doesn't show up in the Web UI (since it's really noisy).
+ CACHE_CONFIGS.foreach { e =>
+ sparkConf.remove(e)
+ sys.props.remove(e.key)
+ }
+
+ resources.toMap
+ }
+
+ def getAttemptId(): ApplicationAttemptId = {
+ client.getAttemptId()
+ }
+
+ final def run(): Int = {
+ try {
+ val appAttemptId = client.getAttemptId()
+
+ var attemptID: Option[String] = None
+
+ if (isClusterMode) {
+ // Set the web ui port to be ephemeral for yarn so we don't conflict with
+ // other spark processes running on the same box
+ System.setProperty("spark.ui.port", "0")
+
+ // Set the master and deploy mode property to match the requested mode.
+ System.setProperty("spark.master", "yarn")
+ System.setProperty("spark.submit.deployMode", "cluster")
+
+ // Set this internal configuration if it is running on cluster mode, this
+ // configuration will be checked in SparkContext to avoid misuse of yarn cluster mode.
+ System.setProperty("spark.yarn.app.id", appAttemptId.getApplicationId().toString())
+
+ attemptID = Option(appAttemptId.getAttemptId.toString)
+ }
+
+ new CallerContext(
+ "APPMASTER", sparkConf.get(APP_CALLER_CONTEXT),
+ Option(appAttemptId.getApplicationId.toString), attemptID).setCurrentContext()
+
+ logInfo("ApplicationAttemptId: " + appAttemptId)
+
+ val fs = FileSystem.get(yarnConf)
+
+ // This shutdown hook should run *after* the SparkContext is shut down.
+ val priority = ShutdownHookManager.SPARK_CONTEXT_SHUTDOWN_PRIORITY - 1
+ ShutdownHookManager.addShutdownHook(priority) { () =>
+ val maxAppAttempts = client.getMaxRegAttempts(sparkConf, yarnConf)
+ val isLastAttempt = client.getAttemptId().getAttemptId() >= maxAppAttempts
+
+ if (!finished) {
+ // The default state of ApplicationMaster is failed if it is invoked by shut down hook.
+ // This behavior is different compared to 1.x version.
+ // If user application is exited ahead of time by calling System.exit(N), here mark
+ // this application as failed with EXIT_EARLY. For a good shutdown, user shouldn't call
+ // System.exit(0) to terminate the application.
+ finish(finalStatus,
+ ApplicationMaster.EXIT_EARLY,
+ "Shutdown hook called before final status was reported.")
+ }
+
+ if (!unregistered) {
+ // we only want to unregister if we don't want the RM to retry
+ if (finalStatus == FinalApplicationStatus.SUCCEEDED || isLastAttempt) {
+ unregister(finalStatus, finalMsg)
+ cleanupStagingDir(fs)
+ }
+ }
+ }
+
+ // Call this to force generation of secret so it gets populated into the
+ // Hadoop UGI. This has to happen before the startUserApplication which does a
+ // doAs in order for the credentials to be passed on to the executor containers.
+ val securityMgr = new SecurityManager(sparkConf)
+
+ // If the credentials file config is present, we must periodically renew tokens. So create
+ // a new AMDelegationTokenRenewer
+ if (sparkConf.contains(CREDENTIALS_FILE_PATH.key)) {
+ // If a principal and keytab have been set, use that to create new credentials for executors
+ // periodically
+ credentialRenewer =
+ new ConfigurableCredentialManager(sparkConf, yarnConf).credentialRenewer()
+ credentialRenewer.scheduleLoginFromKeytab()
+ }
+
+ if (isClusterMode) {
+ runDriver(securityMgr)
+ } else {
+ runExecutorLauncher(securityMgr)
+ }
+ } catch {
+ case e: Exception =>
+ // catch everything else if not specifically handled
+ logError("Uncaught exception: ", e)
+ finish(FinalApplicationStatus.FAILED,
+ ApplicationMaster.EXIT_UNCAUGHT_EXCEPTION,
+ "Uncaught exception: " + e)
+ }
+ exitCode
+ }
+
+ /**
+ * Set the default final application status for client mode to UNDEFINED to handle
+ * if YARN HA restarts the application so that it properly retries. Set the final
+ * status to SUCCEEDED in cluster mode to handle if the user calls System.exit
+ * from the application code.
+ */
+ final def getDefaultFinalStatus(): FinalApplicationStatus = {
+ if (isClusterMode) {
+ FinalApplicationStatus.FAILED
+ } else {
+ FinalApplicationStatus.UNDEFINED
+ }
+ }
+
+ /**
+ * unregister is used to completely unregister the application from the ResourceManager.
+ * This means the ResourceManager will not retry the application attempt on your behalf if
+ * a failure occurred.
+ */
+ final def unregister(status: FinalApplicationStatus, diagnostics: String = null): Unit = {
+ synchronized {
+ if (!unregistered) {
+ logInfo(s"Unregistering ApplicationMaster with $status" +
+ Option(diagnostics).map(msg => s" (diag message: $msg)").getOrElse(""))
+ unregistered = true
+ client.unregister(status, Option(diagnostics).getOrElse(""))
+ }
+ }
+ }
+
+ final def finish(status: FinalApplicationStatus, code: Int, msg: String = null): Unit = {
+ synchronized {
+ if (!finished) {
+ val inShutdown = ShutdownHookManager.inShutdown()
+ logInfo(s"Final app status: $status, exitCode: $code" +
+ Option(msg).map(msg => s", (reason: $msg)").getOrElse(""))
+ exitCode = code
+ finalStatus = status
+ finalMsg = msg
+ finished = true
+ if (!inShutdown && Thread.currentThread() != reporterThread && reporterThread != null) {
+ logDebug("shutting down reporter thread")
+ reporterThread.interrupt()
+ }
+ if (!inShutdown && Thread.currentThread() != userClassThread && userClassThread != null) {
+ logDebug("shutting down user thread")
+ userClassThread.interrupt()
+ }
+ if (!inShutdown && credentialRenewer != null) {
+ credentialRenewer.stop()
+ credentialRenewer = null
+ }
+ }
+ }
+ }
+
+ private def sparkContextInitialized(sc: SparkContext) = {
+ sparkContextPromise.success(sc)
+ }
+
+ private def registerAM(
+ _sparkConf: SparkConf,
+ _rpcEnv: RpcEnv,
+ driverRef: RpcEndpointRef,
+ uiAddress: String,
+ securityMgr: SecurityManager) = {
+ val appId = client.getAttemptId().getApplicationId().toString()
+ val attemptId = client.getAttemptId().getAttemptId().toString()
+ val historyAddress =
+ _sparkConf.get(HISTORY_SERVER_ADDRESS)
+ .map { text => SparkHadoopUtil.get.substituteHadoopVariables(text, yarnConf) }
+ .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}/${attemptId}" }
+ .getOrElse("")
+
+ val driverUrl = RpcEndpointAddress(
+ _sparkConf.get("spark.driver.host"),
+ _sparkConf.get("spark.driver.port").toInt,
+ CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString
+
+ // Before we initialize the allocator, let's log the information about how executors will
+ // be run up front, to avoid printing this out for every single executor being launched.
+ // Use placeholders for information that changes such as executor IDs.
+ logInfo {
+ val executorMemory = sparkConf.get(EXECUTOR_MEMORY).toInt
+ val executorCores = sparkConf.get(EXECUTOR_CORES)
+ val dummyRunner = new ExecutorRunnable(None, yarnConf, sparkConf, driverUrl, "<executorId>",
+ "<hostname>", executorMemory, executorCores, appId, securityMgr, localResources)
+ dummyRunner.launchContextDebugInfo()
+ }
+
+ allocator = client.register(driverUrl,
+ driverRef,
+ yarnConf,
+ _sparkConf,
+ uiAddress,
+ historyAddress,
+ securityMgr,
+ localResources)
+
+ allocator.allocateResources()
+ reporterThread = launchReporterThread()
+ }
+
+ /**
+ * Create an [[RpcEndpoint]] that communicates with the driver.
+ *
+ * In cluster mode, the AM and the driver belong to same process
+ * so the AMEndpoint need not monitor lifecycle of the driver.
+ *
+ * @return A reference to the driver's RPC endpoint.
+ */
+ private def runAMEndpoint(
+ host: String,
+ port: String,
+ isClusterMode: Boolean): RpcEndpointRef = {
+ val driverEndpoint = rpcEnv.setupEndpointRef(
+ RpcAddress(host, port.toInt),
+ YarnSchedulerBackend.ENDPOINT_NAME)
+ amEndpoint =
+ rpcEnv.setupEndpoint("YarnAM", new AMEndpoint(rpcEnv, driverEndpoint, isClusterMode))
+ driverEndpoint
+ }
+
+ private def runDriver(securityMgr: SecurityManager): Unit = {
+ addAmIpFilter()
+ userClassThread = startUserApplication()
+
+ // This a bit hacky, but we need to wait until the spark.driver.port property has
+ // been set by the Thread executing the user class.
+ logInfo("Waiting for spark context initialization...")
+ val totalWaitTime = sparkConf.get(AM_MAX_WAIT_TIME)
+ try {
+ val sc = ThreadUtils.awaitResult(sparkContextPromise.future,
+ Duration(totalWaitTime, TimeUnit.MILLISECONDS))
+ if (sc != null) {
+ rpcEnv = sc.env.rpcEnv
+ val driverRef = runAMEndpoint(
+ sc.getConf.get("spark.driver.host"),
+ sc.getConf.get("spark.driver.port"),
+ isClusterMode = true)
+ registerAM(sc.getConf, rpcEnv, driverRef, sc.ui.map(_.webUrl).getOrElse(""),
+ securityMgr)
+ } else {
+ // Sanity check; should never happen in normal operation, since sc should only be null
+ // if the user app did not create a SparkContext.
+ if (!finished) {
+ throw new IllegalStateException("SparkContext is null but app is still running!")
+ }
+ }
+ userClassThread.join()
+ } catch {
+ case e: SparkException if e.getCause().isInstanceOf[TimeoutException] =>
+ logError(
+ s"SparkContext did not initialize after waiting for $totalWaitTime ms. " +
+ "Please check earlier log output for errors. Failing the application.")
+ finish(FinalApplicationStatus.FAILED,
+ ApplicationMaster.EXIT_SC_NOT_INITED,
+ "Timed out waiting for SparkContext.")
+ }
+ }
+
+ private def runExecutorLauncher(securityMgr: SecurityManager): Unit = {
+ val port = sparkConf.get(AM_PORT)
+ rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, port, sparkConf, securityMgr,
+ clientMode = true)
+ val driverRef = waitForSparkDriver()
+ addAmIpFilter()
+ registerAM(sparkConf, rpcEnv, driverRef, sparkConf.get("spark.driver.appUIAddress", ""),
+ securityMgr)
+
+ // In client mode the actor will stop the reporter thread.
+ reporterThread.join()
+ }
+
+ private def launchReporterThread(): Thread = {
+ // The number of failures in a row until Reporter thread give up
+ val reporterMaxFailures = sparkConf.get(MAX_REPORTER_THREAD_FAILURES)
+
+ val t = new Thread {
+ override def run() {
+ var failureCount = 0
+ while (!finished) {
+ try {
+ if (allocator.getNumExecutorsFailed >= maxNumExecutorFailures) {
+ finish(FinalApplicationStatus.FAILED,
+ ApplicationMaster.EXIT_MAX_EXECUTOR_FAILURES,
+ s"Max number of executor failures ($maxNumExecutorFailures) reached")
+ } else {
+ logDebug("Sending progress")
+ allocator.allocateResources()
+ }
+ failureCount = 0
+ } catch {
+ case i: InterruptedException =>
+ case e: Throwable =>
+ failureCount += 1
+ // this exception was introduced in hadoop 2.4 and this code would not compile
+ // with earlier versions if we refer it directly.
+ if ("org.apache.hadoop.yarn.exceptions.ApplicationAttemptNotFoundException" ==
+ e.getClass().getName()) {
+ logError("Exception from Reporter thread.", e)
+ finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_REPORTER_FAILURE,
+ e.getMessage)
+ } else if (!NonFatal(e) || failureCount >= reporterMaxFailures) {
+ finish(FinalApplicationStatus.FAILED,
+ ApplicationMaster.EXIT_REPORTER_FAILURE, "Exception was thrown " +
+ s"$failureCount time(s) from Reporter thread.")
+ } else {
+ logWarning(s"Reporter thread fails $failureCount time(s) in a row.", e)
+ }
+ }
+ try {
+ val numPendingAllocate = allocator.getPendingAllocate.size
+ var sleepStart = 0L
+ var sleepInterval = 200L // ms
+ allocatorLock.synchronized {
+ sleepInterval =
+ if (numPendingAllocate > 0 || allocator.getNumPendingLossReasonRequests > 0) {
+ val currentAllocationInterval =
+ math.min(heartbeatInterval, nextAllocationInterval)
+ nextAllocationInterval = currentAllocationInterval * 2 // avoid overflow
+ currentAllocationInterval
+ } else {
+ nextAllocationInterval = initialAllocationInterval
+ heartbeatInterval
+ }
+ sleepStart = System.currentTimeMillis()
+ allocatorLock.wait(sleepInterval)
+ }
+ val sleepDuration = System.currentTimeMillis() - sleepStart
+ if (sleepDuration < sleepInterval) {
+ // log when sleep is interrupted
+ logDebug(s"Number of pending allocations is $numPendingAllocate. " +
+ s"Slept for $sleepDuration/$sleepInterval ms.")
+ // if sleep was less than the minimum interval, sleep for the rest of it
+ val toSleep = math.max(0, initialAllocationInterval - sleepDuration)
+ if (toSleep > 0) {
+ logDebug(s"Going back to sleep for $toSleep ms")
+ // use Thread.sleep instead of allocatorLock.wait. there is no need to be woken up
+ // by the methods that signal allocatorLock because this is just finishing the min
+ // sleep interval, which should happen even if this is signalled again.
+ Thread.sleep(toSleep)
+ }
+ } else {
+ logDebug(s"Number of pending allocations is $numPendingAllocate. " +
+ s"Slept for $sleepDuration/$sleepInterval.")
+ }
+ } catch {
+ case e: InterruptedException =>
+ }
+ }
+ }
+ }
+ // setting to daemon status, though this is usually not a good idea.
+ t.setDaemon(true)
+ t.setName("Reporter")
+ t.start()
+ logInfo(s"Started progress reporter thread with (heartbeat : $heartbeatInterval, " +
+ s"initial allocation : $initialAllocationInterval) intervals")
+ t
+ }
+
+ /**
+ * Clean up the staging directory.
+ */
+ private def cleanupStagingDir(fs: FileSystem) {
+ var stagingDirPath: Path = null
+ try {
+ val preserveFiles = sparkConf.get(PRESERVE_STAGING_FILES)
+ if (!preserveFiles) {
+ stagingDirPath = new Path(System.getenv("SPARK_YARN_STAGING_DIR"))
+ if (stagingDirPath == null) {
+ logError("Staging directory is null")
+ return
+ }
+ logInfo("Deleting staging directory " + stagingDirPath)
+ fs.delete(stagingDirPath, true)
+ }
+ } catch {
+ case ioe: IOException =>
+ logError("Failed to cleanup staging dir " + stagingDirPath, ioe)
+ }
+ }
+
+ private def waitForSparkDriver(): RpcEndpointRef = {
+ logInfo("Waiting for Spark driver to be reachable.")
+ var driverUp = false
+ val hostport = args.userArgs(0)
+ val (driverHost, driverPort) = Utils.parseHostPort(hostport)
+
+ // Spark driver should already be up since it launched us, but we don't want to
+ // wait forever, so wait 100 seconds max to match the cluster mode setting.
+ val totalWaitTimeMs = sparkConf.get(AM_MAX_WAIT_TIME)
+ val deadline = System.currentTimeMillis + totalWaitTimeMs
+
+ while (!driverUp && !finished && System.currentTimeMillis < deadline) {
+ try {
+ val socket = new Socket(driverHost, driverPort)
+ socket.close()
+ logInfo("Driver now available: %s:%s".format(driverHost, driverPort))
+ driverUp = true
+ } catch {
+ case e: Exception =>
+ logError("Failed to connect to driver at %s:%s, retrying ...".
+ format(driverHost, driverPort))
+ Thread.sleep(100L)
+ }
+ }
+
+ if (!driverUp) {
+ throw new SparkException("Failed to connect to driver!")
+ }
+
+ sparkConf.set("spark.driver.host", driverHost)
+ sparkConf.set("spark.driver.port", driverPort.toString)
+
+ runAMEndpoint(driverHost, driverPort.toString, isClusterMode = false)
+ }
+
+ /** Add the Yarn IP filter that is required for properly securing the UI. */
+ private def addAmIpFilter() = {
+ val proxyBase = System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV)
+ val amFilter = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter"
+ val params = client.getAmIpFilterParams(yarnConf, proxyBase)
+ if (isClusterMode) {
+ System.setProperty("spark.ui.filters", amFilter)
+ params.foreach { case (k, v) => System.setProperty(s"spark.$amFilter.param.$k", v) }
+ } else {
+ amEndpoint.send(AddWebUIFilter(amFilter, params.toMap, proxyBase))
+ }
+ }
+
+ /**
+ * Start the user class, which contains the spark driver, in a separate Thread.
+ * If the main routine exits cleanly or exits with System.exit(N) for any N
+ * we assume it was successful, for all other cases we assume failure.
+ *
+ * Returns the user thread that was started.
+ */
+ private def startUserApplication(): Thread = {
+ logInfo("Starting the user application in a separate Thread")
+
+ val classpath = Client.getUserClasspath(sparkConf)
+ val urls = classpath.map { entry =>
+ new URL("file:" + new File(entry.getPath()).getAbsolutePath())
+ }
+ val userClassLoader =
+ if (Client.isUserClassPathFirst(sparkConf, isDriver = true)) {
+ new ChildFirstURLClassLoader(urls, Utils.getContextOrSparkClassLoader)
+ } else {
+ new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader)
+ }
+
+ var userArgs = args.userArgs
+ if (args.primaryPyFile != null && args.primaryPyFile.endsWith(".py")) {
+ // When running pyspark, the app is run using PythonRunner. The second argument is the list
+ // of files to add to PYTHONPATH, which Client.scala already handles, so it's empty.
+ userArgs = Seq(args.primaryPyFile, "") ++ userArgs
+ }
+ if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) {
+ // TODO(davies): add R dependencies here
+ }
+ val mainMethod = userClassLoader.loadClass(args.userClass)
+ .getMethod("main", classOf[Array[String]])
+
+ val userThread = new Thread {
+ override def run() {
+ try {
+ mainMethod.invoke(null, userArgs.toArray)
+ finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS)
+ logDebug("Done running users class")
+ } catch {
+ case e: InvocationTargetException =>
+ e.getCause match {
+ case _: InterruptedException =>
+ // Reporter thread can interrupt to stop user class
+ case SparkUserAppException(exitCode) =>
+ val msg = s"User application exited with status $exitCode"
+ logError(msg)
+ finish(FinalApplicationStatus.FAILED, exitCode, msg)
+ case cause: Throwable =>
+ logError("User class threw exception: " + cause, cause)
+ finish(FinalApplicationStatus.FAILED,
+ ApplicationMaster.EXIT_EXCEPTION_USER_CLASS,
+ "User class threw exception: " + cause)
+ }
+ sparkContextPromise.tryFailure(e.getCause())
+ } finally {
+ // Notify the thread waiting for the SparkContext, in case the application did not
+ // instantiate one. This will do nothing when the user code instantiates a SparkContext
+ // (with the correct master), or when the user code throws an exception (due to the
+ // tryFailure above).
+ sparkContextPromise.trySuccess(null)
+ }
+ }
+ }
+ userThread.setContextClassLoader(userClassLoader)
+ userThread.setName("Driver")
+ userThread.start()
+ userThread
+ }
+
+ private def resetAllocatorInterval(): Unit = allocatorLock.synchronized {
+ nextAllocationInterval = initialAllocationInterval
+ allocatorLock.notifyAll()
+ }
+
+ /**
+ * An [[RpcEndpoint]] that communicates with the driver's scheduler backend.
+ */
+ private class AMEndpoint(
+ override val rpcEnv: RpcEnv, driver: RpcEndpointRef, isClusterMode: Boolean)
+ extends RpcEndpoint with Logging {
+
+ override def onStart(): Unit = {
+ driver.send(RegisterClusterManager(self))
+ }
+
+ override def receive: PartialFunction[Any, Unit] = {
+ case x: AddWebUIFilter =>
+ logInfo(s"Add WebUI Filter. $x")
+ driver.send(x)
+ }
+
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ case RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount) =>
+ Option(allocator) match {
+ case Some(a) =>
+ if (a.requestTotalExecutorsWithPreferredLocalities(requestedTotal,
+ localityAwareTasks, hostToLocalTaskCount)) {
+ resetAllocatorInterval()
+ }
+ context.reply(true)
+
+ case None =>
+ logWarning("Container allocator is not ready to request executors yet.")
+ context.reply(false)
+ }
+
+ case KillExecutors(executorIds) =>
+ logInfo(s"Driver requested to kill executor(s) ${executorIds.mkString(", ")}.")
+ Option(allocator) match {
+ case Some(a) => executorIds.foreach(a.killExecutor)
+ case None => logWarning("Container allocator is not ready to kill executors yet.")
+ }
+ context.reply(true)
+
+ case GetExecutorLossReason(eid) =>
+ Option(allocator) match {
+ case Some(a) =>
+ a.enqueueGetLossReasonRequest(eid, context)
+ resetAllocatorInterval()
+ case None =>
+ logWarning("Container allocator is not ready to find executor loss reasons yet.")
+ }
+ }
+
+ override def onDisconnected(remoteAddress: RpcAddress): Unit = {
+ // In cluster mode, do not rely on the disassociated event to exit
+ // This avoids potentially reporting incorrect exit codes if the driver fails
+ if (!isClusterMode) {
+ logInfo(s"Driver terminated or disconnected! Shutting down. $remoteAddress")
+ finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS)
+ }
+ }
+ }
+
+}
+
+object ApplicationMaster extends Logging {
+
+ // exit codes for different causes, no reason behind the values
+ private val EXIT_SUCCESS = 0
+ private val EXIT_UNCAUGHT_EXCEPTION = 10
+ private val EXIT_MAX_EXECUTOR_FAILURES = 11
+ private val EXIT_REPORTER_FAILURE = 12
+ private val EXIT_SC_NOT_INITED = 13
+ private val EXIT_SECURITY = 14
+ private val EXIT_EXCEPTION_USER_CLASS = 15
+ private val EXIT_EARLY = 16
+
+ private var master: ApplicationMaster = _
+
+ def main(args: Array[String]): Unit = {
+ SignalUtils.registerLogger(log)
+ val amArgs = new ApplicationMasterArguments(args)
+
+ // Load the properties file with the Spark configuration and set entries as system properties,
+ // so that user code run inside the AM also has access to them.
+ // Note: we must do this before SparkHadoopUtil instantiated
+ if (amArgs.propertiesFile != null) {
+ Utils.getPropertiesFromFile(amArgs.propertiesFile).foreach { case (k, v) =>
+ sys.props(k) = v
+ }
+ }
+ SparkHadoopUtil.get.runAsSparkUser { () =>
+ master = new ApplicationMaster(amArgs, new YarnRMClient)
+ System.exit(master.run())
+ }
+ }
+
+ private[spark] def sparkContextInitialized(sc: SparkContext): Unit = {
+ master.sparkContextInitialized(sc)
+ }
+
+ private[spark] def getAttemptId(): ApplicationAttemptId = {
+ master.getAttemptId
+ }
+
+}
+
+/**
+ * This object does not provide any special functionality. It exists so that it's easy to tell
+ * apart the client-mode AM from the cluster-mode AM when using tools such as ps or jps.
+ */
+object ExecutorLauncher {
+
+ def main(args: Array[String]): Unit = {
+ ApplicationMaster.main(args)
+ }
+
+}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala
new file mode 100644
index 0000000000..5cdec87667
--- /dev/null
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.util.{IntParam, MemoryParam}
+
+class ApplicationMasterArguments(val args: Array[String]) {
+ var userJar: String = null
+ var userClass: String = null
+ var primaryPyFile: String = null
+ var primaryRFile: String = null
+ var userArgs: Seq[String] = Nil
+ var propertiesFile: String = null
+
+ parseArgs(args.toList)
+
+ private def parseArgs(inputArgs: List[String]): Unit = {
+ val userArgsBuffer = new ArrayBuffer[String]()
+
+ var args = inputArgs
+
+ while (!args.isEmpty) {
+ // --num-workers, --worker-memory, and --worker-cores are deprecated since 1.0,
+ // the properties with executor in their names are preferred.
+ args match {
+ case ("--jar") :: value :: tail =>
+ userJar = value
+ args = tail
+
+ case ("--class") :: value :: tail =>
+ userClass = value
+ args = tail
+
+ case ("--primary-py-file") :: value :: tail =>
+ primaryPyFile = value
+ args = tail
+
+ case ("--primary-r-file") :: value :: tail =>
+ primaryRFile = value
+ args = tail
+
+ case ("--arg") :: value :: tail =>
+ userArgsBuffer += value
+ args = tail
+
+ case ("--properties-file") :: value :: tail =>
+ propertiesFile = value
+ args = tail
+
+ case _ =>
+ printUsageAndExit(1, args)
+ }
+ }
+
+ if (primaryPyFile != null && primaryRFile != null) {
+ // scalastyle:off println
+ System.err.println("Cannot have primary-py-file and primary-r-file at the same time")
+ // scalastyle:on println
+ System.exit(-1)
+ }
+
+ userArgs = userArgsBuffer.toList
+ }
+
+ def printUsageAndExit(exitCode: Int, unknownParam: Any = null) {
+ // scalastyle:off println
+ if (unknownParam != null) {
+ System.err.println("Unknown/unsupported param " + unknownParam)
+ }
+ System.err.println("""
+ |Usage: org.apache.spark.deploy.yarn.ApplicationMaster [options]
+ |Options:
+ | --jar JAR_PATH Path to your application's JAR file
+ | --class CLASS_NAME Name of your application's main class
+ | --primary-py-file A main Python file
+ | --primary-r-file A main R file
+ | --arg ARG Argument to be passed to your application's main class.
+ | Multiple invocations are possible, each will be passed in order.
+ | --properties-file FILE Path to a custom Spark properties file.
+ """.stripMargin)
+ // scalastyle:on println
+ System.exit(exitCode)
+ }
+}
+
+object ApplicationMasterArguments {
+ val DEFAULT_NUMBER_EXECUTORS = 2
+}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
new file mode 100644
index 0000000000..be419cee77
--- /dev/null
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -0,0 +1,1541 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.io.{File, FileOutputStream, IOException, OutputStreamWriter}
+import java.net.{InetAddress, UnknownHostException, URI}
+import java.nio.ByteBuffer
+import java.nio.charset.StandardCharsets
+import java.util.{Properties, UUID}
+import java.util.zip.{ZipEntry, ZipOutputStream}
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, ListBuffer, Map}
+import scala.util.{Failure, Success, Try}
+import scala.util.control.NonFatal
+
+import com.google.common.base.Objects
+import com.google.common.io.Files
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs._
+import org.apache.hadoop.fs.permission.FsPermission
+import org.apache.hadoop.io.DataOutputBuffer
+import org.apache.hadoop.mapreduce.MRJobConfig
+import org.apache.hadoop.security.{Credentials, UserGroupInformation}
+import org.apache.hadoop.util.StringUtils
+import org.apache.hadoop.yarn.api._
+import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
+import org.apache.hadoop.yarn.api.protocolrecords._
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.client.api.{YarnClient, YarnClientApplication}
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.exceptions.ApplicationNotFoundException
+import org.apache.hadoop.yarn.util.Records
+
+import org.apache.spark.{SecurityManager, SparkConf, SparkContext, SparkException}
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.deploy.yarn.config._
+import org.apache.spark.deploy.yarn.security.ConfigurableCredentialManager
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config._
+import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle, YarnCommandBuilderUtils}
+import org.apache.spark.util.{CallerContext, Utils}
+
+private[spark] class Client(
+ val args: ClientArguments,
+ val hadoopConf: Configuration,
+ val sparkConf: SparkConf)
+ extends Logging {
+
+ import Client._
+ import YarnSparkHadoopUtil._
+
+ def this(clientArgs: ClientArguments, spConf: SparkConf) =
+ this(clientArgs, SparkHadoopUtil.get.newConfiguration(spConf), spConf)
+
+ private val yarnClient = YarnClient.createYarnClient
+ private val yarnConf = new YarnConfiguration(hadoopConf)
+
+ private val isClusterMode = sparkConf.get("spark.submit.deployMode", "client") == "cluster"
+
+ // AM related configurations
+ private val amMemory = if (isClusterMode) {
+ sparkConf.get(DRIVER_MEMORY).toInt
+ } else {
+ sparkConf.get(AM_MEMORY).toInt
+ }
+ private val amMemoryOverhead = {
+ val amMemoryOverheadEntry = if (isClusterMode) DRIVER_MEMORY_OVERHEAD else AM_MEMORY_OVERHEAD
+ sparkConf.get(amMemoryOverheadEntry).getOrElse(
+ math.max((MEMORY_OVERHEAD_FACTOR * amMemory).toLong, MEMORY_OVERHEAD_MIN)).toInt
+ }
+ private val amCores = if (isClusterMode) {
+ sparkConf.get(DRIVER_CORES)
+ } else {
+ sparkConf.get(AM_CORES)
+ }
+
+ // Executor related configurations
+ private val executorMemory = sparkConf.get(EXECUTOR_MEMORY)
+ private val executorMemoryOverhead = sparkConf.get(EXECUTOR_MEMORY_OVERHEAD).getOrElse(
+ math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toLong, MEMORY_OVERHEAD_MIN)).toInt
+
+ private val distCacheMgr = new ClientDistributedCacheManager()
+
+ private var loginFromKeytab = false
+ private var principal: String = null
+ private var keytab: String = null
+ private var credentials: Credentials = null
+
+ private val launcherBackend = new LauncherBackend() {
+ override def onStopRequest(): Unit = {
+ if (isClusterMode && appId != null) {
+ yarnClient.killApplication(appId)
+ } else {
+ setState(SparkAppHandle.State.KILLED)
+ stop()
+ }
+ }
+ }
+ private val fireAndForget = isClusterMode && !sparkConf.get(WAIT_FOR_APP_COMPLETION)
+
+ private var appId: ApplicationId = null
+
+ // The app staging dir based on the STAGING_DIR configuration if configured
+ // otherwise based on the users home directory.
+ private val appStagingBaseDir = sparkConf.get(STAGING_DIR).map { new Path(_) }
+ .getOrElse(FileSystem.get(hadoopConf).getHomeDirectory())
+
+ private val credentialManager = new ConfigurableCredentialManager(sparkConf, hadoopConf)
+
+ def reportLauncherState(state: SparkAppHandle.State): Unit = {
+ launcherBackend.setState(state)
+ }
+
+ def stop(): Unit = {
+ launcherBackend.close()
+ yarnClient.stop()
+ // Unset YARN mode system env variable, to allow switching between cluster types.
+ System.clearProperty("SPARK_YARN_MODE")
+ }
+
+ /**
+ * Submit an application running our ApplicationMaster to the ResourceManager.
+ *
+ * The stable Yarn API provides a convenience method (YarnClient#createApplication) for
+ * creating applications and setting up the application submission context. This was not
+ * available in the alpha API.
+ */
+ def submitApplication(): ApplicationId = {
+ var appId: ApplicationId = null
+ try {
+ launcherBackend.connect()
+ // Setup the credentials before doing anything else,
+ // so we have don't have issues at any point.
+ setupCredentials()
+ yarnClient.init(yarnConf)
+ yarnClient.start()
+
+ logInfo("Requesting a new application from cluster with %d NodeManagers"
+ .format(yarnClient.getYarnClusterMetrics.getNumNodeManagers))
+
+ // Get a new application from our RM
+ val newApp = yarnClient.createApplication()
+ val newAppResponse = newApp.getNewApplicationResponse()
+ appId = newAppResponse.getApplicationId()
+ reportLauncherState(SparkAppHandle.State.SUBMITTED)
+ launcherBackend.setAppId(appId.toString)
+
+ new CallerContext("CLIENT", sparkConf.get(APP_CALLER_CONTEXT),
+ Option(appId.toString)).setCurrentContext()
+
+ // Verify whether the cluster has enough resources for our AM
+ verifyClusterResources(newAppResponse)
+
+ // Set up the appropriate contexts to launch our AM
+ val containerContext = createContainerLaunchContext(newAppResponse)
+ val appContext = createApplicationSubmissionContext(newApp, containerContext)
+
+ // Finally, submit and monitor the application
+ logInfo(s"Submitting application $appId to ResourceManager")
+ yarnClient.submitApplication(appContext)
+ appId
+ } catch {
+ case e: Throwable =>
+ if (appId != null) {
+ cleanupStagingDir(appId)
+ }
+ throw e
+ }
+ }
+
+ /**
+ * Cleanup application staging directory.
+ */
+ private def cleanupStagingDir(appId: ApplicationId): Unit = {
+ val stagingDirPath = new Path(appStagingBaseDir, getAppStagingDir(appId))
+ try {
+ val preserveFiles = sparkConf.get(PRESERVE_STAGING_FILES)
+ val fs = stagingDirPath.getFileSystem(hadoopConf)
+ if (!preserveFiles && fs.delete(stagingDirPath, true)) {
+ logInfo(s"Deleted staging directory $stagingDirPath")
+ }
+ } catch {
+ case ioe: IOException =>
+ logWarning("Failed to cleanup staging dir " + stagingDirPath, ioe)
+ }
+ }
+
+ /**
+ * Set up the context for submitting our ApplicationMaster.
+ * This uses the YarnClientApplication not available in the Yarn alpha API.
+ */
+ def createApplicationSubmissionContext(
+ newApp: YarnClientApplication,
+ containerContext: ContainerLaunchContext): ApplicationSubmissionContext = {
+ val appContext = newApp.getApplicationSubmissionContext
+ appContext.setApplicationName(sparkConf.get("spark.app.name", "Spark"))
+ appContext.setQueue(sparkConf.get(QUEUE_NAME))
+ appContext.setAMContainerSpec(containerContext)
+ appContext.setApplicationType("SPARK")
+
+ sparkConf.get(APPLICATION_TAGS).foreach { tags =>
+ try {
+ // The setApplicationTags method was only introduced in Hadoop 2.4+, so we need to use
+ // reflection to set it, printing a warning if a tag was specified but the YARN version
+ // doesn't support it.
+ val method = appContext.getClass().getMethod(
+ "setApplicationTags", classOf[java.util.Set[String]])
+ method.invoke(appContext, new java.util.HashSet[String](tags.asJava))
+ } catch {
+ case e: NoSuchMethodException =>
+ logWarning(s"Ignoring ${APPLICATION_TAGS.key} because this version of " +
+ "YARN does not support it")
+ }
+ }
+ sparkConf.get(MAX_APP_ATTEMPTS) match {
+ case Some(v) => appContext.setMaxAppAttempts(v)
+ case None => logDebug(s"${MAX_APP_ATTEMPTS.key} is not set. " +
+ "Cluster's default value will be used.")
+ }
+
+ sparkConf.get(AM_ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS).foreach { interval =>
+ try {
+ val method = appContext.getClass().getMethod(
+ "setAttemptFailuresValidityInterval", classOf[Long])
+ method.invoke(appContext, interval: java.lang.Long)
+ } catch {
+ case e: NoSuchMethodException =>
+ logWarning(s"Ignoring ${AM_ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS.key} because " +
+ "the version of YARN does not support it")
+ }
+ }
+
+ val capability = Records.newRecord(classOf[Resource])
+ capability.setMemory(amMemory + amMemoryOverhead)
+ capability.setVirtualCores(amCores)
+
+ sparkConf.get(AM_NODE_LABEL_EXPRESSION) match {
+ case Some(expr) =>
+ try {
+ val amRequest = Records.newRecord(classOf[ResourceRequest])
+ amRequest.setResourceName(ResourceRequest.ANY)
+ amRequest.setPriority(Priority.newInstance(0))
+ amRequest.setCapability(capability)
+ amRequest.setNumContainers(1)
+ val method = amRequest.getClass.getMethod("setNodeLabelExpression", classOf[String])
+ method.invoke(amRequest, expr)
+
+ val setResourceRequestMethod =
+ appContext.getClass.getMethod("setAMContainerResourceRequest", classOf[ResourceRequest])
+ setResourceRequestMethod.invoke(appContext, amRequest)
+ } catch {
+ case e: NoSuchMethodException =>
+ logWarning(s"Ignoring ${AM_NODE_LABEL_EXPRESSION.key} because the version " +
+ "of YARN does not support it")
+ appContext.setResource(capability)
+ }
+ case None =>
+ appContext.setResource(capability)
+ }
+
+ sparkConf.get(ROLLED_LOG_INCLUDE_PATTERN).foreach { includePattern =>
+ try {
+ val logAggregationContext = Records.newRecord(
+ Utils.classForName("org.apache.hadoop.yarn.api.records.LogAggregationContext"))
+ .asInstanceOf[Object]
+
+ val setRolledLogsIncludePatternMethod =
+ logAggregationContext.getClass.getMethod("setRolledLogsIncludePattern", classOf[String])
+ setRolledLogsIncludePatternMethod.invoke(logAggregationContext, includePattern)
+
+ sparkConf.get(ROLLED_LOG_EXCLUDE_PATTERN).foreach { excludePattern =>
+ val setRolledLogsExcludePatternMethod =
+ logAggregationContext.getClass.getMethod("setRolledLogsExcludePattern", classOf[String])
+ setRolledLogsExcludePatternMethod.invoke(logAggregationContext, excludePattern)
+ }
+
+ val setLogAggregationContextMethod =
+ appContext.getClass.getMethod("setLogAggregationContext",
+ Utils.classForName("org.apache.hadoop.yarn.api.records.LogAggregationContext"))
+ setLogAggregationContextMethod.invoke(appContext, logAggregationContext)
+ } catch {
+ case NonFatal(e) =>
+ logWarning(s"Ignoring ${ROLLED_LOG_INCLUDE_PATTERN.key} because the version of YARN " +
+ s"does not support it", e)
+ }
+ }
+
+ appContext
+ }
+
+ /** Set up security tokens for launching our ApplicationMaster container. */
+ private def setupSecurityToken(amContainer: ContainerLaunchContext): Unit = {
+ val dob = new DataOutputBuffer
+ credentials.writeTokenStorageToStream(dob)
+ amContainer.setTokens(ByteBuffer.wrap(dob.getData))
+ }
+
+ /** Get the application report from the ResourceManager for an application we have submitted. */
+ def getApplicationReport(appId: ApplicationId): ApplicationReport =
+ yarnClient.getApplicationReport(appId)
+
+ /**
+ * Return the security token used by this client to communicate with the ApplicationMaster.
+ * If no security is enabled, the token returned by the report is null.
+ */
+ private def getClientToken(report: ApplicationReport): String =
+ Option(report.getClientToAMToken).map(_.toString).getOrElse("")
+
+ /**
+ * Fail fast if we have requested more resources per container than is available in the cluster.
+ */
+ private def verifyClusterResources(newAppResponse: GetNewApplicationResponse): Unit = {
+ val maxMem = newAppResponse.getMaximumResourceCapability().getMemory()
+ logInfo("Verifying our application has not requested more than the maximum " +
+ s"memory capability of the cluster ($maxMem MB per container)")
+ val executorMem = executorMemory + executorMemoryOverhead
+ if (executorMem > maxMem) {
+ throw new IllegalArgumentException(s"Required executor memory ($executorMemory" +
+ s"+$executorMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster! " +
+ "Please check the values of 'yarn.scheduler.maximum-allocation-mb' and/or " +
+ "'yarn.nodemanager.resource.memory-mb'.")
+ }
+ val amMem = amMemory + amMemoryOverhead
+ if (amMem > maxMem) {
+ throw new IllegalArgumentException(s"Required AM memory ($amMemory" +
+ s"+$amMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster! " +
+ "Please increase the value of 'yarn.scheduler.maximum-allocation-mb'.")
+ }
+ logInfo("Will allocate AM container, with %d MB memory including %d MB overhead".format(
+ amMem,
+ amMemoryOverhead))
+
+ // We could add checks to make sure the entire cluster has enough resources but that involves
+ // getting all the node reports and computing ourselves.
+ }
+
+ /**
+ * Copy the given file to a remote file system (e.g. HDFS) if needed.
+ * The file is only copied if the source and destination file systems are different. This is used
+ * for preparing resources for launching the ApplicationMaster container. Exposed for testing.
+ */
+ private[yarn] def copyFileToRemote(
+ destDir: Path,
+ srcPath: Path,
+ replication: Short,
+ force: Boolean = false,
+ destName: Option[String] = None): Path = {
+ val destFs = destDir.getFileSystem(hadoopConf)
+ val srcFs = srcPath.getFileSystem(hadoopConf)
+ var destPath = srcPath
+ if (force || !compareFs(srcFs, destFs)) {
+ destPath = new Path(destDir, destName.getOrElse(srcPath.getName()))
+ logInfo(s"Uploading resource $srcPath -> $destPath")
+ FileUtil.copy(srcFs, srcPath, destFs, destPath, false, hadoopConf)
+ destFs.setReplication(destPath, replication)
+ destFs.setPermission(destPath, new FsPermission(APP_FILE_PERMISSION))
+ } else {
+ logInfo(s"Source and destination file systems are the same. Not copying $srcPath")
+ }
+ // Resolve any symlinks in the URI path so using a "current" symlink to point to a specific
+ // version shows the specific version in the distributed cache configuration
+ val qualifiedDestPath = destFs.makeQualified(destPath)
+ val fc = FileContext.getFileContext(qualifiedDestPath.toUri(), hadoopConf)
+ fc.resolvePath(qualifiedDestPath)
+ }
+
+ /**
+ * Upload any resources to the distributed cache if needed. If a resource is intended to be
+ * consumed locally, set up the appropriate config for downstream code to handle it properly.
+ * This is used for setting up a container launch context for our ApplicationMaster.
+ * Exposed for testing.
+ */
+ def prepareLocalResources(
+ destDir: Path,
+ pySparkArchives: Seq[String]): HashMap[String, LocalResource] = {
+ logInfo("Preparing resources for our AM container")
+ // Upload Spark and the application JAR to the remote file system if necessary,
+ // and add them as local resources to the application master.
+ val fs = destDir.getFileSystem(hadoopConf)
+
+ // Merge credentials obtained from registered providers
+ val nearestTimeOfNextRenewal = credentialManager.obtainCredentials(hadoopConf, credentials)
+
+ if (credentials != null) {
+ logDebug(YarnSparkHadoopUtil.get.dumpTokens(credentials).mkString("\n"))
+ }
+
+ // If we use principal and keytab to login, also credentials can be renewed some time
+ // after current time, we should pass the next renewal and updating time to credential
+ // renewer and updater.
+ if (loginFromKeytab && nearestTimeOfNextRenewal > System.currentTimeMillis() &&
+ nearestTimeOfNextRenewal != Long.MaxValue) {
+
+ // Valid renewal time is 75% of next renewal time, and the valid update time will be
+ // slightly later then renewal time (80% of next renewal time). This is to make sure
+ // credentials are renewed and updated before expired.
+ val currTime = System.currentTimeMillis()
+ val renewalTime = (nearestTimeOfNextRenewal - currTime) * 0.75 + currTime
+ val updateTime = (nearestTimeOfNextRenewal - currTime) * 0.8 + currTime
+
+ sparkConf.set(CREDENTIALS_RENEWAL_TIME, renewalTime.toLong)
+ sparkConf.set(CREDENTIALS_UPDATE_TIME, updateTime.toLong)
+ }
+
+ // Used to keep track of URIs added to the distributed cache. If the same URI is added
+ // multiple times, YARN will fail to launch containers for the app with an internal
+ // error.
+ val distributedUris = new HashSet[String]
+ // Used to keep track of URIs(files) added to the distribute cache have the same name. If
+ // same name but different path files are added multiple time, YARN will fail to launch
+ // containers for the app with an internal error.
+ val distributedNames = new HashSet[String]
+
+ val replication = sparkConf.get(STAGING_FILE_REPLICATION).map(_.toShort)
+ .getOrElse(fs.getDefaultReplication(destDir))
+ val localResources = HashMap[String, LocalResource]()
+ FileSystem.mkdirs(fs, destDir, new FsPermission(STAGING_DIR_PERMISSION))
+
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+
+ def addDistributedUri(uri: URI): Boolean = {
+ val uriStr = uri.toString()
+ val fileName = new File(uri.getPath).getName
+ if (distributedUris.contains(uriStr)) {
+ logWarning(s"Same path resource $uri added multiple times to distributed cache.")
+ false
+ } else if (distributedNames.contains(fileName)) {
+ logWarning(s"Same name resource $uri added multiple times to distributed cache")
+ false
+ } else {
+ distributedUris += uriStr
+ distributedNames += fileName
+ true
+ }
+ }
+
+ /**
+ * Distribute a file to the cluster.
+ *
+ * If the file's path is a "local:" URI, it's actually not distributed. Other files are copied
+ * to HDFS (if not already there) and added to the application's distributed cache.
+ *
+ * @param path URI of the file to distribute.
+ * @param resType Type of resource being distributed.
+ * @param destName Name of the file in the distributed cache.
+ * @param targetDir Subdirectory where to place the file.
+ * @param appMasterOnly Whether to distribute only to the AM.
+ * @return A 2-tuple. First item is whether the file is a "local:" URI. Second item is the
+ * localized path for non-local paths, or the input `path` for local paths.
+ * The localized path will be null if the URI has already been added to the cache.
+ */
+ def distribute(
+ path: String,
+ resType: LocalResourceType = LocalResourceType.FILE,
+ destName: Option[String] = None,
+ targetDir: Option[String] = None,
+ appMasterOnly: Boolean = false): (Boolean, String) = {
+ val trimmedPath = path.trim()
+ val localURI = Utils.resolveURI(trimmedPath)
+ if (localURI.getScheme != LOCAL_SCHEME) {
+ if (addDistributedUri(localURI)) {
+ val localPath = getQualifiedLocalPath(localURI, hadoopConf)
+ val linkname = targetDir.map(_ + "/").getOrElse("") +
+ destName.orElse(Option(localURI.getFragment())).getOrElse(localPath.getName())
+ val destPath = copyFileToRemote(destDir, localPath, replication)
+ val destFs = FileSystem.get(destPath.toUri(), hadoopConf)
+ distCacheMgr.addResource(
+ destFs, hadoopConf, destPath, localResources, resType, linkname, statCache,
+ appMasterOnly = appMasterOnly)
+ (false, linkname)
+ } else {
+ (false, null)
+ }
+ } else {
+ (true, trimmedPath)
+ }
+ }
+
+ // If we passed in a keytab, make sure we copy the keytab to the staging directory on
+ // HDFS, and setup the relevant environment vars, so the AM can login again.
+ if (loginFromKeytab) {
+ logInfo("To enable the AM to login from keytab, credentials are being copied over to the AM" +
+ " via the YARN Secure Distributed Cache.")
+ val (_, localizedPath) = distribute(keytab,
+ destName = sparkConf.get(KEYTAB),
+ appMasterOnly = true)
+ require(localizedPath != null, "Keytab file already distributed.")
+ }
+
+ /**
+ * Add Spark to the cache. There are two settings that control what files to add to the cache:
+ * - if a Spark archive is defined, use the archive. The archive is expected to contain
+ * jar files at its root directory.
+ * - if a list of jars is provided, filter the non-local ones, resolve globs, and
+ * add the found files to the cache.
+ *
+ * Note that the archive cannot be a "local" URI. If none of the above settings are found,
+ * then upload all files found in $SPARK_HOME/jars.
+ */
+ val sparkArchive = sparkConf.get(SPARK_ARCHIVE)
+ if (sparkArchive.isDefined) {
+ val archive = sparkArchive.get
+ require(!isLocalUri(archive), s"${SPARK_ARCHIVE.key} cannot be a local URI.")
+ distribute(Utils.resolveURI(archive).toString,
+ resType = LocalResourceType.ARCHIVE,
+ destName = Some(LOCALIZED_LIB_DIR))
+ } else {
+ sparkConf.get(SPARK_JARS) match {
+ case Some(jars) =>
+ // Break the list of jars to upload, and resolve globs.
+ val localJars = new ArrayBuffer[String]()
+ jars.foreach { jar =>
+ if (!isLocalUri(jar)) {
+ val path = getQualifiedLocalPath(Utils.resolveURI(jar), hadoopConf)
+ val pathFs = FileSystem.get(path.toUri(), hadoopConf)
+ pathFs.globStatus(path).filter(_.isFile()).foreach { entry =>
+ distribute(entry.getPath().toUri().toString(),
+ targetDir = Some(LOCALIZED_LIB_DIR))
+ }
+ } else {
+ localJars += jar
+ }
+ }
+
+ // Propagate the local URIs to the containers using the configuration.
+ sparkConf.set(SPARK_JARS, localJars)
+
+ case None =>
+ // No configuration, so fall back to uploading local jar files.
+ logWarning(s"Neither ${SPARK_JARS.key} nor ${SPARK_ARCHIVE.key} is set, falling back " +
+ "to uploading libraries under SPARK_HOME.")
+ val jarsDir = new File(YarnCommandBuilderUtils.findJarsDir(
+ sparkConf.getenv("SPARK_HOME")))
+ val jarsArchive = File.createTempFile(LOCALIZED_LIB_DIR, ".zip",
+ new File(Utils.getLocalDir(sparkConf)))
+ val jarsStream = new ZipOutputStream(new FileOutputStream(jarsArchive))
+
+ try {
+ jarsStream.setLevel(0)
+ jarsDir.listFiles().foreach { f =>
+ if (f.isFile && f.getName.toLowerCase().endsWith(".jar") && f.canRead) {
+ jarsStream.putNextEntry(new ZipEntry(f.getName))
+ Files.copy(f, jarsStream)
+ jarsStream.closeEntry()
+ }
+ }
+ } finally {
+ jarsStream.close()
+ }
+
+ distribute(jarsArchive.toURI.getPath,
+ resType = LocalResourceType.ARCHIVE,
+ destName = Some(LOCALIZED_LIB_DIR))
+ }
+ }
+
+ /**
+ * Copy user jar to the distributed cache if their scheme is not "local".
+ * Otherwise, set the corresponding key in our SparkConf to handle it downstream.
+ */
+ Option(args.userJar).filter(_.trim.nonEmpty).foreach { jar =>
+ val (isLocal, localizedPath) = distribute(jar, destName = Some(APP_JAR_NAME))
+ if (isLocal) {
+ require(localizedPath != null, s"Path $jar already distributed")
+ // If the resource is intended for local use only, handle this downstream
+ // by setting the appropriate property
+ sparkConf.set(APP_JAR, localizedPath)
+ }
+ }
+
+ /**
+ * Do the same for any additional resources passed in through ClientArguments.
+ * Each resource category is represented by a 3-tuple of:
+ * (1) comma separated list of resources in this category,
+ * (2) resource type, and
+ * (3) whether to add these resources to the classpath
+ */
+ val cachedSecondaryJarLinks = ListBuffer.empty[String]
+ List(
+ (sparkConf.get(JARS_TO_DISTRIBUTE), LocalResourceType.FILE, true),
+ (sparkConf.get(FILES_TO_DISTRIBUTE), LocalResourceType.FILE, false),
+ (sparkConf.get(ARCHIVES_TO_DISTRIBUTE), LocalResourceType.ARCHIVE, false)
+ ).foreach { case (flist, resType, addToClasspath) =>
+ flist.foreach { file =>
+ val (_, localizedPath) = distribute(file, resType = resType)
+ // If addToClassPath, we ignore adding jar multiple times to distitrbuted cache.
+ if (addToClasspath) {
+ if (localizedPath != null) {
+ cachedSecondaryJarLinks += localizedPath
+ }
+ } else {
+ if (localizedPath == null) {
+ throw new IllegalArgumentException(s"Attempt to add ($file) multiple times" +
+ " to the distributed cache.")
+ }
+ }
+ }
+ }
+ if (cachedSecondaryJarLinks.nonEmpty) {
+ sparkConf.set(SECONDARY_JARS, cachedSecondaryJarLinks)
+ }
+
+ if (isClusterMode && args.primaryPyFile != null) {
+ distribute(args.primaryPyFile, appMasterOnly = true)
+ }
+
+ pySparkArchives.foreach { f => distribute(f) }
+
+ // The python files list needs to be treated especially. All files that are not an
+ // archive need to be placed in a subdirectory that will be added to PYTHONPATH.
+ sparkConf.get(PY_FILES).foreach { f =>
+ val targetDir = if (f.endsWith(".py")) Some(LOCALIZED_PYTHON_DIR) else None
+ distribute(f, targetDir = targetDir)
+ }
+
+ // Update the configuration with all the distributed files, minus the conf archive. The
+ // conf archive will be handled by the AM differently so that we avoid having to send
+ // this configuration by other means. See SPARK-14602 for one reason of why this is needed.
+ distCacheMgr.updateConfiguration(sparkConf)
+
+ // Upload the conf archive to HDFS manually, and record its location in the configuration.
+ // This will allow the AM to know where the conf archive is in HDFS, so that it can be
+ // distributed to the containers.
+ //
+ // This code forces the archive to be copied, so that unit tests pass (since in that case both
+ // file systems are the same and the archive wouldn't normally be copied). In most (all?)
+ // deployments, the archive would be copied anyway, since it's a temp file in the local file
+ // system.
+ val remoteConfArchivePath = new Path(destDir, LOCALIZED_CONF_ARCHIVE)
+ val remoteFs = FileSystem.get(remoteConfArchivePath.toUri(), hadoopConf)
+ sparkConf.set(CACHED_CONF_ARCHIVE, remoteConfArchivePath.toString())
+
+ val localConfArchive = new Path(createConfArchive().toURI())
+ copyFileToRemote(destDir, localConfArchive, replication, force = true,
+ destName = Some(LOCALIZED_CONF_ARCHIVE))
+
+ // Manually add the config archive to the cache manager so that the AM is launched with
+ // the proper files set up.
+ distCacheMgr.addResource(
+ remoteFs, hadoopConf, remoteConfArchivePath, localResources, LocalResourceType.ARCHIVE,
+ LOCALIZED_CONF_DIR, statCache, appMasterOnly = false)
+
+ // Clear the cache-related entries from the configuration to avoid them polluting the
+ // UI's environment page. This works for client mode; for cluster mode, this is handled
+ // by the AM.
+ CACHE_CONFIGS.foreach(sparkConf.remove)
+
+ localResources
+ }
+
+ /**
+ * Create an archive with the config files for distribution.
+ *
+ * These will be used by AM and executors. The files are zipped and added to the job as an
+ * archive, so that YARN will explode it when distributing to AM and executors. This directory
+ * is then added to the classpath of AM and executor process, just to make sure that everybody
+ * is using the same default config.
+ *
+ * This follows the order of precedence set by the startup scripts, in which HADOOP_CONF_DIR
+ * shows up in the classpath before YARN_CONF_DIR.
+ *
+ * Currently this makes a shallow copy of the conf directory. If there are cases where a
+ * Hadoop config directory contains subdirectories, this code will have to be fixed.
+ *
+ * The archive also contains some Spark configuration. Namely, it saves the contents of
+ * SparkConf in a file to be loaded by the AM process.
+ */
+ private def createConfArchive(): File = {
+ val hadoopConfFiles = new HashMap[String, File]()
+
+ // Uploading $SPARK_CONF_DIR/log4j.properties file to the distributed cache to make sure that
+ // the executors will use the latest configurations instead of the default values. This is
+ // required when user changes log4j.properties directly to set the log configurations. If
+ // configuration file is provided through --files then executors will be taking configurations
+ // from --files instead of $SPARK_CONF_DIR/log4j.properties.
+
+ // Also uploading metrics.properties to distributed cache if exists in classpath.
+ // If user specify this file using --files then executors will use the one
+ // from --files instead.
+ for { prop <- Seq("log4j.properties", "metrics.properties")
+ url <- Option(Utils.getContextOrSparkClassLoader.getResource(prop))
+ if url.getProtocol == "file" } {
+ hadoopConfFiles(prop) = new File(url.getPath)
+ }
+
+ Seq("HADOOP_CONF_DIR", "YARN_CONF_DIR").foreach { envKey =>
+ sys.env.get(envKey).foreach { path =>
+ val dir = new File(path)
+ if (dir.isDirectory()) {
+ val files = dir.listFiles()
+ if (files == null) {
+ logWarning("Failed to list files under directory " + dir)
+ } else {
+ files.foreach { file =>
+ if (file.isFile && !hadoopConfFiles.contains(file.getName())) {
+ hadoopConfFiles(file.getName()) = file
+ }
+ }
+ }
+ }
+ }
+ }
+
+ val confArchive = File.createTempFile(LOCALIZED_CONF_DIR, ".zip",
+ new File(Utils.getLocalDir(sparkConf)))
+ val confStream = new ZipOutputStream(new FileOutputStream(confArchive))
+
+ try {
+ confStream.setLevel(0)
+ hadoopConfFiles.foreach { case (name, file) =>
+ if (file.canRead()) {
+ confStream.putNextEntry(new ZipEntry(name))
+ Files.copy(file, confStream)
+ confStream.closeEntry()
+ }
+ }
+
+ // Save Spark configuration to a file in the archive.
+ val props = new Properties()
+ sparkConf.getAll.foreach { case (k, v) => props.setProperty(k, v) }
+ confStream.putNextEntry(new ZipEntry(SPARK_CONF_FILE))
+ val writer = new OutputStreamWriter(confStream, StandardCharsets.UTF_8)
+ props.store(writer, "Spark configuration.")
+ writer.flush()
+ confStream.closeEntry()
+ } finally {
+ confStream.close()
+ }
+ confArchive
+ }
+
+ /**
+ * Set up the environment for launching our ApplicationMaster container.
+ */
+ private def setupLaunchEnv(
+ stagingDirPath: Path,
+ pySparkArchives: Seq[String]): HashMap[String, String] = {
+ logInfo("Setting up the launch environment for our AM container")
+ val env = new HashMap[String, String]()
+ populateClasspath(args, yarnConf, sparkConf, env, sparkConf.get(DRIVER_CLASS_PATH))
+ env("SPARK_YARN_MODE") = "true"
+ env("SPARK_YARN_STAGING_DIR") = stagingDirPath.toString
+ env("SPARK_USER") = UserGroupInformation.getCurrentUser().getShortUserName()
+ if (loginFromKeytab) {
+ val credentialsFile = "credentials-" + UUID.randomUUID().toString
+ sparkConf.set(CREDENTIALS_FILE_PATH, new Path(stagingDirPath, credentialsFile).toString)
+ logInfo(s"Credentials file set to: $credentialsFile")
+ }
+
+ // Pick up any environment variables for the AM provided through spark.yarn.appMasterEnv.*
+ val amEnvPrefix = "spark.yarn.appMasterEnv."
+ sparkConf.getAll
+ .filter { case (k, v) => k.startsWith(amEnvPrefix) }
+ .map { case (k, v) => (k.substring(amEnvPrefix.length), v) }
+ .foreach { case (k, v) => YarnSparkHadoopUtil.addPathToEnvironment(env, k, v) }
+
+ // Keep this for backwards compatibility but users should move to the config
+ sys.env.get("SPARK_YARN_USER_ENV").foreach { userEnvs =>
+ // Allow users to specify some environment variables.
+ YarnSparkHadoopUtil.setEnvFromInputString(env, userEnvs)
+ // Pass SPARK_YARN_USER_ENV itself to the AM so it can use it to set up executor environments.
+ env("SPARK_YARN_USER_ENV") = userEnvs
+ }
+
+ // If pyFiles contains any .py files, we need to add LOCALIZED_PYTHON_DIR to the PYTHONPATH
+ // of the container processes too. Add all non-.py files directly to PYTHONPATH.
+ //
+ // NOTE: the code currently does not handle .py files defined with a "local:" scheme.
+ val pythonPath = new ListBuffer[String]()
+ val (pyFiles, pyArchives) = sparkConf.get(PY_FILES).partition(_.endsWith(".py"))
+ if (pyFiles.nonEmpty) {
+ pythonPath += buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD),
+ LOCALIZED_PYTHON_DIR)
+ }
+ (pySparkArchives ++ pyArchives).foreach { path =>
+ val uri = Utils.resolveURI(path)
+ if (uri.getScheme != LOCAL_SCHEME) {
+ pythonPath += buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD),
+ new Path(uri).getName())
+ } else {
+ pythonPath += uri.getPath()
+ }
+ }
+
+ // Finally, update the Spark config to propagate PYTHONPATH to the AM and executors.
+ if (pythonPath.nonEmpty) {
+ val pythonPathStr = (sys.env.get("PYTHONPATH") ++ pythonPath)
+ .mkString(YarnSparkHadoopUtil.getClassPathSeparator)
+ env("PYTHONPATH") = pythonPathStr
+ sparkConf.setExecutorEnv("PYTHONPATH", pythonPathStr)
+ }
+
+ // In cluster mode, if the deprecated SPARK_JAVA_OPTS is set, we need to propagate it to
+ // executors. But we can't just set spark.executor.extraJavaOptions, because the driver's
+ // SparkContext will not let that set spark* system properties, which is expected behavior for
+ // Yarn clients. So propagate it through the environment.
+ //
+ // Note that to warn the user about the deprecation in cluster mode, some code from
+ // SparkConf#validateSettings() is duplicated here (to avoid triggering the condition
+ // described above).
+ if (isClusterMode) {
+ sys.env.get("SPARK_JAVA_OPTS").foreach { value =>
+ val warning =
+ s"""
+ |SPARK_JAVA_OPTS was detected (set to '$value').
+ |This is deprecated in Spark 1.0+.
+ |
+ |Please instead use:
+ | - ./spark-submit with conf/spark-defaults.conf to set defaults for an application
+ | - ./spark-submit with --driver-java-options to set -X options for a driver
+ | - spark.executor.extraJavaOptions to set -X options for executors
+ """.stripMargin
+ logWarning(warning)
+ for (proc <- Seq("driver", "executor")) {
+ val key = s"spark.$proc.extraJavaOptions"
+ if (sparkConf.contains(key)) {
+ throw new SparkException(s"Found both $key and SPARK_JAVA_OPTS. Use only the former.")
+ }
+ }
+ env("SPARK_JAVA_OPTS") = value
+ }
+ // propagate PYSPARK_DRIVER_PYTHON and PYSPARK_PYTHON to driver in cluster mode
+ Seq("PYSPARK_DRIVER_PYTHON", "PYSPARK_PYTHON").foreach { envname =>
+ if (!env.contains(envname)) {
+ sys.env.get(envname).foreach(env(envname) = _)
+ }
+ }
+ }
+
+ sys.env.get(ENV_DIST_CLASSPATH).foreach { dcp =>
+ env(ENV_DIST_CLASSPATH) = dcp
+ }
+
+ env
+ }
+
+ /**
+ * Set up a ContainerLaunchContext to launch our ApplicationMaster container.
+ * This sets up the launch environment, java options, and the command for launching the AM.
+ */
+ private def createContainerLaunchContext(newAppResponse: GetNewApplicationResponse)
+ : ContainerLaunchContext = {
+ logInfo("Setting up container launch context for our AM")
+ val appId = newAppResponse.getApplicationId
+ val appStagingDirPath = new Path(appStagingBaseDir, getAppStagingDir(appId))
+ val pySparkArchives =
+ if (sparkConf.get(IS_PYTHON_APP)) {
+ findPySparkArchives()
+ } else {
+ Nil
+ }
+ val launchEnv = setupLaunchEnv(appStagingDirPath, pySparkArchives)
+ val localResources = prepareLocalResources(appStagingDirPath, pySparkArchives)
+
+ val amContainer = Records.newRecord(classOf[ContainerLaunchContext])
+ amContainer.setLocalResources(localResources.asJava)
+ amContainer.setEnvironment(launchEnv.asJava)
+
+ val javaOpts = ListBuffer[String]()
+
+ // Set the environment variable through a command prefix
+ // to append to the existing value of the variable
+ var prefixEnv: Option[String] = None
+
+ // Add Xmx for AM memory
+ javaOpts += "-Xmx" + amMemory + "m"
+
+ val tmpDir = new Path(
+ YarnSparkHadoopUtil.expandEnvironment(Environment.PWD),
+ YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR
+ )
+ javaOpts += "-Djava.io.tmpdir=" + tmpDir
+
+ // TODO: Remove once cpuset version is pushed out.
+ // The context is, default gc for server class machines ends up using all cores to do gc -
+ // hence if there are multiple containers in same node, Spark GC affects all other containers'
+ // performance (which can be that of other Spark containers)
+ // Instead of using this, rely on cpusets by YARN to enforce "proper" Spark behavior in
+ // multi-tenant environments. Not sure how default Java GC behaves if it is limited to subset
+ // of cores on a node.
+ val useConcurrentAndIncrementalGC = launchEnv.get("SPARK_USE_CONC_INCR_GC").exists(_.toBoolean)
+ if (useConcurrentAndIncrementalGC) {
+ // In our expts, using (default) throughput collector has severe perf ramifications in
+ // multi-tenant machines
+ javaOpts += "-XX:+UseConcMarkSweepGC"
+ javaOpts += "-XX:MaxTenuringThreshold=31"
+ javaOpts += "-XX:SurvivorRatio=8"
+ javaOpts += "-XX:+CMSIncrementalMode"
+ javaOpts += "-XX:+CMSIncrementalPacing"
+ javaOpts += "-XX:CMSIncrementalDutyCycleMin=0"
+ javaOpts += "-XX:CMSIncrementalDutyCycle=10"
+ }
+
+ // Include driver-specific java options if we are launching a driver
+ if (isClusterMode) {
+ val driverOpts = sparkConf.get(DRIVER_JAVA_OPTIONS).orElse(sys.env.get("SPARK_JAVA_OPTS"))
+ driverOpts.foreach { opts =>
+ javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell)
+ }
+ val libraryPaths = Seq(sparkConf.get(DRIVER_LIBRARY_PATH),
+ sys.props.get("spark.driver.libraryPath")).flatten
+ if (libraryPaths.nonEmpty) {
+ prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(libraryPaths)))
+ }
+ if (sparkConf.get(AM_JAVA_OPTIONS).isDefined) {
+ logWarning(s"${AM_JAVA_OPTIONS.key} will not take effect in cluster mode")
+ }
+ } else {
+ // Validate and include yarn am specific java options in yarn-client mode.
+ sparkConf.get(AM_JAVA_OPTIONS).foreach { opts =>
+ if (opts.contains("-Dspark")) {
+ val msg = s"${AM_JAVA_OPTIONS.key} is not allowed to set Spark options (was '$opts')."
+ throw new SparkException(msg)
+ }
+ if (opts.contains("-Xmx")) {
+ val msg = s"${AM_JAVA_OPTIONS.key} is not allowed to specify max heap memory settings " +
+ s"(was '$opts'). Use spark.yarn.am.memory instead."
+ throw new SparkException(msg)
+ }
+ javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell)
+ }
+ sparkConf.get(AM_LIBRARY_PATH).foreach { paths =>
+ prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(paths))))
+ }
+ }
+
+ // For log4j configuration to reference
+ javaOpts += ("-Dspark.yarn.app.container.log.dir=" + ApplicationConstants.LOG_DIR_EXPANSION_VAR)
+ YarnCommandBuilderUtils.addPermGenSizeOpt(javaOpts)
+
+ val userClass =
+ if (isClusterMode) {
+ Seq("--class", YarnSparkHadoopUtil.escapeForShell(args.userClass))
+ } else {
+ Nil
+ }
+ val userJar =
+ if (args.userJar != null) {
+ Seq("--jar", args.userJar)
+ } else {
+ Nil
+ }
+ val primaryPyFile =
+ if (isClusterMode && args.primaryPyFile != null) {
+ Seq("--primary-py-file", new Path(args.primaryPyFile).getName())
+ } else {
+ Nil
+ }
+ val primaryRFile =
+ if (args.primaryRFile != null) {
+ Seq("--primary-r-file", args.primaryRFile)
+ } else {
+ Nil
+ }
+ val amClass =
+ if (isClusterMode) {
+ Utils.classForName("org.apache.spark.deploy.yarn.ApplicationMaster").getName
+ } else {
+ Utils.classForName("org.apache.spark.deploy.yarn.ExecutorLauncher").getName
+ }
+ if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) {
+ args.userArgs = ArrayBuffer(args.primaryRFile) ++ args.userArgs
+ }
+ val userArgs = args.userArgs.flatMap { arg =>
+ Seq("--arg", YarnSparkHadoopUtil.escapeForShell(arg))
+ }
+ val amArgs =
+ Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ primaryRFile ++
+ userArgs ++ Seq(
+ "--properties-file", buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD),
+ LOCALIZED_CONF_DIR, SPARK_CONF_FILE))
+
+ // Command for the ApplicationMaster
+ val commands = prefixEnv ++ Seq(
+ YarnSparkHadoopUtil.expandEnvironment(Environment.JAVA_HOME) + "/bin/java", "-server"
+ ) ++
+ javaOpts ++ amArgs ++
+ Seq(
+ "1>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout",
+ "2>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr")
+
+ // TODO: it would be nicer to just make sure there are no null commands here
+ val printableCommands = commands.map(s => if (s == null) "null" else s).toList
+ amContainer.setCommands(printableCommands.asJava)
+
+ logDebug("===============================================================================")
+ logDebug("YARN AM launch context:")
+ logDebug(s" user class: ${Option(args.userClass).getOrElse("N/A")}")
+ logDebug(" env:")
+ launchEnv.foreach { case (k, v) => logDebug(s" $k -> $v") }
+ logDebug(" resources:")
+ localResources.foreach { case (k, v) => logDebug(s" $k -> $v")}
+ logDebug(" command:")
+ logDebug(s" ${printableCommands.mkString(" ")}")
+ logDebug("===============================================================================")
+
+ // send the acl settings into YARN to control who has access via YARN interfaces
+ val securityManager = new SecurityManager(sparkConf)
+ amContainer.setApplicationACLs(
+ YarnSparkHadoopUtil.getApplicationAclsForYarn(securityManager).asJava)
+ setupSecurityToken(amContainer)
+ amContainer
+ }
+
+ def setupCredentials(): Unit = {
+ loginFromKeytab = sparkConf.contains(PRINCIPAL.key)
+ if (loginFromKeytab) {
+ principal = sparkConf.get(PRINCIPAL).get
+ keytab = sparkConf.get(KEYTAB).orNull
+
+ require(keytab != null, "Keytab must be specified when principal is specified.")
+ logInfo("Attempting to login to the Kerberos" +
+ s" using principal: $principal and keytab: $keytab")
+ val f = new File(keytab)
+ // Generate a file name that can be used for the keytab file, that does not conflict
+ // with any user file.
+ val keytabFileName = f.getName + "-" + UUID.randomUUID().toString
+ sparkConf.set(KEYTAB.key, keytabFileName)
+ sparkConf.set(PRINCIPAL.key, principal)
+ }
+ // Defensive copy of the credentials
+ credentials = new Credentials(UserGroupInformation.getCurrentUser.getCredentials)
+ }
+
+ /**
+ * Report the state of an application until it has exited, either successfully or
+ * due to some failure, then return a pair of the yarn application state (FINISHED, FAILED,
+ * KILLED, or RUNNING) and the final application state (UNDEFINED, SUCCEEDED, FAILED,
+ * or KILLED).
+ *
+ * @param appId ID of the application to monitor.
+ * @param returnOnRunning Whether to also return the application state when it is RUNNING.
+ * @param logApplicationReport Whether to log details of the application report every iteration.
+ * @return A pair of the yarn application state and the final application state.
+ */
+ def monitorApplication(
+ appId: ApplicationId,
+ returnOnRunning: Boolean = false,
+ logApplicationReport: Boolean = true): (YarnApplicationState, FinalApplicationStatus) = {
+ val interval = sparkConf.get(REPORT_INTERVAL)
+ var lastState: YarnApplicationState = null
+ while (true) {
+ Thread.sleep(interval)
+ val report: ApplicationReport =
+ try {
+ getApplicationReport(appId)
+ } catch {
+ case e: ApplicationNotFoundException =>
+ logError(s"Application $appId not found.")
+ cleanupStagingDir(appId)
+ return (YarnApplicationState.KILLED, FinalApplicationStatus.KILLED)
+ case NonFatal(e) =>
+ logError(s"Failed to contact YARN for application $appId.", e)
+ // Don't necessarily clean up staging dir because status is unknown
+ return (YarnApplicationState.FAILED, FinalApplicationStatus.FAILED)
+ }
+ val state = report.getYarnApplicationState
+
+ if (logApplicationReport) {
+ logInfo(s"Application report for $appId (state: $state)")
+
+ // If DEBUG is enabled, log report details every iteration
+ // Otherwise, log them every time the application changes state
+ if (log.isDebugEnabled) {
+ logDebug(formatReportDetails(report))
+ } else if (lastState != state) {
+ logInfo(formatReportDetails(report))
+ }
+ }
+
+ if (lastState != state) {
+ state match {
+ case YarnApplicationState.RUNNING =>
+ reportLauncherState(SparkAppHandle.State.RUNNING)
+ case YarnApplicationState.FINISHED =>
+ report.getFinalApplicationStatus match {
+ case FinalApplicationStatus.FAILED =>
+ reportLauncherState(SparkAppHandle.State.FAILED)
+ case FinalApplicationStatus.KILLED =>
+ reportLauncherState(SparkAppHandle.State.KILLED)
+ case _ =>
+ reportLauncherState(SparkAppHandle.State.FINISHED)
+ }
+ case YarnApplicationState.FAILED =>
+ reportLauncherState(SparkAppHandle.State.FAILED)
+ case YarnApplicationState.KILLED =>
+ reportLauncherState(SparkAppHandle.State.KILLED)
+ case _ =>
+ }
+ }
+
+ if (state == YarnApplicationState.FINISHED ||
+ state == YarnApplicationState.FAILED ||
+ state == YarnApplicationState.KILLED) {
+ cleanupStagingDir(appId)
+ return (state, report.getFinalApplicationStatus)
+ }
+
+ if (returnOnRunning && state == YarnApplicationState.RUNNING) {
+ return (state, report.getFinalApplicationStatus)
+ }
+
+ lastState = state
+ }
+
+ // Never reached, but keeps compiler happy
+ throw new SparkException("While loop is depleted! This should never happen...")
+ }
+
+ private def formatReportDetails(report: ApplicationReport): String = {
+ val details = Seq[(String, String)](
+ ("client token", getClientToken(report)),
+ ("diagnostics", report.getDiagnostics),
+ ("ApplicationMaster host", report.getHost),
+ ("ApplicationMaster RPC port", report.getRpcPort.toString),
+ ("queue", report.getQueue),
+ ("start time", report.getStartTime.toString),
+ ("final status", report.getFinalApplicationStatus.toString),
+ ("tracking URL", report.getTrackingUrl),
+ ("user", report.getUser)
+ )
+
+ // Use more loggable format if value is null or empty
+ details.map { case (k, v) =>
+ val newValue = Option(v).filter(_.nonEmpty).getOrElse("N/A")
+ s"\n\t $k: $newValue"
+ }.mkString("")
+ }
+
+ /**
+ * Submit an application to the ResourceManager.
+ * If set spark.yarn.submit.waitAppCompletion to true, it will stay alive
+ * reporting the application's status until the application has exited for any reason.
+ * Otherwise, the client process will exit after submission.
+ * If the application finishes with a failed, killed, or undefined status,
+ * throw an appropriate SparkException.
+ */
+ def run(): Unit = {
+ this.appId = submitApplication()
+ if (!launcherBackend.isConnected() && fireAndForget) {
+ val report = getApplicationReport(appId)
+ val state = report.getYarnApplicationState
+ logInfo(s"Application report for $appId (state: $state)")
+ logInfo(formatReportDetails(report))
+ if (state == YarnApplicationState.FAILED || state == YarnApplicationState.KILLED) {
+ throw new SparkException(s"Application $appId finished with status: $state")
+ }
+ } else {
+ val (yarnApplicationState, finalApplicationStatus) = monitorApplication(appId)
+ if (yarnApplicationState == YarnApplicationState.FAILED ||
+ finalApplicationStatus == FinalApplicationStatus.FAILED) {
+ throw new SparkException(s"Application $appId finished with failed status")
+ }
+ if (yarnApplicationState == YarnApplicationState.KILLED ||
+ finalApplicationStatus == FinalApplicationStatus.KILLED) {
+ throw new SparkException(s"Application $appId is killed")
+ }
+ if (finalApplicationStatus == FinalApplicationStatus.UNDEFINED) {
+ throw new SparkException(s"The final status of application $appId is undefined")
+ }
+ }
+ }
+
+ private def findPySparkArchives(): Seq[String] = {
+ sys.env.get("PYSPARK_ARCHIVES_PATH")
+ .map(_.split(",").toSeq)
+ .getOrElse {
+ val pyLibPath = Seq(sys.env("SPARK_HOME"), "python", "lib").mkString(File.separator)
+ val pyArchivesFile = new File(pyLibPath, "pyspark.zip")
+ require(pyArchivesFile.exists(),
+ s"$pyArchivesFile not found; cannot run pyspark application in YARN mode.")
+ val py4jFile = new File(pyLibPath, "py4j-0.10.4-src.zip")
+ require(py4jFile.exists(),
+ s"$py4jFile not found; cannot run pyspark application in YARN mode.")
+ Seq(pyArchivesFile.getAbsolutePath(), py4jFile.getAbsolutePath())
+ }
+ }
+
+}
+
+private object Client extends Logging {
+
+ def main(argStrings: Array[String]) {
+ if (!sys.props.contains("SPARK_SUBMIT")) {
+ logWarning("WARNING: This client is deprecated and will be removed in a " +
+ "future version of Spark. Use ./bin/spark-submit with \"--master yarn\"")
+ }
+
+ // Set an env variable indicating we are running in YARN mode.
+ // Note that any env variable with the SPARK_ prefix gets propagated to all (remote) processes
+ System.setProperty("SPARK_YARN_MODE", "true")
+ val sparkConf = new SparkConf
+ // SparkSubmit would use yarn cache to distribute files & jars in yarn mode,
+ // so remove them from sparkConf here for yarn mode.
+ sparkConf.remove("spark.jars")
+ sparkConf.remove("spark.files")
+ val args = new ClientArguments(argStrings)
+ new Client(args, sparkConf).run()
+ }
+
+ // Alias for the user jar
+ val APP_JAR_NAME: String = "__app__.jar"
+
+ // URI scheme that identifies local resources
+ val LOCAL_SCHEME = "local"
+
+ // Staging directory for any temporary jars or files
+ val SPARK_STAGING: String = ".sparkStaging"
+
+
+ // Staging directory is private! -> rwx--------
+ val STAGING_DIR_PERMISSION: FsPermission =
+ FsPermission.createImmutable(Integer.parseInt("700", 8).toShort)
+
+ // App files are world-wide readable and owner writable -> rw-r--r--
+ val APP_FILE_PERMISSION: FsPermission =
+ FsPermission.createImmutable(Integer.parseInt("644", 8).toShort)
+
+ // Distribution-defined classpath to add to processes
+ val ENV_DIST_CLASSPATH = "SPARK_DIST_CLASSPATH"
+
+ // Subdirectory where the user's Spark and Hadoop config files will be placed.
+ val LOCALIZED_CONF_DIR = "__spark_conf__"
+
+ // File containing the conf archive in the AM. See prepareLocalResources().
+ val LOCALIZED_CONF_ARCHIVE = LOCALIZED_CONF_DIR + ".zip"
+
+ // Name of the file in the conf archive containing Spark configuration.
+ val SPARK_CONF_FILE = "__spark_conf__.properties"
+
+ // Subdirectory where the user's python files (not archives) will be placed.
+ val LOCALIZED_PYTHON_DIR = "__pyfiles__"
+
+ // Subdirectory where Spark libraries will be placed.
+ val LOCALIZED_LIB_DIR = "__spark_libs__"
+
+ /**
+ * Return the path to the given application's staging directory.
+ */
+ private def getAppStagingDir(appId: ApplicationId): String = {
+ buildPath(SPARK_STAGING, appId.toString())
+ }
+
+ /**
+ * Populate the classpath entry in the given environment map with any application
+ * classpath specified through the Hadoop and Yarn configurations.
+ */
+ private[yarn] def populateHadoopClasspath(conf: Configuration, env: HashMap[String, String])
+ : Unit = {
+ val classPathElementsToAdd = getYarnAppClasspath(conf) ++ getMRAppClasspath(conf)
+ for (c <- classPathElementsToAdd.flatten) {
+ YarnSparkHadoopUtil.addPathToEnvironment(env, Environment.CLASSPATH.name, c.trim)
+ }
+ }
+
+ private def getYarnAppClasspath(conf: Configuration): Option[Seq[String]] =
+ Option(conf.getStrings(YarnConfiguration.YARN_APPLICATION_CLASSPATH)) match {
+ case Some(s) => Some(s.toSeq)
+ case None => getDefaultYarnApplicationClasspath
+ }
+
+ private def getMRAppClasspath(conf: Configuration): Option[Seq[String]] =
+ Option(conf.getStrings("mapreduce.application.classpath")) match {
+ case Some(s) => Some(s.toSeq)
+ case None => getDefaultMRApplicationClasspath
+ }
+
+ private[yarn] def getDefaultYarnApplicationClasspath: Option[Seq[String]] = {
+ val triedDefault = Try[Seq[String]] {
+ val field = classOf[YarnConfiguration].getField("DEFAULT_YARN_APPLICATION_CLASSPATH")
+ val value = field.get(null).asInstanceOf[Array[String]]
+ value.toSeq
+ } recoverWith {
+ case e: NoSuchFieldException => Success(Seq.empty[String])
+ }
+
+ triedDefault match {
+ case f: Failure[_] =>
+ logError("Unable to obtain the default YARN Application classpath.", f.exception)
+ case s: Success[Seq[String]] =>
+ logDebug(s"Using the default YARN application classpath: ${s.get.mkString(",")}")
+ }
+
+ triedDefault.toOption
+ }
+
+ private[yarn] def getDefaultMRApplicationClasspath: Option[Seq[String]] = {
+ val triedDefault = Try[Seq[String]] {
+ val field = classOf[MRJobConfig].getField("DEFAULT_MAPREDUCE_APPLICATION_CLASSPATH")
+ StringUtils.getStrings(field.get(null).asInstanceOf[String]).toSeq
+ } recoverWith {
+ case e: NoSuchFieldException => Success(Seq.empty[String])
+ }
+
+ triedDefault match {
+ case f: Failure[_] =>
+ logError("Unable to obtain the default MR Application classpath.", f.exception)
+ case s: Success[Seq[String]] =>
+ logDebug(s"Using the default MR application classpath: ${s.get.mkString(",")}")
+ }
+
+ triedDefault.toOption
+ }
+
+ /**
+ * Populate the classpath entry in the given environment map.
+ *
+ * User jars are generally not added to the JVM's system classpath; those are handled by the AM
+ * and executor backend. When the deprecated `spark.yarn.user.classpath.first` is used, user jars
+ * are included in the system classpath, though. The extra class path and other uploaded files are
+ * always made available through the system class path.
+ *
+ * @param args Client arguments (when starting the AM) or null (when starting executors).
+ */
+ private[yarn] def populateClasspath(
+ args: ClientArguments,
+ conf: Configuration,
+ sparkConf: SparkConf,
+ env: HashMap[String, String],
+ extraClassPath: Option[String] = None): Unit = {
+ extraClassPath.foreach { cp =>
+ addClasspathEntry(getClusterPath(sparkConf, cp), env)
+ }
+
+ addClasspathEntry(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), env)
+
+ addClasspathEntry(
+ YarnSparkHadoopUtil.expandEnvironment(Environment.PWD) + Path.SEPARATOR +
+ LOCALIZED_CONF_DIR, env)
+
+ if (sparkConf.get(USER_CLASS_PATH_FIRST)) {
+ // in order to properly add the app jar when user classpath is first
+ // we have to do the mainJar separate in order to send the right thing
+ // into addFileToClasspath
+ val mainJar =
+ if (args != null) {
+ getMainJarUri(Option(args.userJar))
+ } else {
+ getMainJarUri(sparkConf.get(APP_JAR))
+ }
+ mainJar.foreach(addFileToClasspath(sparkConf, conf, _, APP_JAR_NAME, env))
+
+ val secondaryJars =
+ if (args != null) {
+ getSecondaryJarUris(Option(sparkConf.get(JARS_TO_DISTRIBUTE)))
+ } else {
+ getSecondaryJarUris(sparkConf.get(SECONDARY_JARS))
+ }
+ secondaryJars.foreach { x =>
+ addFileToClasspath(sparkConf, conf, x, null, env)
+ }
+ }
+
+ // Add the Spark jars to the classpath, depending on how they were distributed.
+ addClasspathEntry(buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD),
+ LOCALIZED_LIB_DIR, "*"), env)
+ if (!sparkConf.get(SPARK_ARCHIVE).isDefined) {
+ sparkConf.get(SPARK_JARS).foreach { jars =>
+ jars.filter(isLocalUri).foreach { jar =>
+ addClasspathEntry(getClusterPath(sparkConf, jar), env)
+ }
+ }
+ }
+
+ populateHadoopClasspath(conf, env)
+ sys.env.get(ENV_DIST_CLASSPATH).foreach { cp =>
+ addClasspathEntry(getClusterPath(sparkConf, cp), env)
+ }
+ }
+
+ /**
+ * Returns a list of URIs representing the user classpath.
+ *
+ * @param conf Spark configuration.
+ */
+ def getUserClasspath(conf: SparkConf): Array[URI] = {
+ val mainUri = getMainJarUri(conf.get(APP_JAR))
+ val secondaryUris = getSecondaryJarUris(conf.get(SECONDARY_JARS))
+ (mainUri ++ secondaryUris).toArray
+ }
+
+ private def getMainJarUri(mainJar: Option[String]): Option[URI] = {
+ mainJar.flatMap { path =>
+ val uri = Utils.resolveURI(path)
+ if (uri.getScheme == LOCAL_SCHEME) Some(uri) else None
+ }.orElse(Some(new URI(APP_JAR_NAME)))
+ }
+
+ private def getSecondaryJarUris(secondaryJars: Option[Seq[String]]): Seq[URI] = {
+ secondaryJars.getOrElse(Nil).map(new URI(_))
+ }
+
+ /**
+ * Adds the given path to the classpath, handling "local:" URIs correctly.
+ *
+ * If an alternate name for the file is given, and it's not a "local:" file, the alternate
+ * name will be added to the classpath (relative to the job's work directory).
+ *
+ * If not a "local:" file and no alternate name, the linkName will be added to the classpath.
+ *
+ * @param conf Spark configuration.
+ * @param hadoopConf Hadoop configuration.
+ * @param uri URI to add to classpath (optional).
+ * @param fileName Alternate name for the file (optional).
+ * @param env Map holding the environment variables.
+ */
+ private def addFileToClasspath(
+ conf: SparkConf,
+ hadoopConf: Configuration,
+ uri: URI,
+ fileName: String,
+ env: HashMap[String, String]): Unit = {
+ if (uri != null && uri.getScheme == LOCAL_SCHEME) {
+ addClasspathEntry(getClusterPath(conf, uri.getPath), env)
+ } else if (fileName != null) {
+ addClasspathEntry(buildPath(
+ YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), fileName), env)
+ } else if (uri != null) {
+ val localPath = getQualifiedLocalPath(uri, hadoopConf)
+ val linkName = Option(uri.getFragment()).getOrElse(localPath.getName())
+ addClasspathEntry(buildPath(
+ YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), linkName), env)
+ }
+ }
+
+ /**
+ * Add the given path to the classpath entry of the given environment map.
+ * If the classpath is already set, this appends the new path to the existing classpath.
+ */
+ private def addClasspathEntry(path: String, env: HashMap[String, String]): Unit =
+ YarnSparkHadoopUtil.addPathToEnvironment(env, Environment.CLASSPATH.name, path)
+
+ /**
+ * Returns the path to be sent to the NM for a path that is valid on the gateway.
+ *
+ * This method uses two configuration values:
+ *
+ * - spark.yarn.config.gatewayPath: a string that identifies a portion of the input path that may
+ * only be valid in the gateway node.
+ * - spark.yarn.config.replacementPath: a string with which to replace the gateway path. This may
+ * contain, for example, env variable references, which will be expanded by the NMs when
+ * starting containers.
+ *
+ * If either config is not available, the input path is returned.
+ */
+ def getClusterPath(conf: SparkConf, path: String): String = {
+ val localPath = conf.get(GATEWAY_ROOT_PATH)
+ val clusterPath = conf.get(REPLACEMENT_ROOT_PATH)
+ if (localPath != null && clusterPath != null) {
+ path.replace(localPath, clusterPath)
+ } else {
+ path
+ }
+ }
+
+ /**
+ * Return whether the two file systems are the same.
+ */
+ private def compareFs(srcFs: FileSystem, destFs: FileSystem): Boolean = {
+ val srcUri = srcFs.getUri()
+ val dstUri = destFs.getUri()
+ if (srcUri.getScheme() == null || srcUri.getScheme() != dstUri.getScheme()) {
+ return false
+ }
+
+ var srcHost = srcUri.getHost()
+ var dstHost = dstUri.getHost()
+
+ // In HA or when using viewfs, the host part of the URI may not actually be a host, but the
+ // name of the HDFS namespace. Those names won't resolve, so avoid even trying if they
+ // match.
+ if (srcHost != null && dstHost != null && srcHost != dstHost) {
+ try {
+ srcHost = InetAddress.getByName(srcHost).getCanonicalHostName()
+ dstHost = InetAddress.getByName(dstHost).getCanonicalHostName()
+ } catch {
+ case e: UnknownHostException =>
+ return false
+ }
+ }
+
+ Objects.equal(srcHost, dstHost) && srcUri.getPort() == dstUri.getPort()
+ }
+
+ /**
+ * Given a local URI, resolve it and return a qualified local path that corresponds to the URI.
+ * This is used for preparing local resources to be included in the container launch context.
+ */
+ private def getQualifiedLocalPath(localURI: URI, hadoopConf: Configuration): Path = {
+ val qualifiedURI =
+ if (localURI.getScheme == null) {
+ // If not specified, assume this is in the local filesystem to keep the behavior
+ // consistent with that of Hadoop
+ new URI(FileSystem.getLocal(hadoopConf).makeQualified(new Path(localURI)).toString)
+ } else {
+ localURI
+ }
+ new Path(qualifiedURI)
+ }
+
+ /**
+ * Whether to consider jars provided by the user to have precedence over the Spark jars when
+ * loading user classes.
+ */
+ def isUserClassPathFirst(conf: SparkConf, isDriver: Boolean): Boolean = {
+ if (isDriver) {
+ conf.get(DRIVER_USER_CLASS_PATH_FIRST)
+ } else {
+ conf.get(EXECUTOR_USER_CLASS_PATH_FIRST)
+ }
+ }
+
+ /**
+ * Joins all the path components using Path.SEPARATOR.
+ */
+ def buildPath(components: String*): String = {
+ components.mkString(Path.SEPARATOR)
+ }
+
+ /** Returns whether the URI is a "local:" URI. */
+ def isLocalUri(uri: String): Boolean = {
+ uri.startsWith(s"$LOCAL_SCHEME:")
+ }
+
+}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
new file mode 100644
index 0000000000..61c027ec44
--- /dev/null
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import scala.collection.mutable.ArrayBuffer
+
+// TODO: Add code and support for ensuring that yarn resource 'tasks' are location aware !
+private[spark] class ClientArguments(args: Array[String]) {
+
+ var userJar: String = null
+ var userClass: String = null
+ var primaryPyFile: String = null
+ var primaryRFile: String = null
+ var userArgs: ArrayBuffer[String] = new ArrayBuffer[String]()
+
+ parseArgs(args.toList)
+
+ private def parseArgs(inputArgs: List[String]): Unit = {
+ var args = inputArgs
+
+ while (!args.isEmpty) {
+ args match {
+ case ("--jar") :: value :: tail =>
+ userJar = value
+ args = tail
+
+ case ("--class") :: value :: tail =>
+ userClass = value
+ args = tail
+
+ case ("--primary-py-file") :: value :: tail =>
+ primaryPyFile = value
+ args = tail
+
+ case ("--primary-r-file") :: value :: tail =>
+ primaryRFile = value
+ args = tail
+
+ case ("--arg") :: value :: tail =>
+ userArgs += value
+ args = tail
+
+ case Nil =>
+
+ case _ =>
+ throw new IllegalArgumentException(getUsageMessage(args))
+ }
+ }
+
+ if (primaryPyFile != null && primaryRFile != null) {
+ throw new IllegalArgumentException("Cannot have primary-py-file and primary-r-file" +
+ " at the same time")
+ }
+ }
+
+ private def getUsageMessage(unknownParam: List[String] = null): String = {
+ val message = if (unknownParam != null) s"Unknown/unsupported param $unknownParam\n" else ""
+ message +
+ s"""
+ |Usage: org.apache.spark.deploy.yarn.Client [options]
+ |Options:
+ | --jar JAR_PATH Path to your application's JAR file (required in yarn-cluster
+ | mode)
+ | --class CLASS_NAME Name of your application's main class (required)
+ | --primary-py-file A main Python file
+ | --primary-r-file A main R file
+ | --arg ARG Argument to be passed to your application's main class.
+ | Multiple invocations are possible, each will be passed in order.
+ """.stripMargin
+ }
+}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala
new file mode 100644
index 0000000000..dcc2288dd1
--- /dev/null
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala
@@ -0,0 +1,186 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.net.URI
+
+import scala.collection.mutable.{HashMap, ListBuffer, Map}
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
+import org.apache.hadoop.fs.permission.FsAction
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
+
+import org.apache.spark.SparkConf
+import org.apache.spark.deploy.yarn.config._
+import org.apache.spark.internal.Logging
+
+private case class CacheEntry(
+ uri: URI,
+ size: Long,
+ modTime: Long,
+ visibility: LocalResourceVisibility,
+ resType: LocalResourceType)
+
+/** Client side methods to setup the Hadoop distributed cache */
+private[spark] class ClientDistributedCacheManager() extends Logging {
+
+ private val distCacheEntries = new ListBuffer[CacheEntry]()
+
+ /**
+ * Add a resource to the list of distributed cache resources. This list can
+ * be sent to the ApplicationMaster and possibly the executors so that it can
+ * be downloaded into the Hadoop distributed cache for use by this application.
+ * Adds the LocalResource to the localResources HashMap passed in and saves
+ * the stats of the resources to they can be sent to the executors and verified.
+ *
+ * @param fs FileSystem
+ * @param conf Configuration
+ * @param destPath path to the resource
+ * @param localResources localResource hashMap to insert the resource into
+ * @param resourceType LocalResourceType
+ * @param link link presented in the distributed cache to the destination
+ * @param statCache cache to store the file/directory stats
+ * @param appMasterOnly Whether to only add the resource to the app master
+ */
+ def addResource(
+ fs: FileSystem,
+ conf: Configuration,
+ destPath: Path,
+ localResources: HashMap[String, LocalResource],
+ resourceType: LocalResourceType,
+ link: String,
+ statCache: Map[URI, FileStatus],
+ appMasterOnly: Boolean = false): Unit = {
+ val destStatus = fs.getFileStatus(destPath)
+ val amJarRsrc = Records.newRecord(classOf[LocalResource])
+ amJarRsrc.setType(resourceType)
+ val visibility = getVisibility(conf, destPath.toUri(), statCache)
+ amJarRsrc.setVisibility(visibility)
+ amJarRsrc.setResource(ConverterUtils.getYarnUrlFromPath(destPath))
+ amJarRsrc.setTimestamp(destStatus.getModificationTime())
+ amJarRsrc.setSize(destStatus.getLen())
+ require(link != null && link.nonEmpty, "You must specify a valid link name.")
+ localResources(link) = amJarRsrc
+
+ if (!appMasterOnly) {
+ val uri = destPath.toUri()
+ val pathURI = new URI(uri.getScheme(), uri.getAuthority(), uri.getPath(), null, link)
+ distCacheEntries += CacheEntry(pathURI, destStatus.getLen(), destStatus.getModificationTime(),
+ visibility, resourceType)
+ }
+ }
+
+ /**
+ * Writes down information about cached files needed in executors to the given configuration.
+ */
+ def updateConfiguration(conf: SparkConf): Unit = {
+ conf.set(CACHED_FILES, distCacheEntries.map(_.uri.toString))
+ conf.set(CACHED_FILES_SIZES, distCacheEntries.map(_.size))
+ conf.set(CACHED_FILES_TIMESTAMPS, distCacheEntries.map(_.modTime))
+ conf.set(CACHED_FILES_VISIBILITIES, distCacheEntries.map(_.visibility.name()))
+ conf.set(CACHED_FILES_TYPES, distCacheEntries.map(_.resType.name()))
+ }
+
+ /**
+ * Returns the local resource visibility depending on the cache file permissions
+ * @return LocalResourceVisibility
+ */
+ private[yarn] def getVisibility(
+ conf: Configuration,
+ uri: URI,
+ statCache: Map[URI, FileStatus]): LocalResourceVisibility = {
+ if (isPublic(conf, uri, statCache)) {
+ LocalResourceVisibility.PUBLIC
+ } else {
+ LocalResourceVisibility.PRIVATE
+ }
+ }
+
+ /**
+ * Returns a boolean to denote whether a cache file is visible to all (public)
+ * @return true if the path in the uri is visible to all, false otherwise
+ */
+ private def isPublic(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): Boolean = {
+ val fs = FileSystem.get(uri, conf)
+ val current = new Path(uri.getPath())
+ // the leaf level file should be readable by others
+ if (!checkPermissionOfOther(fs, current, FsAction.READ, statCache)) {
+ return false
+ }
+ ancestorsHaveExecutePermissions(fs, current.getParent(), statCache)
+ }
+
+ /**
+ * Returns true if all ancestors of the specified path have the 'execute'
+ * permission set for all users (i.e. that other users can traverse
+ * the directory hierarchy to the given path)
+ * @return true if all ancestors have the 'execute' permission set for all users
+ */
+ private def ancestorsHaveExecutePermissions(
+ fs: FileSystem,
+ path: Path,
+ statCache: Map[URI, FileStatus]): Boolean = {
+ var current = path
+ while (current != null) {
+ // the subdirs in the path should have execute permissions for others
+ if (!checkPermissionOfOther(fs, current, FsAction.EXECUTE, statCache)) {
+ return false
+ }
+ current = current.getParent()
+ }
+ true
+ }
+
+ /**
+ * Checks for a given path whether the Other permissions on it
+ * imply the permission in the passed FsAction
+ * @return true if the path in the uri is visible to all, false otherwise
+ */
+ private def checkPermissionOfOther(
+ fs: FileSystem,
+ path: Path,
+ action: FsAction,
+ statCache: Map[URI, FileStatus]): Boolean = {
+ val status = getFileStatus(fs, path.toUri(), statCache)
+ val perms = status.getPermission()
+ val otherAction = perms.getOtherAction()
+ otherAction.implies(action)
+ }
+
+ /**
+ * Checks to see if the given uri exists in the cache, if it does it
+ * returns the existing FileStatus, otherwise it stats the uri, stores
+ * it in the cache, and returns the FileStatus.
+ * @return FileStatus
+ */
+ private[yarn] def getFileStatus(
+ fs: FileSystem,
+ uri: URI,
+ statCache: Map[URI, FileStatus]): FileStatus = {
+ val stat = statCache.get(uri) match {
+ case Some(existstat) => existstat
+ case None =>
+ val newStat = fs.getFileStatus(new Path(uri))
+ statCache.put(uri, newStat)
+ newStat
+ }
+ stat
+ }
+}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
new file mode 100644
index 0000000000..868c2edc5a
--- /dev/null
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
@@ -0,0 +1,266 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.io.File
+import java.nio.ByteBuffer
+import java.util.Collections
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.{HashMap, ListBuffer}
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.DataOutputBuffer
+import org.apache.hadoop.security.UserGroupInformation
+import org.apache.hadoop.yarn.api._
+import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.client.api.NMClient
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.ipc.YarnRPC
+import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
+
+import org.apache.spark.{SecurityManager, SparkConf, SparkException}
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config._
+import org.apache.spark.launcher.YarnCommandBuilderUtils
+import org.apache.spark.network.util.JavaUtils
+import org.apache.spark.util.Utils
+
+private[yarn] class ExecutorRunnable(
+ container: Option[Container],
+ conf: YarnConfiguration,
+ sparkConf: SparkConf,
+ masterAddress: String,
+ executorId: String,
+ hostname: String,
+ executorMemory: Int,
+ executorCores: Int,
+ appId: String,
+ securityMgr: SecurityManager,
+ localResources: Map[String, LocalResource]) extends Logging {
+
+ var rpc: YarnRPC = YarnRPC.create(conf)
+ var nmClient: NMClient = _
+
+ def run(): Unit = {
+ logDebug("Starting Executor Container")
+ nmClient = NMClient.createNMClient()
+ nmClient.init(conf)
+ nmClient.start()
+ startContainer()
+ }
+
+ def launchContextDebugInfo(): String = {
+ val commands = prepareCommand()
+ val env = prepareEnvironment()
+
+ s"""
+ |===============================================================================
+ |YARN executor launch context:
+ | env:
+ |${Utils.redact(sparkConf, env.toSeq).map { case (k, v) => s" $k -> $v\n" }.mkString}
+ | command:
+ | ${commands.mkString(" \\ \n ")}
+ |
+ | resources:
+ |${localResources.map { case (k, v) => s" $k -> $v\n" }.mkString}
+ |===============================================================================""".stripMargin
+ }
+
+ def startContainer(): java.util.Map[String, ByteBuffer] = {
+ val ctx = Records.newRecord(classOf[ContainerLaunchContext])
+ .asInstanceOf[ContainerLaunchContext]
+ val env = prepareEnvironment().asJava
+
+ ctx.setLocalResources(localResources.asJava)
+ ctx.setEnvironment(env)
+
+ val credentials = UserGroupInformation.getCurrentUser().getCredentials()
+ val dob = new DataOutputBuffer()
+ credentials.writeTokenStorageToStream(dob)
+ ctx.setTokens(ByteBuffer.wrap(dob.getData()))
+
+ val commands = prepareCommand()
+
+ ctx.setCommands(commands.asJava)
+ ctx.setApplicationACLs(
+ YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr).asJava)
+
+ // If external shuffle service is enabled, register with the Yarn shuffle service already
+ // started on the NodeManager and, if authentication is enabled, provide it with our secret
+ // key for fetching shuffle files later
+ if (sparkConf.get(SHUFFLE_SERVICE_ENABLED)) {
+ val secretString = securityMgr.getSecretKey()
+ val secretBytes =
+ if (secretString != null) {
+ // This conversion must match how the YarnShuffleService decodes our secret
+ JavaUtils.stringToBytes(secretString)
+ } else {
+ // Authentication is not enabled, so just provide dummy metadata
+ ByteBuffer.allocate(0)
+ }
+ ctx.setServiceData(Collections.singletonMap("spark_shuffle", secretBytes))
+ }
+
+ // Send the start request to the ContainerManager
+ try {
+ nmClient.startContainer(container.get, ctx)
+ } catch {
+ case ex: Exception =>
+ throw new SparkException(s"Exception while starting container ${container.get.getId}" +
+ s" on host $hostname", ex)
+ }
+ }
+
+ private def prepareCommand(): List[String] = {
+ // Extra options for the JVM
+ val javaOpts = ListBuffer[String]()
+
+ // Set the environment variable through a command prefix
+ // to append to the existing value of the variable
+ var prefixEnv: Option[String] = None
+
+ // Set the JVM memory
+ val executorMemoryString = executorMemory + "m"
+ javaOpts += "-Xmx" + executorMemoryString
+
+ // Set extra Java options for the executor, if defined
+ sparkConf.get(EXECUTOR_JAVA_OPTIONS).foreach { opts =>
+ javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell)
+ }
+ sys.env.get("SPARK_JAVA_OPTS").foreach { opts =>
+ javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell)
+ }
+ sparkConf.get(EXECUTOR_LIBRARY_PATH).foreach { p =>
+ prefixEnv = Some(Client.getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(p))))
+ }
+
+ javaOpts += "-Djava.io.tmpdir=" +
+ new Path(
+ YarnSparkHadoopUtil.expandEnvironment(Environment.PWD),
+ YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR
+ )
+
+ // Certain configs need to be passed here because they are needed before the Executor
+ // registers with the Scheduler and transfers the spark configs. Since the Executor backend
+ // uses RPC to connect to the scheduler, the RPC settings are needed as well as the
+ // authentication settings.
+ sparkConf.getAll
+ .filter { case (k, v) => SparkConf.isExecutorStartupConf(k) }
+ .foreach { case (k, v) => javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v") }
+
+ // Commenting it out for now - so that people can refer to the properties if required. Remove
+ // it once cpuset version is pushed out.
+ // The context is, default gc for server class machines end up using all cores to do gc - hence
+ // if there are multiple containers in same node, spark gc effects all other containers
+ // performance (which can also be other spark containers)
+ // Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in
+ // multi-tenant environments. Not sure how default java gc behaves if it is limited to subset
+ // of cores on a node.
+ /*
+ else {
+ // If no java_opts specified, default to using -XX:+CMSIncrementalMode
+ // It might be possible that other modes/config is being done in
+ // spark.executor.extraJavaOptions, so we don't want to mess with it.
+ // In our expts, using (default) throughput collector has severe perf ramifications in
+ // multi-tenant machines
+ // The options are based on
+ // http://www.oracle.com/technetwork/java/gc-tuning-5-138395.html#0.0.0.%20When%20to%20Use
+ // %20the%20Concurrent%20Low%20Pause%20Collector|outline
+ javaOpts += "-XX:+UseConcMarkSweepGC"
+ javaOpts += "-XX:+CMSIncrementalMode"
+ javaOpts += "-XX:+CMSIncrementalPacing"
+ javaOpts += "-XX:CMSIncrementalDutyCycleMin=0"
+ javaOpts += "-XX:CMSIncrementalDutyCycle=10"
+ }
+ */
+
+ // For log4j configuration to reference
+ javaOpts += ("-Dspark.yarn.app.container.log.dir=" + ApplicationConstants.LOG_DIR_EXPANSION_VAR)
+ YarnCommandBuilderUtils.addPermGenSizeOpt(javaOpts)
+
+ val userClassPath = Client.getUserClasspath(sparkConf).flatMap { uri =>
+ val absPath =
+ if (new File(uri.getPath()).isAbsolute()) {
+ Client.getClusterPath(sparkConf, uri.getPath())
+ } else {
+ Client.buildPath(Environment.PWD.$(), uri.getPath())
+ }
+ Seq("--user-class-path", "file:" + absPath)
+ }.toSeq
+
+ YarnSparkHadoopUtil.addOutOfMemoryErrorArgument(javaOpts)
+ val commands = prefixEnv ++ Seq(
+ YarnSparkHadoopUtil.expandEnvironment(Environment.JAVA_HOME) + "/bin/java",
+ "-server") ++
+ javaOpts ++
+ Seq("org.apache.spark.executor.CoarseGrainedExecutorBackend",
+ "--driver-url", masterAddress,
+ "--executor-id", executorId,
+ "--hostname", hostname,
+ "--cores", executorCores.toString,
+ "--app-id", appId) ++
+ userClassPath ++
+ Seq(
+ s"1>${ApplicationConstants.LOG_DIR_EXPANSION_VAR}/stdout",
+ s"2>${ApplicationConstants.LOG_DIR_EXPANSION_VAR}/stderr")
+
+ // TODO: it would be nicer to just make sure there are no null commands here
+ commands.map(s => if (s == null) "null" else s).toList
+ }
+
+ private def prepareEnvironment(): HashMap[String, String] = {
+ val env = new HashMap[String, String]()
+ Client.populateClasspath(null, conf, sparkConf, env, sparkConf.get(EXECUTOR_CLASS_PATH))
+
+ sparkConf.getExecutorEnv.foreach { case (key, value) =>
+ // This assumes each executor environment variable set here is a path
+ // This is kept for backward compatibility and consistency with hadoop
+ YarnSparkHadoopUtil.addPathToEnvironment(env, key, value)
+ }
+
+ // Keep this for backwards compatibility but users should move to the config
+ sys.env.get("SPARK_YARN_USER_ENV").foreach { userEnvs =>
+ YarnSparkHadoopUtil.setEnvFromInputString(env, userEnvs)
+ }
+
+ // lookup appropriate http scheme for container log urls
+ val yarnHttpPolicy = conf.get(
+ YarnConfiguration.YARN_HTTP_POLICY_KEY,
+ YarnConfiguration.YARN_HTTP_POLICY_DEFAULT
+ )
+ val httpScheme = if (yarnHttpPolicy == "HTTPS_ONLY") "https://" else "http://"
+
+ // Add log urls
+ container.foreach { c =>
+ sys.env.get("SPARK_USER").foreach { user =>
+ val containerId = ConverterUtils.toString(c.getId)
+ val address = c.getNodeHttpAddress
+ val baseUrl = s"$httpScheme$address/node/containerlogs/$containerId/$user"
+
+ env("SPARK_LOG_URL_STDERR") = s"$baseUrl/stderr?start=-4096"
+ env("SPARK_LOG_URL_STDOUT") = s"$baseUrl/stdout?start=-4096"
+ }
+ }
+
+ System.getenv().asScala.filterKeys(_.startsWith("SPARK"))
+ .foreach { case (k, v) => env(k) = v }
+ env
+ }
+}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala
new file mode 100644
index 0000000000..8772e26f43
--- /dev/null
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import scala.collection.mutable.{ArrayBuffer, HashMap, Set}
+import scala.collection.JavaConverters._
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.yarn.api.records.{ContainerId, Resource}
+import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
+import org.apache.hadoop.yarn.util.RackResolver
+
+import org.apache.spark.SparkConf
+import org.apache.spark.internal.config._
+
+private[yarn] case class ContainerLocalityPreferences(nodes: Array[String], racks: Array[String])
+
+/**
+ * This strategy is calculating the optimal locality preferences of YARN containers by considering
+ * the node ratio of pending tasks, number of required cores/containers and and locality of current
+ * existing and pending allocated containers. The target of this algorithm is to maximize the number
+ * of tasks that would run locally.
+ *
+ * Consider a situation in which we have 20 tasks that require (host1, host2, host3)
+ * and 10 tasks that require (host1, host2, host4), besides each container has 2 cores
+ * and cpus per task is 1, so the required container number is 15,
+ * and host ratio is (host1: 30, host2: 30, host3: 20, host4: 10).
+ *
+ * 1. If requested container number (18) is more than the required container number (15):
+ *
+ * requests for 5 containers with nodes: (host1, host2, host3, host4)
+ * requests for 5 containers with nodes: (host1, host2, host3)
+ * requests for 5 containers with nodes: (host1, host2)
+ * requests for 3 containers with no locality preferences.
+ *
+ * The placement ratio is 3 : 3 : 2 : 1, and set the additional containers with no locality
+ * preferences.
+ *
+ * 2. If requested container number (10) is less than or equal to the required container number
+ * (15):
+ *
+ * requests for 4 containers with nodes: (host1, host2, host3, host4)
+ * requests for 3 containers with nodes: (host1, host2, host3)
+ * requests for 3 containers with nodes: (host1, host2)
+ *
+ * The placement ratio is 10 : 10 : 7 : 4, close to expected ratio (3 : 3 : 2 : 1)
+ *
+ * 3. If containers exist but none of them can match the requested localities,
+ * follow the method of 1 and 2.
+ *
+ * 4. If containers exist and some of them can match the requested localities.
+ * For example if we have 1 containers on each node (host1: 1, host2: 1: host3: 1, host4: 1),
+ * and the expected containers on each node would be (host1: 5, host2: 5, host3: 4, host4: 2),
+ * so the newly requested containers on each node would be updated to (host1: 4, host2: 4,
+ * host3: 3, host4: 1), 12 containers by total.
+ *
+ * 4.1 If requested container number (18) is more than newly required containers (12). Follow
+ * method 1 with updated ratio 4 : 4 : 3 : 1.
+ *
+ * 4.2 If request container number (10) is more than newly required containers (12). Follow
+ * method 2 with updated ratio 4 : 4 : 3 : 1.
+ *
+ * 5. If containers exist and existing localities can fully cover the requested localities.
+ * For example if we have 5 containers on each node (host1: 5, host2: 5, host3: 5, host4: 5),
+ * which could cover the current requested localities. This algorithm will allocate all the
+ * requested containers with no localities.
+ */
+private[yarn] class LocalityPreferredContainerPlacementStrategy(
+ val sparkConf: SparkConf,
+ val yarnConf: Configuration,
+ val resource: Resource) {
+
+ /**
+ * Calculate each container's node locality and rack locality
+ * @param numContainer number of containers to calculate
+ * @param numLocalityAwareTasks number of locality required tasks
+ * @param hostToLocalTaskCount a map to store the preferred hostname and possible task
+ * numbers running on it, used as hints for container allocation
+ * @param allocatedHostToContainersMap host to allocated containers map, used to calculate the
+ * expected locality preference by considering the existing
+ * containers
+ * @param localityMatchedPendingAllocations A sequence of pending container request which
+ * matches the localities of current required tasks.
+ * @return node localities and rack localities, each locality is an array of string,
+ * the length of localities is the same as number of containers
+ */
+ def localityOfRequestedContainers(
+ numContainer: Int,
+ numLocalityAwareTasks: Int,
+ hostToLocalTaskCount: Map[String, Int],
+ allocatedHostToContainersMap: HashMap[String, Set[ContainerId]],
+ localityMatchedPendingAllocations: Seq[ContainerRequest]
+ ): Array[ContainerLocalityPreferences] = {
+ val updatedHostToContainerCount = expectedHostToContainerCount(
+ numLocalityAwareTasks, hostToLocalTaskCount, allocatedHostToContainersMap,
+ localityMatchedPendingAllocations)
+ val updatedLocalityAwareContainerNum = updatedHostToContainerCount.values.sum
+
+ // The number of containers to allocate, divided into two groups, one with preferred locality,
+ // and the other without locality preference.
+ val requiredLocalityFreeContainerNum =
+ math.max(0, numContainer - updatedLocalityAwareContainerNum)
+ val requiredLocalityAwareContainerNum = numContainer - requiredLocalityFreeContainerNum
+
+ val containerLocalityPreferences = ArrayBuffer[ContainerLocalityPreferences]()
+ if (requiredLocalityFreeContainerNum > 0) {
+ for (i <- 0 until requiredLocalityFreeContainerNum) {
+ containerLocalityPreferences += ContainerLocalityPreferences(
+ null.asInstanceOf[Array[String]], null.asInstanceOf[Array[String]])
+ }
+ }
+
+ if (requiredLocalityAwareContainerNum > 0) {
+ val largestRatio = updatedHostToContainerCount.values.max
+ // Round the ratio of preferred locality to the number of locality required container
+ // number, which is used for locality preferred host calculating.
+ var preferredLocalityRatio = updatedHostToContainerCount.mapValues { ratio =>
+ val adjustedRatio = ratio.toDouble * requiredLocalityAwareContainerNum / largestRatio
+ adjustedRatio.ceil.toInt
+ }
+
+ for (i <- 0 until requiredLocalityAwareContainerNum) {
+ // Only filter out the ratio which is larger than 0, which means the current host can
+ // still be allocated with new container request.
+ val hosts = preferredLocalityRatio.filter(_._2 > 0).keys.toArray
+ val racks = hosts.map { h =>
+ RackResolver.resolve(yarnConf, h).getNetworkLocation
+ }.toSet
+ containerLocalityPreferences += ContainerLocalityPreferences(hosts, racks.toArray)
+
+ // Minus 1 each time when the host is used. When the current ratio is 0,
+ // which means all the required ratio is satisfied, this host will not be allocated again.
+ preferredLocalityRatio = preferredLocalityRatio.mapValues(_ - 1)
+ }
+ }
+
+ containerLocalityPreferences.toArray
+ }
+
+ /**
+ * Calculate the number of executors need to satisfy the given number of pending tasks.
+ */
+ private def numExecutorsPending(numTasksPending: Int): Int = {
+ val coresPerExecutor = resource.getVirtualCores
+ (numTasksPending * sparkConf.get(CPUS_PER_TASK) + coresPerExecutor - 1) / coresPerExecutor
+ }
+
+ /**
+ * Calculate the expected host to number of containers by considering with allocated containers.
+ * @param localityAwareTasks number of locality aware tasks
+ * @param hostToLocalTaskCount a map to store the preferred hostname and possible task
+ * numbers running on it, used as hints for container allocation
+ * @param allocatedHostToContainersMap host to allocated containers map, used to calculate the
+ * expected locality preference by considering the existing
+ * containers
+ * @param localityMatchedPendingAllocations A sequence of pending container request which
+ * matches the localities of current required tasks.
+ * @return a map with hostname as key and required number of containers on this host as value
+ */
+ private def expectedHostToContainerCount(
+ localityAwareTasks: Int,
+ hostToLocalTaskCount: Map[String, Int],
+ allocatedHostToContainersMap: HashMap[String, Set[ContainerId]],
+ localityMatchedPendingAllocations: Seq[ContainerRequest]
+ ): Map[String, Int] = {
+ val totalLocalTaskNum = hostToLocalTaskCount.values.sum
+ val pendingHostToContainersMap = pendingHostToContainerCount(localityMatchedPendingAllocations)
+
+ hostToLocalTaskCount.map { case (host, count) =>
+ val expectedCount =
+ count.toDouble * numExecutorsPending(localityAwareTasks) / totalLocalTaskNum
+ // Take the locality of pending containers into consideration
+ val existedCount = allocatedHostToContainersMap.get(host).map(_.size).getOrElse(0) +
+ pendingHostToContainersMap.getOrElse(host, 0.0)
+
+ // If existing container can not fully satisfy the expected number of container,
+ // the required container number is expected count minus existed count. Otherwise the
+ // required container number is 0.
+ (host, math.max(0, (expectedCount - existedCount).ceil.toInt))
+ }
+ }
+
+ /**
+ * According to the locality ratio and number of container requests, calculate the host to
+ * possible number of containers for pending allocated containers.
+ *
+ * If current locality ratio of hosts is: Host1 : Host2 : Host3 = 20 : 20 : 10,
+ * and pending container requests is 3, so the possible number of containers on
+ * Host1 : Host2 : Host3 will be 1.2 : 1.2 : 0.6.
+ * @param localityMatchedPendingAllocations A sequence of pending container request which
+ * matches the localities of current required tasks.
+ * @return a Map with hostname as key and possible number of containers on this host as value
+ */
+ private def pendingHostToContainerCount(
+ localityMatchedPendingAllocations: Seq[ContainerRequest]): Map[String, Double] = {
+ val pendingHostToContainerCount = new HashMap[String, Int]()
+ localityMatchedPendingAllocations.foreach { cr =>
+ cr.getNodes.asScala.foreach { n =>
+ val count = pendingHostToContainerCount.getOrElse(n, 0) + 1
+ pendingHostToContainerCount(n) = count
+ }
+ }
+
+ val possibleTotalContainerNum = pendingHostToContainerCount.values.sum
+ val localityMatchedPendingNum = localityMatchedPendingAllocations.size.toDouble
+ pendingHostToContainerCount.mapValues(_ * localityMatchedPendingNum / possibleTotalContainerNum)
+ .toMap
+ }
+}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
new file mode 100644
index 0000000000..0b66d1cf08
--- /dev/null
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
@@ -0,0 +1,727 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.util.Collections
+import java.util.concurrent._
+import java.util.regex.Pattern
+
+import scala.collection.mutable
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue}
+import scala.collection.JavaConverters._
+import scala.util.control.NonFatal
+
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.client.api.AMRMClient
+import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.util.RackResolver
+import org.apache.log4j.{Level, Logger}
+
+import org.apache.spark.{SecurityManager, SparkConf, SparkException}
+import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._
+import org.apache.spark.deploy.yarn.config._
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config._
+import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef}
+import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason}
+import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor
+import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RetrieveLastAllocatedExecutorId
+import org.apache.spark.util.{Clock, SystemClock, ThreadUtils}
+
+/**
+ * YarnAllocator is charged with requesting containers from the YARN ResourceManager and deciding
+ * what to do with containers when YARN fulfills these requests.
+ *
+ * This class makes use of YARN's AMRMClient APIs. We interact with the AMRMClient in three ways:
+ * * Making our resource needs known, which updates local bookkeeping about containers requested.
+ * * Calling "allocate", which syncs our local container requests with the RM, and returns any
+ * containers that YARN has granted to us. This also functions as a heartbeat.
+ * * Processing the containers granted to us to possibly launch executors inside of them.
+ *
+ * The public methods of this class are thread-safe. All methods that mutate state are
+ * synchronized.
+ */
+private[yarn] class YarnAllocator(
+ driverUrl: String,
+ driverRef: RpcEndpointRef,
+ conf: YarnConfiguration,
+ sparkConf: SparkConf,
+ amClient: AMRMClient[ContainerRequest],
+ appAttemptId: ApplicationAttemptId,
+ securityMgr: SecurityManager,
+ localResources: Map[String, LocalResource])
+ extends Logging {
+
+ import YarnAllocator._
+
+ // RackResolver logs an INFO message whenever it resolves a rack, which is way too often.
+ if (Logger.getLogger(classOf[RackResolver]).getLevel == null) {
+ Logger.getLogger(classOf[RackResolver]).setLevel(Level.WARN)
+ }
+
+ // Visible for testing.
+ val allocatedHostToContainersMap = new HashMap[String, collection.mutable.Set[ContainerId]]
+ val allocatedContainerToHostMap = new HashMap[ContainerId, String]
+
+ // Containers that we no longer care about. We've either already told the RM to release them or
+ // will on the next heartbeat. Containers get removed from this map after the RM tells us they've
+ // completed.
+ private val releasedContainers = Collections.newSetFromMap[ContainerId](
+ new ConcurrentHashMap[ContainerId, java.lang.Boolean])
+
+ @volatile private var numExecutorsRunning = 0
+
+ /**
+ * Used to generate a unique ID per executor
+ *
+ * Init `executorIdCounter`. when AM restart, `executorIdCounter` will reset to 0. Then
+ * the id of new executor will start from 1, this will conflict with the executor has
+ * already created before. So, we should initialize the `executorIdCounter` by getting
+ * the max executorId from driver.
+ *
+ * And this situation of executorId conflict is just in yarn client mode, so this is an issue
+ * in yarn client mode. For more details, can check in jira.
+ *
+ * @see SPARK-12864
+ */
+ private var executorIdCounter: Int =
+ driverRef.askWithRetry[Int](RetrieveLastAllocatedExecutorId)
+
+ // Queue to store the timestamp of failed executors
+ private val failedExecutorsTimeStamps = new Queue[Long]()
+
+ private var clock: Clock = new SystemClock
+
+ private val executorFailuresValidityInterval =
+ sparkConf.get(EXECUTOR_ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS).getOrElse(-1L)
+
+ @volatile private var targetNumExecutors =
+ YarnSparkHadoopUtil.getInitialTargetExecutorNumber(sparkConf)
+
+ // Executor loss reason requests that are pending - maps from executor ID for inquiry to a
+ // list of requesters that should be responded to once we find out why the given executor
+ // was lost.
+ private val pendingLossReasonRequests = new HashMap[String, mutable.Buffer[RpcCallContext]]
+
+ // Maintain loss reasons for already released executors, it will be added when executor loss
+ // reason is got from AM-RM call, and be removed after querying this loss reason.
+ private val releasedExecutorLossReasons = new HashMap[String, ExecutorLossReason]
+
+ // Keep track of which container is running which executor to remove the executors later
+ // Visible for testing.
+ private[yarn] val executorIdToContainer = new HashMap[String, Container]
+
+ private var numUnexpectedContainerRelease = 0L
+ private val containerIdToExecutorId = new HashMap[ContainerId, String]
+
+ // Executor memory in MB.
+ protected val executorMemory = sparkConf.get(EXECUTOR_MEMORY).toInt
+ // Additional memory overhead.
+ protected val memoryOverhead: Int = sparkConf.get(EXECUTOR_MEMORY_OVERHEAD).getOrElse(
+ math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toInt, MEMORY_OVERHEAD_MIN)).toInt
+ // Number of cores per executor.
+ protected val executorCores = sparkConf.get(EXECUTOR_CORES)
+ // Resource capability requested for each executors
+ private[yarn] val resource = Resource.newInstance(executorMemory + memoryOverhead, executorCores)
+
+ private val launcherPool = ThreadUtils.newDaemonCachedThreadPool(
+ "ContainerLauncher", sparkConf.get(CONTAINER_LAUNCH_MAX_THREADS))
+
+ // For testing
+ private val launchContainers = sparkConf.getBoolean("spark.yarn.launchContainers", true)
+
+ private val labelExpression = sparkConf.get(EXECUTOR_NODE_LABEL_EXPRESSION)
+
+ // ContainerRequest constructor that can take a node label expression. We grab it through
+ // reflection because it's only available in later versions of YARN.
+ private val nodeLabelConstructor = labelExpression.flatMap { expr =>
+ try {
+ Some(classOf[ContainerRequest].getConstructor(classOf[Resource],
+ classOf[Array[String]], classOf[Array[String]], classOf[Priority], classOf[Boolean],
+ classOf[String]))
+ } catch {
+ case e: NoSuchMethodException =>
+ logWarning(s"Node label expression $expr will be ignored because YARN version on" +
+ " classpath does not support it.")
+ None
+ }
+ }
+
+ // A map to store preferred hostname and possible task numbers running on it.
+ private var hostToLocalTaskCounts: Map[String, Int] = Map.empty
+
+ // Number of tasks that have locality preferences in active stages
+ private var numLocalityAwareTasks: Int = 0
+
+ // A container placement strategy based on pending tasks' locality preference
+ private[yarn] val containerPlacementStrategy =
+ new LocalityPreferredContainerPlacementStrategy(sparkConf, conf, resource)
+
+ /**
+ * Use a different clock for YarnAllocator. This is mainly used for testing.
+ */
+ def setClock(newClock: Clock): Unit = {
+ clock = newClock
+ }
+
+ def getNumExecutorsRunning: Int = numExecutorsRunning
+
+ def getNumExecutorsFailed: Int = synchronized {
+ val endTime = clock.getTimeMillis()
+
+ while (executorFailuresValidityInterval > 0
+ && failedExecutorsTimeStamps.nonEmpty
+ && failedExecutorsTimeStamps.head < endTime - executorFailuresValidityInterval) {
+ failedExecutorsTimeStamps.dequeue()
+ }
+
+ failedExecutorsTimeStamps.size
+ }
+
+ /**
+ * A sequence of pending container requests that have not yet been fulfilled.
+ */
+ def getPendingAllocate: Seq[ContainerRequest] = getPendingAtLocation(ANY_HOST)
+
+ /**
+ * A sequence of pending container requests at the given location that have not yet been
+ * fulfilled.
+ */
+ private def getPendingAtLocation(location: String): Seq[ContainerRequest] = {
+ amClient.getMatchingRequests(RM_REQUEST_PRIORITY, location, resource).asScala
+ .flatMap(_.asScala)
+ .toSeq
+ }
+
+ /**
+ * Request as many executors from the ResourceManager as needed to reach the desired total. If
+ * the requested total is smaller than the current number of running executors, no executors will
+ * be killed.
+ * @param requestedTotal total number of containers requested
+ * @param localityAwareTasks number of locality aware tasks to be used as container placement hint
+ * @param hostToLocalTaskCount a map of preferred hostname to possible task counts to be used as
+ * container placement hint.
+ * @return Whether the new requested total is different than the old value.
+ */
+ def requestTotalExecutorsWithPreferredLocalities(
+ requestedTotal: Int,
+ localityAwareTasks: Int,
+ hostToLocalTaskCount: Map[String, Int]): Boolean = synchronized {
+ this.numLocalityAwareTasks = localityAwareTasks
+ this.hostToLocalTaskCounts = hostToLocalTaskCount
+
+ if (requestedTotal != targetNumExecutors) {
+ logInfo(s"Driver requested a total number of $requestedTotal executor(s).")
+ targetNumExecutors = requestedTotal
+ true
+ } else {
+ false
+ }
+ }
+
+ /**
+ * Request that the ResourceManager release the container running the specified executor.
+ */
+ def killExecutor(executorId: String): Unit = synchronized {
+ if (executorIdToContainer.contains(executorId)) {
+ val container = executorIdToContainer.get(executorId).get
+ internalReleaseContainer(container)
+ numExecutorsRunning -= 1
+ } else {
+ logWarning(s"Attempted to kill unknown executor $executorId!")
+ }
+ }
+
+ /**
+ * Request resources such that, if YARN gives us all we ask for, we'll have a number of containers
+ * equal to maxExecutors.
+ *
+ * Deal with any containers YARN has granted to us by possibly launching executors in them.
+ *
+ * This must be synchronized because variables read in this method are mutated by other methods.
+ */
+ def allocateResources(): Unit = synchronized {
+ updateResourceRequests()
+
+ val progressIndicator = 0.1f
+ // Poll the ResourceManager. This doubles as a heartbeat if there are no pending container
+ // requests.
+ val allocateResponse = amClient.allocate(progressIndicator)
+
+ val allocatedContainers = allocateResponse.getAllocatedContainers()
+
+ if (allocatedContainers.size > 0) {
+ logDebug("Allocated containers: %d. Current executor count: %d. Cluster resources: %s."
+ .format(
+ allocatedContainers.size,
+ numExecutorsRunning,
+ allocateResponse.getAvailableResources))
+
+ handleAllocatedContainers(allocatedContainers.asScala)
+ }
+
+ val completedContainers = allocateResponse.getCompletedContainersStatuses()
+ if (completedContainers.size > 0) {
+ logDebug("Completed %d containers".format(completedContainers.size))
+ processCompletedContainers(completedContainers.asScala)
+ logDebug("Finished processing %d completed containers. Current running executor count: %d."
+ .format(completedContainers.size, numExecutorsRunning))
+ }
+ }
+
+ /**
+ * Update the set of container requests that we will sync with the RM based on the number of
+ * executors we have currently running and our target number of executors.
+ *
+ * Visible for testing.
+ */
+ def updateResourceRequests(): Unit = {
+ val pendingAllocate = getPendingAllocate
+ val numPendingAllocate = pendingAllocate.size
+ val missing = targetNumExecutors - numPendingAllocate - numExecutorsRunning
+
+ if (missing > 0) {
+ logInfo(s"Will request $missing executor container(s), each with " +
+ s"${resource.getVirtualCores} core(s) and " +
+ s"${resource.getMemory} MB memory (including $memoryOverhead MB of overhead)")
+
+ // Split the pending container request into three groups: locality matched list, locality
+ // unmatched list and non-locality list. Take the locality matched container request into
+ // consideration of container placement, treat as allocated containers.
+ // For locality unmatched and locality free container requests, cancel these container
+ // requests, since required locality preference has been changed, recalculating using
+ // container placement strategy.
+ val (localRequests, staleRequests, anyHostRequests) = splitPendingAllocationsByLocality(
+ hostToLocalTaskCounts, pendingAllocate)
+
+ // cancel "stale" requests for locations that are no longer needed
+ staleRequests.foreach { stale =>
+ amClient.removeContainerRequest(stale)
+ }
+ val cancelledContainers = staleRequests.size
+ if (cancelledContainers > 0) {
+ logInfo(s"Canceled $cancelledContainers container request(s) (locality no longer needed)")
+ }
+
+ // consider the number of new containers and cancelled stale containers available
+ val availableContainers = missing + cancelledContainers
+
+ // to maximize locality, include requests with no locality preference that can be cancelled
+ val potentialContainers = availableContainers + anyHostRequests.size
+
+ val containerLocalityPreferences = containerPlacementStrategy.localityOfRequestedContainers(
+ potentialContainers, numLocalityAwareTasks, hostToLocalTaskCounts,
+ allocatedHostToContainersMap, localRequests)
+
+ val newLocalityRequests = new mutable.ArrayBuffer[ContainerRequest]
+ containerLocalityPreferences.foreach {
+ case ContainerLocalityPreferences(nodes, racks) if nodes != null =>
+ newLocalityRequests += createContainerRequest(resource, nodes, racks)
+ case _ =>
+ }
+
+ if (availableContainers >= newLocalityRequests.size) {
+ // more containers are available than needed for locality, fill in requests for any host
+ for (i <- 0 until (availableContainers - newLocalityRequests.size)) {
+ newLocalityRequests += createContainerRequest(resource, null, null)
+ }
+ } else {
+ val numToCancel = newLocalityRequests.size - availableContainers
+ // cancel some requests without locality preferences to schedule more local containers
+ anyHostRequests.slice(0, numToCancel).foreach { nonLocal =>
+ amClient.removeContainerRequest(nonLocal)
+ }
+ if (numToCancel > 0) {
+ logInfo(s"Canceled $numToCancel unlocalized container requests to resubmit with locality")
+ }
+ }
+
+ newLocalityRequests.foreach { request =>
+ amClient.addContainerRequest(request)
+ }
+
+ if (log.isInfoEnabled()) {
+ val (localized, anyHost) = newLocalityRequests.partition(_.getNodes() != null)
+ if (anyHost.nonEmpty) {
+ logInfo(s"Submitted ${anyHost.size} unlocalized container requests.")
+ }
+ localized.foreach { request =>
+ logInfo(s"Submitted container request for host ${hostStr(request)}.")
+ }
+ }
+ } else if (numPendingAllocate > 0 && missing < 0) {
+ val numToCancel = math.min(numPendingAllocate, -missing)
+ logInfo(s"Canceling requests for $numToCancel executor container(s) to have a new desired " +
+ s"total $targetNumExecutors executors.")
+
+ val matchingRequests = amClient.getMatchingRequests(RM_REQUEST_PRIORITY, ANY_HOST, resource)
+ if (!matchingRequests.isEmpty) {
+ matchingRequests.iterator().next().asScala
+ .take(numToCancel).foreach(amClient.removeContainerRequest)
+ } else {
+ logWarning("Expected to find pending requests, but found none.")
+ }
+ }
+ }
+
+ private def hostStr(request: ContainerRequest): String = {
+ Option(request.getNodes) match {
+ case Some(nodes) => nodes.asScala.mkString(",")
+ case None => "Any"
+ }
+ }
+
+ /**
+ * Creates a container request, handling the reflection required to use YARN features that were
+ * added in recent versions.
+ */
+ private def createContainerRequest(
+ resource: Resource,
+ nodes: Array[String],
+ racks: Array[String]): ContainerRequest = {
+ nodeLabelConstructor.map { constructor =>
+ constructor.newInstance(resource, nodes, racks, RM_REQUEST_PRIORITY, true: java.lang.Boolean,
+ labelExpression.orNull)
+ }.getOrElse(new ContainerRequest(resource, nodes, racks, RM_REQUEST_PRIORITY))
+ }
+
+ /**
+ * Handle containers granted by the RM by launching executors on them.
+ *
+ * Due to the way the YARN allocation protocol works, certain healthy race conditions can result
+ * in YARN granting containers that we no longer need. In this case, we release them.
+ *
+ * Visible for testing.
+ */
+ def handleAllocatedContainers(allocatedContainers: Seq[Container]): Unit = {
+ val containersToUse = new ArrayBuffer[Container](allocatedContainers.size)
+
+ // Match incoming requests by host
+ val remainingAfterHostMatches = new ArrayBuffer[Container]
+ for (allocatedContainer <- allocatedContainers) {
+ matchContainerToRequest(allocatedContainer, allocatedContainer.getNodeId.getHost,
+ containersToUse, remainingAfterHostMatches)
+ }
+
+ // Match remaining by rack
+ val remainingAfterRackMatches = new ArrayBuffer[Container]
+ for (allocatedContainer <- remainingAfterHostMatches) {
+ val rack = RackResolver.resolve(conf, allocatedContainer.getNodeId.getHost).getNetworkLocation
+ matchContainerToRequest(allocatedContainer, rack, containersToUse,
+ remainingAfterRackMatches)
+ }
+
+ // Assign remaining that are neither node-local nor rack-local
+ val remainingAfterOffRackMatches = new ArrayBuffer[Container]
+ for (allocatedContainer <- remainingAfterRackMatches) {
+ matchContainerToRequest(allocatedContainer, ANY_HOST, containersToUse,
+ remainingAfterOffRackMatches)
+ }
+
+ if (!remainingAfterOffRackMatches.isEmpty) {
+ logDebug(s"Releasing ${remainingAfterOffRackMatches.size} unneeded containers that were " +
+ s"allocated to us")
+ for (container <- remainingAfterOffRackMatches) {
+ internalReleaseContainer(container)
+ }
+ }
+
+ runAllocatedContainers(containersToUse)
+
+ logInfo("Received %d containers from YARN, launching executors on %d of them."
+ .format(allocatedContainers.size, containersToUse.size))
+ }
+
+ /**
+ * Looks for requests for the given location that match the given container allocation. If it
+ * finds one, removes the request so that it won't be submitted again. Places the container into
+ * containersToUse or remaining.
+ *
+ * @param allocatedContainer container that was given to us by YARN
+ * @param location resource name, either a node, rack, or *
+ * @param containersToUse list of containers that will be used
+ * @param remaining list of containers that will not be used
+ */
+ private def matchContainerToRequest(
+ allocatedContainer: Container,
+ location: String,
+ containersToUse: ArrayBuffer[Container],
+ remaining: ArrayBuffer[Container]): Unit = {
+ // SPARK-6050: certain Yarn configurations return a virtual core count that doesn't match the
+ // request; for example, capacity scheduler + DefaultResourceCalculator. So match on requested
+ // memory, but use the asked vcore count for matching, effectively disabling matching on vcore
+ // count.
+ val matchingResource = Resource.newInstance(allocatedContainer.getResource.getMemory,
+ resource.getVirtualCores)
+ val matchingRequests = amClient.getMatchingRequests(allocatedContainer.getPriority, location,
+ matchingResource)
+
+ // Match the allocation to a request
+ if (!matchingRequests.isEmpty) {
+ val containerRequest = matchingRequests.get(0).iterator.next
+ amClient.removeContainerRequest(containerRequest)
+ containersToUse += allocatedContainer
+ } else {
+ remaining += allocatedContainer
+ }
+ }
+
+ /**
+ * Launches executors in the allocated containers.
+ */
+ private def runAllocatedContainers(containersToUse: ArrayBuffer[Container]): Unit = {
+ for (container <- containersToUse) {
+ executorIdCounter += 1
+ val executorHostname = container.getNodeId.getHost
+ val containerId = container.getId
+ val executorId = executorIdCounter.toString
+ assert(container.getResource.getMemory >= resource.getMemory)
+ logInfo(s"Launching container $containerId on host $executorHostname")
+
+ def updateInternalState(): Unit = synchronized {
+ numExecutorsRunning += 1
+ executorIdToContainer(executorId) = container
+ containerIdToExecutorId(container.getId) = executorId
+
+ val containerSet = allocatedHostToContainersMap.getOrElseUpdate(executorHostname,
+ new HashSet[ContainerId])
+ containerSet += containerId
+ allocatedContainerToHostMap.put(containerId, executorHostname)
+ }
+
+ if (numExecutorsRunning < targetNumExecutors) {
+ if (launchContainers) {
+ launcherPool.execute(new Runnable {
+ override def run(): Unit = {
+ try {
+ new ExecutorRunnable(
+ Some(container),
+ conf,
+ sparkConf,
+ driverUrl,
+ executorId,
+ executorHostname,
+ executorMemory,
+ executorCores,
+ appAttemptId.getApplicationId.toString,
+ securityMgr,
+ localResources
+ ).run()
+ updateInternalState()
+ } catch {
+ case NonFatal(e) =>
+ logError(s"Failed to launch executor $executorId on container $containerId", e)
+ // Assigned container should be released immediately to avoid unnecessary resource
+ // occupation.
+ amClient.releaseAssignedContainer(containerId)
+ }
+ }
+ })
+ } else {
+ // For test only
+ updateInternalState()
+ }
+ } else {
+ logInfo(("Skip launching executorRunnable as runnning Excecutors count: %d " +
+ "reached target Executors count: %d.").format(numExecutorsRunning, targetNumExecutors))
+ }
+ }
+ }
+
+ // Visible for testing.
+ private[yarn] def processCompletedContainers(completedContainers: Seq[ContainerStatus]): Unit = {
+ for (completedContainer <- completedContainers) {
+ val containerId = completedContainer.getContainerId
+ val alreadyReleased = releasedContainers.remove(containerId)
+ val hostOpt = allocatedContainerToHostMap.get(containerId)
+ val onHostStr = hostOpt.map(host => s" on host: $host").getOrElse("")
+ val exitReason = if (!alreadyReleased) {
+ // Decrement the number of executors running. The next iteration of
+ // the ApplicationMaster's reporting thread will take care of allocating.
+ numExecutorsRunning -= 1
+ logInfo("Completed container %s%s (state: %s, exit status: %s)".format(
+ containerId,
+ onHostStr,
+ completedContainer.getState,
+ completedContainer.getExitStatus))
+ // Hadoop 2.2.X added a ContainerExitStatus we should switch to use
+ // there are some exit status' we shouldn't necessarily count against us, but for
+ // now I think its ok as none of the containers are expected to exit.
+ val exitStatus = completedContainer.getExitStatus
+ val (exitCausedByApp, containerExitReason) = exitStatus match {
+ case ContainerExitStatus.SUCCESS =>
+ (false, s"Executor for container $containerId exited because of a YARN event (e.g., " +
+ "pre-emption) and not because of an error in the running job.")
+ case ContainerExitStatus.PREEMPTED =>
+ // Preemption is not the fault of the running tasks, since YARN preempts containers
+ // merely to do resource sharing, and tasks that fail due to preempted executors could
+ // just as easily finish on any other executor. See SPARK-8167.
+ (false, s"Container ${containerId}${onHostStr} was preempted.")
+ // Should probably still count memory exceeded exit codes towards task failures
+ case VMEM_EXCEEDED_EXIT_CODE =>
+ (true, memLimitExceededLogMessage(
+ completedContainer.getDiagnostics,
+ VMEM_EXCEEDED_PATTERN))
+ case PMEM_EXCEEDED_EXIT_CODE =>
+ (true, memLimitExceededLogMessage(
+ completedContainer.getDiagnostics,
+ PMEM_EXCEEDED_PATTERN))
+ case _ =>
+ // Enqueue the timestamp of failed executor
+ failedExecutorsTimeStamps.enqueue(clock.getTimeMillis())
+ (true, "Container marked as failed: " + containerId + onHostStr +
+ ". Exit status: " + completedContainer.getExitStatus +
+ ". Diagnostics: " + completedContainer.getDiagnostics)
+
+ }
+ if (exitCausedByApp) {
+ logWarning(containerExitReason)
+ } else {
+ logInfo(containerExitReason)
+ }
+ ExecutorExited(exitStatus, exitCausedByApp, containerExitReason)
+ } else {
+ // If we have already released this container, then it must mean
+ // that the driver has explicitly requested it to be killed
+ ExecutorExited(completedContainer.getExitStatus, exitCausedByApp = false,
+ s"Container $containerId exited from explicit termination request.")
+ }
+
+ for {
+ host <- hostOpt
+ containerSet <- allocatedHostToContainersMap.get(host)
+ } {
+ containerSet.remove(containerId)
+ if (containerSet.isEmpty) {
+ allocatedHostToContainersMap.remove(host)
+ } else {
+ allocatedHostToContainersMap.update(host, containerSet)
+ }
+
+ allocatedContainerToHostMap.remove(containerId)
+ }
+
+ containerIdToExecutorId.remove(containerId).foreach { eid =>
+ executorIdToContainer.remove(eid)
+ pendingLossReasonRequests.remove(eid) match {
+ case Some(pendingRequests) =>
+ // Notify application of executor loss reasons so it can decide whether it should abort
+ pendingRequests.foreach(_.reply(exitReason))
+
+ case None =>
+ // We cannot find executor for pending reasons. This is because completed container
+ // is processed before querying pending result. We should store it for later query.
+ // This is usually happened when explicitly killing a container, the result will be
+ // returned in one AM-RM communication. So query RPC will be later than this completed
+ // container process.
+ releasedExecutorLossReasons.put(eid, exitReason)
+ }
+ if (!alreadyReleased) {
+ // The executor could have gone away (like no route to host, node failure, etc)
+ // Notify backend about the failure of the executor
+ numUnexpectedContainerRelease += 1
+ driverRef.send(RemoveExecutor(eid, exitReason))
+ }
+ }
+ }
+ }
+
+ /**
+ * Register that some RpcCallContext has asked the AM why the executor was lost. Note that
+ * we can only find the loss reason to send back in the next call to allocateResources().
+ */
+ private[yarn] def enqueueGetLossReasonRequest(
+ eid: String,
+ context: RpcCallContext): Unit = synchronized {
+ if (executorIdToContainer.contains(eid)) {
+ pendingLossReasonRequests
+ .getOrElseUpdate(eid, new ArrayBuffer[RpcCallContext]) += context
+ } else if (releasedExecutorLossReasons.contains(eid)) {
+ // Executor is already released explicitly before getting the loss reason, so directly send
+ // the pre-stored lost reason
+ context.reply(releasedExecutorLossReasons.remove(eid).get)
+ } else {
+ logWarning(s"Tried to get the loss reason for non-existent executor $eid")
+ context.sendFailure(
+ new SparkException(s"Fail to find loss reason for non-existent executor $eid"))
+ }
+ }
+
+ private def internalReleaseContainer(container: Container): Unit = {
+ releasedContainers.add(container.getId())
+ amClient.releaseAssignedContainer(container.getId())
+ }
+
+ private[yarn] def getNumUnexpectedContainerRelease = numUnexpectedContainerRelease
+
+ private[yarn] def getNumPendingLossReasonRequests: Int = synchronized {
+ pendingLossReasonRequests.size
+ }
+
+ /**
+ * Split the pending container requests into 3 groups based on current localities of pending
+ * tasks.
+ * @param hostToLocalTaskCount a map of preferred hostname to possible task counts to be used as
+ * container placement hint.
+ * @param pendingAllocations A sequence of pending allocation container request.
+ * @return A tuple of 3 sequences, first is a sequence of locality matched container
+ * requests, second is a sequence of locality unmatched container requests, and third is a
+ * sequence of locality free container requests.
+ */
+ private def splitPendingAllocationsByLocality(
+ hostToLocalTaskCount: Map[String, Int],
+ pendingAllocations: Seq[ContainerRequest]
+ ): (Seq[ContainerRequest], Seq[ContainerRequest], Seq[ContainerRequest]) = {
+ val localityMatched = ArrayBuffer[ContainerRequest]()
+ val localityUnMatched = ArrayBuffer[ContainerRequest]()
+ val localityFree = ArrayBuffer[ContainerRequest]()
+
+ val preferredHosts = hostToLocalTaskCount.keySet
+ pendingAllocations.foreach { cr =>
+ val nodes = cr.getNodes
+ if (nodes == null) {
+ localityFree += cr
+ } else if (nodes.asScala.toSet.intersect(preferredHosts).nonEmpty) {
+ localityMatched += cr
+ } else {
+ localityUnMatched += cr
+ }
+ }
+
+ (localityMatched.toSeq, localityUnMatched.toSeq, localityFree.toSeq)
+ }
+
+}
+
+private object YarnAllocator {
+ val MEM_REGEX = "[0-9.]+ [KMG]B"
+ val PMEM_EXCEEDED_PATTERN =
+ Pattern.compile(s"$MEM_REGEX of $MEM_REGEX physical memory used")
+ val VMEM_EXCEEDED_PATTERN =
+ Pattern.compile(s"$MEM_REGEX of $MEM_REGEX virtual memory used")
+ val VMEM_EXCEEDED_EXIT_CODE = -103
+ val PMEM_EXCEEDED_EXIT_CODE = -104
+
+ def memLimitExceededLogMessage(diagnostics: String, pattern: Pattern): String = {
+ val matcher = pattern.matcher(diagnostics)
+ val diag = if (matcher.find()) " " + matcher.group() + "." else ""
+ ("Container killed by YARN for exceeding memory limits." + diag
+ + " Consider boosting spark.yarn.executor.memoryOverhead.")
+ }
+}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
new file mode 100644
index 0000000000..53df11eb66
--- /dev/null
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.util.{List => JList}
+
+import scala.collection.JavaConverters._
+import scala.util.Try
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.client.api.AMRMClient
+import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.webapp.util.WebAppUtils
+
+import org.apache.spark.{SecurityManager, SparkConf}
+import org.apache.spark.deploy.yarn.config._
+import org.apache.spark.internal.Logging
+import org.apache.spark.rpc.RpcEndpointRef
+import org.apache.spark.util.Utils
+
+/**
+ * Handles registering and unregistering the application with the YARN ResourceManager.
+ */
+private[spark] class YarnRMClient extends Logging {
+
+ private var amClient: AMRMClient[ContainerRequest] = _
+ private var uiHistoryAddress: String = _
+ private var registered: Boolean = false
+
+ /**
+ * Registers the application master with the RM.
+ *
+ * @param conf The Yarn configuration.
+ * @param sparkConf The Spark configuration.
+ * @param uiAddress Address of the SparkUI.
+ * @param uiHistoryAddress Address of the application on the History Server.
+ * @param securityMgr The security manager.
+ * @param localResources Map with information about files distributed via YARN's cache.
+ */
+ def register(
+ driverUrl: String,
+ driverRef: RpcEndpointRef,
+ conf: YarnConfiguration,
+ sparkConf: SparkConf,
+ uiAddress: String,
+ uiHistoryAddress: String,
+ securityMgr: SecurityManager,
+ localResources: Map[String, LocalResource]
+ ): YarnAllocator = {
+ amClient = AMRMClient.createAMRMClient()
+ amClient.init(conf)
+ amClient.start()
+ this.uiHistoryAddress = uiHistoryAddress
+
+ logInfo("Registering the ApplicationMaster")
+ synchronized {
+ amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress)
+ registered = true
+ }
+ new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), securityMgr,
+ localResources)
+ }
+
+ /**
+ * Unregister the AM. Guaranteed to only be called once.
+ *
+ * @param status The final status of the AM.
+ * @param diagnostics Diagnostics message to include in the final status.
+ */
+ def unregister(status: FinalApplicationStatus, diagnostics: String = ""): Unit = synchronized {
+ if (registered) {
+ amClient.unregisterApplicationMaster(status, diagnostics, uiHistoryAddress)
+ }
+ }
+
+ /** Returns the attempt ID. */
+ def getAttemptId(): ApplicationAttemptId = {
+ YarnSparkHadoopUtil.get.getContainerId.getApplicationAttemptId()
+ }
+
+ /** Returns the configuration for the AmIpFilter to add to the Spark UI. */
+ def getAmIpFilterParams(conf: YarnConfiguration, proxyBase: String): Map[String, String] = {
+ // Figure out which scheme Yarn is using. Note the method seems to have been added after 2.2,
+ // so not all stable releases have it.
+ val prefix = Try(classOf[WebAppUtils].getMethod("getHttpSchemePrefix", classOf[Configuration])
+ .invoke(null, conf).asInstanceOf[String]).getOrElse("http://")
+
+ // If running a new enough Yarn, use the HA-aware API for retrieving the RM addresses.
+ try {
+ val method = classOf[WebAppUtils].getMethod("getProxyHostsAndPortsForAmFilter",
+ classOf[Configuration])
+ val proxies = method.invoke(null, conf).asInstanceOf[JList[String]]
+ val hosts = proxies.asScala.map { proxy => proxy.split(":")(0) }
+ val uriBases = proxies.asScala.map { proxy => prefix + proxy + proxyBase }
+ Map("PROXY_HOSTS" -> hosts.mkString(","), "PROXY_URI_BASES" -> uriBases.mkString(","))
+ } catch {
+ case e: NoSuchMethodException =>
+ val proxy = WebAppUtils.getProxyHostAndPort(conf)
+ val parts = proxy.split(":")
+ val uriBase = prefix + proxy + proxyBase
+ Map("PROXY_HOST" -> parts(0), "PROXY_URI_BASE" -> uriBase)
+ }
+ }
+
+ /** Returns the maximum number of attempts to register the AM. */
+ def getMaxRegAttempts(sparkConf: SparkConf, yarnConf: YarnConfiguration): Int = {
+ val sparkMaxAttempts = sparkConf.get(MAX_APP_ATTEMPTS).map(_.toInt)
+ val yarnMaxAttempts = yarnConf.getInt(
+ YarnConfiguration.RM_AM_MAX_ATTEMPTS, YarnConfiguration.DEFAULT_RM_AM_MAX_ATTEMPTS)
+ val retval: Int = sparkMaxAttempts match {
+ case Some(x) => if (x <= yarnMaxAttempts) x else yarnMaxAttempts
+ case None => yarnMaxAttempts
+ }
+
+ retval
+ }
+
+}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
new file mode 100644
index 0000000000..cc53b1b06e
--- /dev/null
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
@@ -0,0 +1,317 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.io.File
+import java.nio.charset.StandardCharsets.UTF_8
+import java.util.regex.Matcher
+import java.util.regex.Pattern
+
+import scala.collection.mutable.{HashMap, ListBuffer}
+import scala.util.Try
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.io.Text
+import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.security.Credentials
+import org.apache.hadoop.security.UserGroupInformation
+import org.apache.hadoop.yarn.api.ApplicationConstants
+import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
+import org.apache.hadoop.yarn.api.records.{ApplicationAccessType, ContainerId, Priority}
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.util.ConverterUtils
+
+import org.apache.spark.{SecurityManager, SparkConf, SparkException}
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.deploy.yarn.security.{ConfigurableCredentialManager, CredentialUpdater}
+import org.apache.spark.internal.config._
+import org.apache.spark.launcher.YarnCommandBuilderUtils
+import org.apache.spark.util.Utils
+
+/**
+ * Contains util methods to interact with Hadoop from spark.
+ */
+class YarnSparkHadoopUtil extends SparkHadoopUtil {
+
+ private var credentialUpdater: CredentialUpdater = _
+
+ override def transferCredentials(source: UserGroupInformation, dest: UserGroupInformation) {
+ dest.addCredentials(source.getCredentials())
+ }
+
+ // Note that all params which start with SPARK are propagated all the way through, so if in yarn
+ // mode, this MUST be set to true.
+ override def isYarnMode(): Boolean = { true }
+
+ // Return an appropriate (subclass) of Configuration. Creating a config initializes some Hadoop
+ // subsystems. Always create a new config, don't reuse yarnConf.
+ override def newConfiguration(conf: SparkConf): Configuration =
+ new YarnConfiguration(super.newConfiguration(conf))
+
+ // Add any user credentials to the job conf which are necessary for running on a secure Hadoop
+ // cluster
+ override def addCredentials(conf: JobConf) {
+ val jobCreds = conf.getCredentials()
+ jobCreds.mergeAll(UserGroupInformation.getCurrentUser().getCredentials())
+ }
+
+ override def getCurrentUserCredentials(): Credentials = {
+ UserGroupInformation.getCurrentUser().getCredentials()
+ }
+
+ override def addCurrentUserCredentials(creds: Credentials) {
+ UserGroupInformation.getCurrentUser().addCredentials(creds)
+ }
+
+ override def addSecretKeyToUserCredentials(key: String, secret: String) {
+ val creds = new Credentials()
+ creds.addSecretKey(new Text(key), secret.getBytes(UTF_8))
+ addCurrentUserCredentials(creds)
+ }
+
+ override def getSecretKeyFromUserCredentials(key: String): Array[Byte] = {
+ val credentials = getCurrentUserCredentials()
+ if (credentials != null) credentials.getSecretKey(new Text(key)) else null
+ }
+
+ private[spark] override def startCredentialUpdater(sparkConf: SparkConf): Unit = {
+ credentialUpdater =
+ new ConfigurableCredentialManager(sparkConf, newConfiguration(sparkConf)).credentialUpdater()
+ credentialUpdater.start()
+ }
+
+ private[spark] override def stopCredentialUpdater(): Unit = {
+ if (credentialUpdater != null) {
+ credentialUpdater.stop()
+ credentialUpdater = null
+ }
+ }
+
+ private[spark] def getContainerId: ContainerId = {
+ val containerIdString = System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name())
+ ConverterUtils.toContainerId(containerIdString)
+ }
+}
+
+object YarnSparkHadoopUtil {
+ // Additional memory overhead
+ // 10% was arrived at experimentally. In the interest of minimizing memory waste while covering
+ // the common cases. Memory overhead tends to grow with container size.
+
+ val MEMORY_OVERHEAD_FACTOR = 0.10
+ val MEMORY_OVERHEAD_MIN = 384L
+
+ val ANY_HOST = "*"
+
+ val DEFAULT_NUMBER_EXECUTORS = 2
+
+ // All RM requests are issued with same priority : we do not (yet) have any distinction between
+ // request types (like map/reduce in hadoop for example)
+ val RM_REQUEST_PRIORITY = Priority.newInstance(1)
+
+ def get: YarnSparkHadoopUtil = {
+ val yarnMode = java.lang.Boolean.parseBoolean(
+ System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")))
+ if (!yarnMode) {
+ throw new SparkException("YarnSparkHadoopUtil is not available in non-YARN mode!")
+ }
+ SparkHadoopUtil.get.asInstanceOf[YarnSparkHadoopUtil]
+ }
+ /**
+ * Add a path variable to the given environment map.
+ * If the map already contains this key, append the value to the existing value instead.
+ */
+ def addPathToEnvironment(env: HashMap[String, String], key: String, value: String): Unit = {
+ val newValue = if (env.contains(key)) { env(key) + getClassPathSeparator + value } else value
+ env.put(key, newValue)
+ }
+
+ /**
+ * Set zero or more environment variables specified by the given input string.
+ * The input string is expected to take the form "KEY1=VAL1,KEY2=VAL2,KEY3=VAL3".
+ */
+ def setEnvFromInputString(env: HashMap[String, String], inputString: String): Unit = {
+ if (inputString != null && inputString.length() > 0) {
+ val childEnvs = inputString.split(",")
+ val p = Pattern.compile(environmentVariableRegex)
+ for (cEnv <- childEnvs) {
+ val parts = cEnv.split("=") // split on '='
+ val m = p.matcher(parts(1))
+ val sb = new StringBuffer
+ while (m.find()) {
+ val variable = m.group(1)
+ var replace = ""
+ if (env.get(variable) != None) {
+ replace = env.get(variable).get
+ } else {
+ // if this key is not configured for the child .. get it from the env
+ replace = System.getenv(variable)
+ if (replace == null) {
+ // the env key is note present anywhere .. simply set it
+ replace = ""
+ }
+ }
+ m.appendReplacement(sb, Matcher.quoteReplacement(replace))
+ }
+ m.appendTail(sb)
+ // This treats the environment variable as path variable delimited by `File.pathSeparator`
+ // This is kept for backward compatibility and consistency with Hadoop's behavior
+ addPathToEnvironment(env, parts(0), sb.toString)
+ }
+ }
+ }
+
+ private val environmentVariableRegex: String = {
+ if (Utils.isWindows) {
+ "%([A-Za-z_][A-Za-z0-9_]*?)%"
+ } else {
+ "\\$([A-Za-z_][A-Za-z0-9_]*)"
+ }
+ }
+
+ /**
+ * Kill if OOM is raised - leverage yarn's failure handling to cause rescheduling.
+ * Not killing the task leaves various aspects of the executor and (to some extent) the jvm in
+ * an inconsistent state.
+ * TODO: If the OOM is not recoverable by rescheduling it on different node, then do
+ * 'something' to fail job ... akin to blacklisting trackers in mapred ?
+ *
+ * The handler if an OOM Exception is thrown by the JVM must be configured on Windows
+ * differently: the 'taskkill' command should be used, whereas Unix-based systems use 'kill'.
+ *
+ * As the JVM interprets both %p and %%p as the same, we can use either of them. However,
+ * some tests on Windows computers suggest, that the JVM only accepts '%%p'.
+ *
+ * Furthermore, the behavior of the character '%' on the Windows command line differs from
+ * the behavior of '%' in a .cmd file: it gets interpreted as an incomplete environment
+ * variable. Windows .cmd files escape a '%' by '%%'. Thus, the correct way of writing
+ * '%%p' in an escaped way is '%%%%p'.
+ */
+ private[yarn] def addOutOfMemoryErrorArgument(javaOpts: ListBuffer[String]): Unit = {
+ if (!javaOpts.exists(_.contains("-XX:OnOutOfMemoryError"))) {
+ if (Utils.isWindows) {
+ javaOpts += escapeForShell("-XX:OnOutOfMemoryError=taskkill /F /PID %%%%p")
+ } else {
+ javaOpts += "-XX:OnOutOfMemoryError='kill %p'"
+ }
+ }
+ }
+
+ /**
+ * Escapes a string for inclusion in a command line executed by Yarn. Yarn executes commands
+ * using either
+ *
+ * (Unix-based) `bash -c "command arg1 arg2"` and that means plain quoting doesn't really work.
+ * The argument is enclosed in single quotes and some key characters are escaped.
+ *
+ * (Windows-based) part of a .cmd file in which case windows escaping for each argument must be
+ * applied. Windows is quite lenient, however it is usually Java that causes trouble, needing to
+ * distinguish between arguments starting with '-' and class names. If arguments are surrounded
+ * by ' java takes the following string as is, hence an argument is mistakenly taken as a class
+ * name which happens to start with a '-'. The way to avoid this, is to surround nothing with
+ * a ', but instead with a ".
+ *
+ * @param arg A single argument.
+ * @return Argument quoted for execution via Yarn's generated shell script.
+ */
+ def escapeForShell(arg: String): String = {
+ if (arg != null) {
+ if (Utils.isWindows) {
+ YarnCommandBuilderUtils.quoteForBatchScript(arg)
+ } else {
+ val escaped = new StringBuilder("'")
+ for (i <- 0 to arg.length() - 1) {
+ arg.charAt(i) match {
+ case '$' => escaped.append("\\$")
+ case '"' => escaped.append("\\\"")
+ case '\'' => escaped.append("'\\''")
+ case c => escaped.append(c)
+ }
+ }
+ escaped.append("'").toString()
+ }
+ } else {
+ arg
+ }
+ }
+
+ // YARN/Hadoop acls are specified as user1,user2 group1,group2
+ // Users and groups are separated by a space and hence we need to pass the acls in same format
+ def getApplicationAclsForYarn(securityMgr: SecurityManager)
+ : Map[ApplicationAccessType, String] = {
+ Map[ApplicationAccessType, String] (
+ ApplicationAccessType.VIEW_APP -> (securityMgr.getViewAcls + " " +
+ securityMgr.getViewAclsGroups),
+ ApplicationAccessType.MODIFY_APP -> (securityMgr.getModifyAcls + " " +
+ securityMgr.getModifyAclsGroups)
+ )
+ }
+
+ /**
+ * Expand environment variable using Yarn API.
+ * If environment.$$() is implemented, return the result of it.
+ * Otherwise, return the result of environment.$()
+ * Note: $$() is added in Hadoop 2.4.
+ */
+ private lazy val expandMethod =
+ Try(classOf[Environment].getMethod("$$"))
+ .getOrElse(classOf[Environment].getMethod("$"))
+
+ def expandEnvironment(environment: Environment): String =
+ expandMethod.invoke(environment).asInstanceOf[String]
+
+ /**
+ * Get class path separator using Yarn API.
+ * If ApplicationConstants.CLASS_PATH_SEPARATOR is implemented, return it.
+ * Otherwise, return File.pathSeparator
+ * Note: CLASS_PATH_SEPARATOR is added in Hadoop 2.4.
+ */
+ private lazy val classPathSeparatorField =
+ Try(classOf[ApplicationConstants].getField("CLASS_PATH_SEPARATOR"))
+ .getOrElse(classOf[File].getField("pathSeparator"))
+
+ def getClassPathSeparator(): String = {
+ classPathSeparatorField.get(null).asInstanceOf[String]
+ }
+
+ /**
+ * Getting the initial target number of executors depends on whether dynamic allocation is
+ * enabled.
+ * If not using dynamic allocation it gets the number of executors requested by the user.
+ */
+ def getInitialTargetExecutorNumber(
+ conf: SparkConf,
+ numExecutors: Int = DEFAULT_NUMBER_EXECUTORS): Int = {
+ if (Utils.isDynamicAllocationEnabled(conf)) {
+ val minNumExecutors = conf.get(DYN_ALLOCATION_MIN_EXECUTORS)
+ val initialNumExecutors = Utils.getDynamicAllocationInitialExecutors(conf)
+ val maxNumExecutors = conf.get(DYN_ALLOCATION_MAX_EXECUTORS)
+ require(initialNumExecutors >= minNumExecutors && initialNumExecutors <= maxNumExecutors,
+ s"initial executor number $initialNumExecutors must between min executor number " +
+ s"$minNumExecutors and max executor number $maxNumExecutors")
+
+ initialNumExecutors
+ } else {
+ val targetNumExecutors =
+ sys.env.get("SPARK_EXECUTOR_INSTANCES").map(_.toInt).getOrElse(numExecutors)
+ // System property can override environment variable.
+ conf.get(EXECUTOR_INSTANCES).getOrElse(targetNumExecutors)
+ }
+ }
+}
+
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala
new file mode 100644
index 0000000000..666cb456a9
--- /dev/null
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala
@@ -0,0 +1,347 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.util.concurrent.TimeUnit
+
+import org.apache.spark.internal.config.ConfigBuilder
+import org.apache.spark.network.util.ByteUnit
+
+package object config {
+
+ /* Common app configuration. */
+
+ private[spark] val APPLICATION_TAGS = ConfigBuilder("spark.yarn.tags")
+ .doc("Comma-separated list of strings to pass through as YARN application tags appearing " +
+ "in YARN Application Reports, which can be used for filtering when querying YARN.")
+ .stringConf
+ .toSequence
+ .createOptional
+
+ private[spark] val AM_ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS =
+ ConfigBuilder("spark.yarn.am.attemptFailuresValidityInterval")
+ .doc("Interval after which AM failures will be considered independent and " +
+ "not accumulate towards the attempt count.")
+ .timeConf(TimeUnit.MILLISECONDS)
+ .createOptional
+
+ private[spark] val AM_PORT =
+ ConfigBuilder("spark.yarn.am.port")
+ .intConf
+ .createWithDefault(0)
+
+ private[spark] val EXECUTOR_ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS =
+ ConfigBuilder("spark.yarn.executor.failuresValidityInterval")
+ .doc("Interval after which Executor failures will be considered independent and not " +
+ "accumulate towards the attempt count.")
+ .timeConf(TimeUnit.MILLISECONDS)
+ .createOptional
+
+ private[spark] val MAX_APP_ATTEMPTS = ConfigBuilder("spark.yarn.maxAppAttempts")
+ .doc("Maximum number of AM attempts before failing the app.")
+ .intConf
+ .createOptional
+
+ private[spark] val USER_CLASS_PATH_FIRST = ConfigBuilder("spark.yarn.user.classpath.first")
+ .doc("Whether to place user jars in front of Spark's classpath.")
+ .booleanConf
+ .createWithDefault(false)
+
+ private[spark] val GATEWAY_ROOT_PATH = ConfigBuilder("spark.yarn.config.gatewayPath")
+ .doc("Root of configuration paths that is present on gateway nodes, and will be replaced " +
+ "with the corresponding path in cluster machines.")
+ .stringConf
+ .createWithDefault(null)
+
+ private[spark] val REPLACEMENT_ROOT_PATH = ConfigBuilder("spark.yarn.config.replacementPath")
+ .doc(s"Path to use as a replacement for ${GATEWAY_ROOT_PATH.key} when launching processes " +
+ "in the YARN cluster.")
+ .stringConf
+ .createWithDefault(null)
+
+ private[spark] val QUEUE_NAME = ConfigBuilder("spark.yarn.queue")
+ .stringConf
+ .createWithDefault("default")
+
+ private[spark] val HISTORY_SERVER_ADDRESS = ConfigBuilder("spark.yarn.historyServer.address")
+ .stringConf
+ .createOptional
+
+ /* File distribution. */
+
+ private[spark] val SPARK_ARCHIVE = ConfigBuilder("spark.yarn.archive")
+ .doc("Location of archive containing jars files with Spark classes.")
+ .stringConf
+ .createOptional
+
+ private[spark] val SPARK_JARS = ConfigBuilder("spark.yarn.jars")
+ .doc("Location of jars containing Spark classes.")
+ .stringConf
+ .toSequence
+ .createOptional
+
+ private[spark] val ARCHIVES_TO_DISTRIBUTE = ConfigBuilder("spark.yarn.dist.archives")
+ .stringConf
+ .toSequence
+ .createWithDefault(Nil)
+
+ private[spark] val FILES_TO_DISTRIBUTE = ConfigBuilder("spark.yarn.dist.files")
+ .stringConf
+ .toSequence
+ .createWithDefault(Nil)
+
+ private[spark] val JARS_TO_DISTRIBUTE = ConfigBuilder("spark.yarn.dist.jars")
+ .stringConf
+ .toSequence
+ .createWithDefault(Nil)
+
+ private[spark] val PRESERVE_STAGING_FILES = ConfigBuilder("spark.yarn.preserve.staging.files")
+ .doc("Whether to preserve temporary files created by the job in HDFS.")
+ .booleanConf
+ .createWithDefault(false)
+
+ private[spark] val STAGING_FILE_REPLICATION = ConfigBuilder("spark.yarn.submit.file.replication")
+ .doc("Replication factor for files uploaded by Spark to HDFS.")
+ .intConf
+ .createOptional
+
+ private[spark] val STAGING_DIR = ConfigBuilder("spark.yarn.stagingDir")
+ .doc("Staging directory used while submitting applications.")
+ .stringConf
+ .createOptional
+
+ /* Cluster-mode launcher configuration. */
+
+ private[spark] val WAIT_FOR_APP_COMPLETION = ConfigBuilder("spark.yarn.submit.waitAppCompletion")
+ .doc("In cluster mode, whether to wait for the application to finish before exiting the " +
+ "launcher process.")
+ .booleanConf
+ .createWithDefault(true)
+
+ private[spark] val REPORT_INTERVAL = ConfigBuilder("spark.yarn.report.interval")
+ .doc("Interval between reports of the current app status in cluster mode.")
+ .timeConf(TimeUnit.MILLISECONDS)
+ .createWithDefaultString("1s")
+
+ /* Shared Client-mode AM / Driver configuration. */
+
+ private[spark] val AM_MAX_WAIT_TIME = ConfigBuilder("spark.yarn.am.waitTime")
+ .timeConf(TimeUnit.MILLISECONDS)
+ .createWithDefaultString("100s")
+
+ private[spark] val AM_NODE_LABEL_EXPRESSION = ConfigBuilder("spark.yarn.am.nodeLabelExpression")
+ .doc("Node label expression for the AM.")
+ .stringConf
+ .createOptional
+
+ private[spark] val CONTAINER_LAUNCH_MAX_THREADS =
+ ConfigBuilder("spark.yarn.containerLauncherMaxThreads")
+ .intConf
+ .createWithDefault(25)
+
+ private[spark] val MAX_EXECUTOR_FAILURES = ConfigBuilder("spark.yarn.max.executor.failures")
+ .intConf
+ .createOptional
+
+ private[spark] val MAX_REPORTER_THREAD_FAILURES =
+ ConfigBuilder("spark.yarn.scheduler.reporterThread.maxFailures")
+ .intConf
+ .createWithDefault(5)
+
+ private[spark] val RM_HEARTBEAT_INTERVAL =
+ ConfigBuilder("spark.yarn.scheduler.heartbeat.interval-ms")
+ .timeConf(TimeUnit.MILLISECONDS)
+ .createWithDefaultString("3s")
+
+ private[spark] val INITIAL_HEARTBEAT_INTERVAL =
+ ConfigBuilder("spark.yarn.scheduler.initial-allocation.interval")
+ .timeConf(TimeUnit.MILLISECONDS)
+ .createWithDefaultString("200ms")
+
+ private[spark] val SCHEDULER_SERVICES = ConfigBuilder("spark.yarn.services")
+ .doc("A comma-separated list of class names of services to add to the scheduler.")
+ .stringConf
+ .toSequence
+ .createWithDefault(Nil)
+
+ /* Client-mode AM configuration. */
+
+ private[spark] val AM_CORES = ConfigBuilder("spark.yarn.am.cores")
+ .intConf
+ .createWithDefault(1)
+
+ private[spark] val AM_JAVA_OPTIONS = ConfigBuilder("spark.yarn.am.extraJavaOptions")
+ .doc("Extra Java options for the client-mode AM.")
+ .stringConf
+ .createOptional
+
+ private[spark] val AM_LIBRARY_PATH = ConfigBuilder("spark.yarn.am.extraLibraryPath")
+ .doc("Extra native library path for the client-mode AM.")
+ .stringConf
+ .createOptional
+
+ private[spark] val AM_MEMORY_OVERHEAD = ConfigBuilder("spark.yarn.am.memoryOverhead")
+ .bytesConf(ByteUnit.MiB)
+ .createOptional
+
+ private[spark] val AM_MEMORY = ConfigBuilder("spark.yarn.am.memory")
+ .bytesConf(ByteUnit.MiB)
+ .createWithDefaultString("512m")
+
+ /* Driver configuration. */
+
+ private[spark] val DRIVER_CORES = ConfigBuilder("spark.driver.cores")
+ .intConf
+ .createWithDefault(1)
+
+ private[spark] val DRIVER_MEMORY_OVERHEAD = ConfigBuilder("spark.yarn.driver.memoryOverhead")
+ .bytesConf(ByteUnit.MiB)
+ .createOptional
+
+ /* Executor configuration. */
+
+ private[spark] val EXECUTOR_CORES = ConfigBuilder("spark.executor.cores")
+ .intConf
+ .createWithDefault(1)
+
+ private[spark] val EXECUTOR_MEMORY_OVERHEAD = ConfigBuilder("spark.yarn.executor.memoryOverhead")
+ .bytesConf(ByteUnit.MiB)
+ .createOptional
+
+ private[spark] val EXECUTOR_NODE_LABEL_EXPRESSION =
+ ConfigBuilder("spark.yarn.executor.nodeLabelExpression")
+ .doc("Node label expression for executors.")
+ .stringConf
+ .createOptional
+
+ /* Security configuration. */
+
+ private[spark] val CREDENTIAL_FILE_MAX_COUNT =
+ ConfigBuilder("spark.yarn.credentials.file.retention.count")
+ .intConf
+ .createWithDefault(5)
+
+ private[spark] val CREDENTIALS_FILE_MAX_RETENTION =
+ ConfigBuilder("spark.yarn.credentials.file.retention.days")
+ .intConf
+ .createWithDefault(5)
+
+ private[spark] val NAMENODES_TO_ACCESS = ConfigBuilder("spark.yarn.access.namenodes")
+ .doc("Extra NameNode URLs for which to request delegation tokens. The NameNode that hosts " +
+ "fs.defaultFS does not need to be listed here.")
+ .stringConf
+ .toSequence
+ .createWithDefault(Nil)
+
+ /* Rolled log aggregation configuration. */
+
+ private[spark] val ROLLED_LOG_INCLUDE_PATTERN =
+ ConfigBuilder("spark.yarn.rolledLog.includePattern")
+ .doc("Java Regex to filter the log files which match the defined include pattern and those " +
+ "log files will be aggregated in a rolling fashion.")
+ .stringConf
+ .createOptional
+
+ private[spark] val ROLLED_LOG_EXCLUDE_PATTERN =
+ ConfigBuilder("spark.yarn.rolledLog.excludePattern")
+ .doc("Java Regex to filter the log files which match the defined exclude pattern and those " +
+ "log files will not be aggregated in a rolling fashion.")
+ .stringConf
+ .createOptional
+
+ /* Private configs. */
+
+ private[spark] val CREDENTIALS_FILE_PATH = ConfigBuilder("spark.yarn.credentials.file")
+ .internal()
+ .stringConf
+ .createWithDefault(null)
+
+ // Internal config to propagate the location of the user's jar to the driver/executors
+ private[spark] val APP_JAR = ConfigBuilder("spark.yarn.user.jar")
+ .internal()
+ .stringConf
+ .createOptional
+
+ // Internal config to propagate the locations of any extra jars to add to the classpath
+ // of the executors
+ private[spark] val SECONDARY_JARS = ConfigBuilder("spark.yarn.secondary.jars")
+ .internal()
+ .stringConf
+ .toSequence
+ .createOptional
+
+ /* Configuration and cached file propagation. */
+
+ private[spark] val CACHED_FILES = ConfigBuilder("spark.yarn.cache.filenames")
+ .internal()
+ .stringConf
+ .toSequence
+ .createWithDefault(Nil)
+
+ private[spark] val CACHED_FILES_SIZES = ConfigBuilder("spark.yarn.cache.sizes")
+ .internal()
+ .longConf
+ .toSequence
+ .createWithDefault(Nil)
+
+ private[spark] val CACHED_FILES_TIMESTAMPS = ConfigBuilder("spark.yarn.cache.timestamps")
+ .internal()
+ .longConf
+ .toSequence
+ .createWithDefault(Nil)
+
+ private[spark] val CACHED_FILES_VISIBILITIES = ConfigBuilder("spark.yarn.cache.visibilities")
+ .internal()
+ .stringConf
+ .toSequence
+ .createWithDefault(Nil)
+
+ // Either "file" or "archive", for each file.
+ private[spark] val CACHED_FILES_TYPES = ConfigBuilder("spark.yarn.cache.types")
+ .internal()
+ .stringConf
+ .toSequence
+ .createWithDefault(Nil)
+
+ // The location of the conf archive in HDFS.
+ private[spark] val CACHED_CONF_ARCHIVE = ConfigBuilder("spark.yarn.cache.confArchive")
+ .internal()
+ .stringConf
+ .createOptional
+
+ private[spark] val CREDENTIALS_RENEWAL_TIME = ConfigBuilder("spark.yarn.credentials.renewalTime")
+ .internal()
+ .timeConf(TimeUnit.MILLISECONDS)
+ .createWithDefault(Long.MaxValue)
+
+ private[spark] val CREDENTIALS_UPDATE_TIME = ConfigBuilder("spark.yarn.credentials.updateTime")
+ .internal()
+ .timeConf(TimeUnit.MILLISECONDS)
+ .createWithDefault(Long.MaxValue)
+
+ // The list of cache-related config entries. This is used by Client and the AM to clean
+ // up the environment so that these settings do not appear on the web UI.
+ private[yarn] val CACHE_CONFIGS = Seq(
+ CACHED_FILES,
+ CACHED_FILES_SIZES,
+ CACHED_FILES_TIMESTAMPS,
+ CACHED_FILES_VISIBILITIES,
+ CACHED_FILES_TYPES,
+ CACHED_CONF_ARCHIVE)
+
+}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala
new file mode 100644
index 0000000000..7e76f402db
--- /dev/null
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala
@@ -0,0 +1,235 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.yarn.security
+
+import java.security.PrivilegedExceptionAction
+import java.util.concurrent.{Executors, TimeUnit}
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.security.UserGroupInformation
+
+import org.apache.spark.SparkConf
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil
+import org.apache.spark.deploy.yarn.config._
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config._
+import org.apache.spark.util.ThreadUtils
+
+/**
+ * The following methods are primarily meant to make sure long-running apps like Spark
+ * Streaming apps can run without interruption while accessing secured services. The
+ * scheduleLoginFromKeytab method is called on the AM to get the new credentials.
+ * This method wakes up a thread that logs into the KDC
+ * once 75% of the renewal interval of the original credentials used for the container
+ * has elapsed. It then obtains new credentials and writes them to HDFS in a
+ * pre-specified location - the prefix of which is specified in the sparkConf by
+ * spark.yarn.credentials.file (so the file(s) would be named c-timestamp1-1, c-timestamp2-2 etc.
+ * - each update goes to a new file, with a monotonically increasing suffix), also the
+ * timestamp1, timestamp2 here indicates the time of next update for CredentialUpdater.
+ * After this, the credentials are renewed once 75% of the new tokens renewal interval has elapsed.
+ *
+ * On the executor and driver (yarn client mode) side, the updateCredentialsIfRequired method is
+ * called once 80% of the validity of the original credentials has elapsed. At that time the
+ * executor finds the credentials file with the latest timestamp and checks if it has read those
+ * credentials before (by keeping track of the suffix of the last file it read). If a new file has
+ * appeared, it will read the credentials and update the currently running UGI with it. This
+ * process happens again once 80% of the validity of this has expired.
+ */
+private[yarn] class AMCredentialRenewer(
+ sparkConf: SparkConf,
+ hadoopConf: Configuration,
+ credentialManager: ConfigurableCredentialManager) extends Logging {
+
+ private var lastCredentialsFileSuffix = 0
+
+ private val credentialRenewer =
+ Executors.newSingleThreadScheduledExecutor(
+ ThreadUtils.namedThreadFactory("Credential Refresh Thread"))
+
+ private val hadoopUtil = YarnSparkHadoopUtil.get
+
+ private val credentialsFile = sparkConf.get(CREDENTIALS_FILE_PATH)
+ private val daysToKeepFiles = sparkConf.get(CREDENTIALS_FILE_MAX_RETENTION)
+ private val numFilesToKeep = sparkConf.get(CREDENTIAL_FILE_MAX_COUNT)
+ private val freshHadoopConf =
+ hadoopUtil.getConfBypassingFSCache(hadoopConf, new Path(credentialsFile).toUri.getScheme)
+
+ @volatile private var timeOfNextRenewal = sparkConf.get(CREDENTIALS_RENEWAL_TIME)
+
+ /**
+ * Schedule a login from the keytab and principal set using the --principal and --keytab
+ * arguments to spark-submit. This login happens only when the credentials of the current user
+ * are about to expire. This method reads spark.yarn.principal and spark.yarn.keytab from
+ * SparkConf to do the login. This method is a no-op in non-YARN mode.
+ *
+ */
+ private[spark] def scheduleLoginFromKeytab(): Unit = {
+ val principal = sparkConf.get(PRINCIPAL).get
+ val keytab = sparkConf.get(KEYTAB).get
+
+ /**
+ * Schedule re-login and creation of new credentials. If credentials have already expired, this
+ * method will synchronously create new ones.
+ */
+ def scheduleRenewal(runnable: Runnable): Unit = {
+ // Run now!
+ val remainingTime = timeOfNextRenewal - System.currentTimeMillis()
+ if (remainingTime <= 0) {
+ logInfo("Credentials have expired, creating new ones now.")
+ runnable.run()
+ } else {
+ logInfo(s"Scheduling login from keytab in $remainingTime millis.")
+ credentialRenewer.schedule(runnable, remainingTime, TimeUnit.MILLISECONDS)
+ }
+ }
+
+ // This thread periodically runs on the AM to update the credentials on HDFS.
+ val credentialRenewerRunnable =
+ new Runnable {
+ override def run(): Unit = {
+ try {
+ writeNewCredentialsToHDFS(principal, keytab)
+ cleanupOldFiles()
+ } catch {
+ case e: Exception =>
+ // Log the error and try to write new tokens back in an hour
+ logWarning("Failed to write out new credentials to HDFS, will try again in an " +
+ "hour! If this happens too often tasks will fail.", e)
+ credentialRenewer.schedule(this, 1, TimeUnit.HOURS)
+ return
+ }
+ scheduleRenewal(this)
+ }
+ }
+ // Schedule update of credentials. This handles the case of updating the credentials right now
+ // as well, since the renewal interval will be 0, and the thread will get scheduled
+ // immediately.
+ scheduleRenewal(credentialRenewerRunnable)
+ }
+
+ // Keeps only files that are newer than daysToKeepFiles days, and deletes everything else. At
+ // least numFilesToKeep files are kept for safety
+ private def cleanupOldFiles(): Unit = {
+ import scala.concurrent.duration._
+ try {
+ val remoteFs = FileSystem.get(freshHadoopConf)
+ val credentialsPath = new Path(credentialsFile)
+ val thresholdTime = System.currentTimeMillis() - (daysToKeepFiles.days).toMillis
+ hadoopUtil.listFilesSorted(
+ remoteFs, credentialsPath.getParent,
+ credentialsPath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION)
+ .dropRight(numFilesToKeep)
+ .takeWhile(_.getModificationTime < thresholdTime)
+ .foreach(x => remoteFs.delete(x.getPath, true))
+ } catch {
+ // Such errors are not fatal, so don't throw. Make sure they are logged though
+ case e: Exception =>
+ logWarning("Error while attempting to cleanup old credentials. If you are seeing many " +
+ "such warnings there may be an issue with your HDFS cluster.", e)
+ }
+ }
+
+ private def writeNewCredentialsToHDFS(principal: String, keytab: String): Unit = {
+ // Keytab is copied by YARN to the working directory of the AM, so full path is
+ // not needed.
+
+ // HACK:
+ // HDFS will not issue new delegation tokens, if the Credentials object
+ // passed in already has tokens for that FS even if the tokens are expired (it really only
+ // checks if there are tokens for the service, and not if they are valid). So the only real
+ // way to get new tokens is to make sure a different Credentials object is used each time to
+ // get new tokens and then the new tokens are copied over the current user's Credentials.
+ // So:
+ // - we login as a different user and get the UGI
+ // - use that UGI to get the tokens (see doAs block below)
+ // - copy the tokens over to the current user's credentials (this will overwrite the tokens
+ // in the current user's Credentials object for this FS).
+ // The login to KDC happens each time new tokens are required, but this is rare enough to not
+ // have to worry about (like once every day or so). This makes this code clearer than having
+ // to login and then relogin every time (the HDFS API may not relogin since we don't use this
+ // UGI directly for HDFS communication.
+ logInfo(s"Attempting to login to KDC using principal: $principal")
+ val keytabLoggedInUGI = UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytab)
+ logInfo("Successfully logged into KDC.")
+ val tempCreds = keytabLoggedInUGI.getCredentials
+ val credentialsPath = new Path(credentialsFile)
+ val dst = credentialsPath.getParent
+ var nearestNextRenewalTime = Long.MaxValue
+ keytabLoggedInUGI.doAs(new PrivilegedExceptionAction[Void] {
+ // Get a copy of the credentials
+ override def run(): Void = {
+ nearestNextRenewalTime = credentialManager.obtainCredentials(freshHadoopConf, tempCreds)
+ null
+ }
+ })
+
+ val currTime = System.currentTimeMillis()
+ val timeOfNextUpdate = if (nearestNextRenewalTime <= currTime) {
+ // If next renewal time is earlier than current time, we set next renewal time to current
+ // time, this will trigger next renewal immediately. Also set next update time to current
+ // time. There still has a gap between token renewal and update will potentially introduce
+ // issue.
+ logWarning(s"Next credential renewal time ($nearestNextRenewalTime) is earlier than " +
+ s"current time ($currTime), which is unexpected, please check your credential renewal " +
+ "related configurations in the target services.")
+ timeOfNextRenewal = currTime
+ currTime
+ } else {
+ // Next valid renewal time is about 75% of credential renewal time, and update time is
+ // slightly later than valid renewal time (80% of renewal time).
+ timeOfNextRenewal = ((nearestNextRenewalTime - currTime) * 0.75 + currTime).toLong
+ ((nearestNextRenewalTime - currTime) * 0.8 + currTime).toLong
+ }
+
+ // Add the temp credentials back to the original ones.
+ UserGroupInformation.getCurrentUser.addCredentials(tempCreds)
+ val remoteFs = FileSystem.get(freshHadoopConf)
+ // If lastCredentialsFileSuffix is 0, then the AM is either started or restarted. If the AM
+ // was restarted, then the lastCredentialsFileSuffix might be > 0, so find the newest file
+ // and update the lastCredentialsFileSuffix.
+ if (lastCredentialsFileSuffix == 0) {
+ hadoopUtil.listFilesSorted(
+ remoteFs, credentialsPath.getParent,
+ credentialsPath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION)
+ .lastOption.foreach { status =>
+ lastCredentialsFileSuffix = hadoopUtil.getSuffixForCredentialsPath(status.getPath)
+ }
+ }
+ val nextSuffix = lastCredentialsFileSuffix + 1
+
+ val tokenPathStr =
+ credentialsFile + SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM +
+ timeOfNextUpdate.toLong.toString + SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM +
+ nextSuffix
+ val tokenPath = new Path(tokenPathStr)
+ val tempTokenPath = new Path(tokenPathStr + SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION)
+
+ logInfo("Writing out delegation tokens to " + tempTokenPath.toString)
+ val credentials = UserGroupInformation.getCurrentUser.getCredentials
+ credentials.writeTokenStorageFile(tempTokenPath, freshHadoopConf)
+ logInfo(s"Delegation Tokens written out successfully. Renaming file to $tokenPathStr")
+ remoteFs.rename(tempTokenPath, tokenPath)
+ logInfo("Delegation token file rename complete.")
+ lastCredentialsFileSuffix = nextSuffix
+ }
+
+ def stop(): Unit = {
+ credentialRenewer.shutdown()
+ }
+}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManager.scala
new file mode 100644
index 0000000000..c4c07b4930
--- /dev/null
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManager.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn.security
+
+import java.util.ServiceLoader
+
+import scala.collection.JavaConverters._
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.security.Credentials
+
+import org.apache.spark.SparkConf
+import org.apache.spark.internal.Logging
+import org.apache.spark.util.Utils
+
+/**
+ * A ConfigurableCredentialManager to manage all the registered credential providers and offer
+ * APIs for other modules to obtain credentials as well as renewal time. By default
+ * [[HDFSCredentialProvider]], [[HiveCredentialProvider]] and [[HBaseCredentialProvider]] will
+ * be loaded in if not explicitly disabled, any plugged-in credential provider wants to be
+ * managed by ConfigurableCredentialManager needs to implement [[ServiceCredentialProvider]]
+ * interface and put into resources/META-INF/services to be loaded by ServiceLoader.
+ *
+ * Also each credential provider is controlled by
+ * spark.yarn.security.credentials.{service}.enabled, it will not be loaded in if set to false.
+ */
+private[yarn] final class ConfigurableCredentialManager(
+ sparkConf: SparkConf, hadoopConf: Configuration) extends Logging {
+ private val deprecatedProviderEnabledConfig = "spark.yarn.security.tokens.%s.enabled"
+ private val providerEnabledConfig = "spark.yarn.security.credentials.%s.enabled"
+
+ // Maintain all the registered credential providers
+ private val credentialProviders = {
+ val providers = ServiceLoader.load(classOf[ServiceCredentialProvider],
+ Utils.getContextOrSparkClassLoader).asScala
+
+ // Filter out credentials in which spark.yarn.security.credentials.{service}.enabled is false.
+ providers.filter { p =>
+ sparkConf.getOption(providerEnabledConfig.format(p.serviceName))
+ .orElse {
+ sparkConf.getOption(deprecatedProviderEnabledConfig.format(p.serviceName)).map { c =>
+ logWarning(s"${deprecatedProviderEnabledConfig.format(p.serviceName)} is deprecated, " +
+ s"using ${providerEnabledConfig.format(p.serviceName)} instead")
+ c
+ }
+ }.map(_.toBoolean).getOrElse(true)
+ }.map { p => (p.serviceName, p) }.toMap
+ }
+
+ /**
+ * Get credential provider for the specified service.
+ */
+ def getServiceCredentialProvider(service: String): Option[ServiceCredentialProvider] = {
+ credentialProviders.get(service)
+ }
+
+ /**
+ * Obtain credentials from all the registered providers.
+ * @return nearest time of next renewal, Long.MaxValue if all the credentials aren't renewable,
+ * otherwise the nearest renewal time of any credentials will be returned.
+ */
+ def obtainCredentials(hadoopConf: Configuration, creds: Credentials): Long = {
+ credentialProviders.values.flatMap { provider =>
+ if (provider.credentialsRequired(hadoopConf)) {
+ provider.obtainCredentials(hadoopConf, sparkConf, creds)
+ } else {
+ logDebug(s"Service ${provider.serviceName} does not require a token." +
+ s" Check your configuration to see if security is disabled or not.")
+ None
+ }
+ }.foldLeft(Long.MaxValue)(math.min)
+ }
+
+ /**
+ * Create an [[AMCredentialRenewer]] instance, caller should be responsible to stop this
+ * instance when it is not used. AM will use it to renew credentials periodically.
+ */
+ def credentialRenewer(): AMCredentialRenewer = {
+ new AMCredentialRenewer(sparkConf, hadoopConf, this)
+ }
+
+ /**
+ * Create an [[CredentialUpdater]] instance, caller should be resposible to stop this intance
+ * when it is not used. Executors and driver (client mode) will use it to update credentials.
+ * periodically.
+ */
+ def credentialUpdater(): CredentialUpdater = {
+ new CredentialUpdater(sparkConf, hadoopConf, this)
+ }
+}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala
new file mode 100644
index 0000000000..5df4fbd9c1
--- /dev/null
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn.security
+
+import java.util.concurrent.{Executors, TimeUnit}
+
+import scala.util.control.NonFatal
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.security.{Credentials, UserGroupInformation}
+
+import org.apache.spark.SparkConf
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.deploy.yarn.config._
+import org.apache.spark.internal.Logging
+import org.apache.spark.util.{ThreadUtils, Utils}
+
+private[spark] class CredentialUpdater(
+ sparkConf: SparkConf,
+ hadoopConf: Configuration,
+ credentialManager: ConfigurableCredentialManager) extends Logging {
+
+ @volatile private var lastCredentialsFileSuffix = 0
+
+ private val credentialsFile = sparkConf.get(CREDENTIALS_FILE_PATH)
+ private val freshHadoopConf =
+ SparkHadoopUtil.get.getConfBypassingFSCache(
+ hadoopConf, new Path(credentialsFile).toUri.getScheme)
+
+ private val credentialUpdater =
+ Executors.newSingleThreadScheduledExecutor(
+ ThreadUtils.namedThreadFactory("Credential Refresh Thread"))
+
+ // This thread wakes up and picks up new credentials from HDFS, if any.
+ private val credentialUpdaterRunnable =
+ new Runnable {
+ override def run(): Unit = Utils.logUncaughtExceptions(updateCredentialsIfRequired())
+ }
+
+ /** Start the credential updater task */
+ def start(): Unit = {
+ val startTime = sparkConf.get(CREDENTIALS_RENEWAL_TIME)
+ val remainingTime = startTime - System.currentTimeMillis()
+ if (remainingTime <= 0) {
+ credentialUpdater.schedule(credentialUpdaterRunnable, 1, TimeUnit.MINUTES)
+ } else {
+ logInfo(s"Scheduling credentials refresh from HDFS in $remainingTime millis.")
+ credentialUpdater.schedule(credentialUpdaterRunnable, remainingTime, TimeUnit.MILLISECONDS)
+ }
+ }
+
+ private def updateCredentialsIfRequired(): Unit = {
+ val timeToNextUpdate = try {
+ val credentialsFilePath = new Path(credentialsFile)
+ val remoteFs = FileSystem.get(freshHadoopConf)
+ SparkHadoopUtil.get.listFilesSorted(
+ remoteFs, credentialsFilePath.getParent,
+ credentialsFilePath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION)
+ .lastOption.map { credentialsStatus =>
+ val suffix = SparkHadoopUtil.get.getSuffixForCredentialsPath(credentialsStatus.getPath)
+ if (suffix > lastCredentialsFileSuffix) {
+ logInfo("Reading new credentials from " + credentialsStatus.getPath)
+ val newCredentials = getCredentialsFromHDFSFile(remoteFs, credentialsStatus.getPath)
+ lastCredentialsFileSuffix = suffix
+ UserGroupInformation.getCurrentUser.addCredentials(newCredentials)
+ logInfo("Credentials updated from credentials file.")
+
+ val remainingTime = getTimeOfNextUpdateFromFileName(credentialsStatus.getPath)
+ - System.currentTimeMillis()
+ if (remainingTime <= 0) TimeUnit.MINUTES.toMillis(1) else remainingTime
+ } else {
+ // If current credential file is older than expected, sleep 1 hour and check again.
+ TimeUnit.HOURS.toMillis(1)
+ }
+ }.getOrElse {
+ // Wait for 1 minute to check again if there's no credential file currently
+ TimeUnit.MINUTES.toMillis(1)
+ }
+ } catch {
+ // Since the file may get deleted while we are reading it, catch the Exception and come
+ // back in an hour to try again
+ case NonFatal(e) =>
+ logWarning("Error while trying to update credentials, will try again in 1 hour", e)
+ TimeUnit.HOURS.toMillis(1)
+ }
+
+ credentialUpdater.schedule(
+ credentialUpdaterRunnable, timeToNextUpdate, TimeUnit.MILLISECONDS)
+ }
+
+ private def getCredentialsFromHDFSFile(remoteFs: FileSystem, tokenPath: Path): Credentials = {
+ val stream = remoteFs.open(tokenPath)
+ try {
+ val newCredentials = new Credentials()
+ newCredentials.readTokenStorageStream(stream)
+ newCredentials
+ } finally {
+ stream.close()
+ }
+ }
+
+ private def getTimeOfNextUpdateFromFileName(credentialsPath: Path): Long = {
+ val name = credentialsPath.getName
+ val index = name.lastIndexOf(SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM)
+ val slice = name.substring(0, index)
+ val last2index = slice.lastIndexOf(SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM)
+ name.substring(last2index + 1, index).toLong
+ }
+
+ def stop(): Unit = {
+ credentialUpdater.shutdown()
+ }
+
+}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala
new file mode 100644
index 0000000000..5571df09a2
--- /dev/null
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn.security
+
+import scala.reflect.runtime.universe
+import scala.util.control.NonFatal
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.security.Credentials
+import org.apache.hadoop.security.token.{Token, TokenIdentifier}
+
+import org.apache.spark.SparkConf
+import org.apache.spark.internal.Logging
+
+private[security] class HBaseCredentialProvider extends ServiceCredentialProvider with Logging {
+
+ override def serviceName: String = "hbase"
+
+ override def obtainCredentials(
+ hadoopConf: Configuration,
+ sparkConf: SparkConf,
+ creds: Credentials): Option[Long] = {
+ try {
+ val mirror = universe.runtimeMirror(getClass.getClassLoader)
+ val obtainToken = mirror.classLoader.
+ loadClass("org.apache.hadoop.hbase.security.token.TokenUtil").
+ getMethod("obtainToken", classOf[Configuration])
+
+ logDebug("Attempting to fetch HBase security token.")
+ val token = obtainToken.invoke(null, hbaseConf(hadoopConf))
+ .asInstanceOf[Token[_ <: TokenIdentifier]]
+ logInfo(s"Get token from HBase: ${token.toString}")
+ creds.addToken(token.getService, token)
+ } catch {
+ case NonFatal(e) =>
+ logDebug(s"Failed to get token from service $serviceName", e)
+ }
+
+ None
+ }
+
+ override def credentialsRequired(hadoopConf: Configuration): Boolean = {
+ hbaseConf(hadoopConf).get("hbase.security.authentication") == "kerberos"
+ }
+
+ private def hbaseConf(conf: Configuration): Configuration = {
+ try {
+ val mirror = universe.runtimeMirror(getClass.getClassLoader)
+ val confCreate = mirror.classLoader.
+ loadClass("org.apache.hadoop.hbase.HBaseConfiguration").
+ getMethod("create", classOf[Configuration])
+ confCreate.invoke(null, conf).asInstanceOf[Configuration]
+ } catch {
+ case NonFatal(e) =>
+ logDebug("Fail to invoke HBaseConfiguration", e)
+ conf
+ }
+ }
+}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HDFSCredentialProvider.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HDFSCredentialProvider.scala
new file mode 100644
index 0000000000..8d06d735ba
--- /dev/null
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HDFSCredentialProvider.scala
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn.security
+
+import java.io.{ByteArrayInputStream, DataInputStream}
+
+import scala.collection.JavaConverters._
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier
+import org.apache.hadoop.mapred.Master
+import org.apache.hadoop.security.Credentials
+
+import org.apache.spark.{SparkConf, SparkException}
+import org.apache.spark.deploy.yarn.config._
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config._
+
+private[security] class HDFSCredentialProvider extends ServiceCredentialProvider with Logging {
+ // Token renewal interval, this value will be set in the first call,
+ // if None means no token renewer specified, so cannot get token renewal interval.
+ private var tokenRenewalInterval: Option[Long] = null
+
+ override val serviceName: String = "hdfs"
+
+ override def obtainCredentials(
+ hadoopConf: Configuration,
+ sparkConf: SparkConf,
+ creds: Credentials): Option[Long] = {
+ // NameNode to access, used to get tokens from different FileSystems
+ nnsToAccess(hadoopConf, sparkConf).foreach { dst =>
+ val dstFs = dst.getFileSystem(hadoopConf)
+ logInfo("getting token for namenode: " + dst)
+ dstFs.addDelegationTokens(getTokenRenewer(hadoopConf), creds)
+ }
+
+ // Get the token renewal interval if it is not set. It will only be called once.
+ if (tokenRenewalInterval == null) {
+ tokenRenewalInterval = getTokenRenewalInterval(hadoopConf, sparkConf)
+ }
+
+ // Get the time of next renewal.
+ tokenRenewalInterval.map { interval =>
+ creds.getAllTokens.asScala
+ .filter(_.getKind == DelegationTokenIdentifier.HDFS_DELEGATION_KIND)
+ .map { t =>
+ val identifier = new DelegationTokenIdentifier()
+ identifier.readFields(new DataInputStream(new ByteArrayInputStream(t.getIdentifier)))
+ identifier.getIssueDate + interval
+ }.foldLeft(0L)(math.max)
+ }
+ }
+
+ private def getTokenRenewalInterval(
+ hadoopConf: Configuration, sparkConf: SparkConf): Option[Long] = {
+ // We cannot use the tokens generated with renewer yarn. Trying to renew
+ // those will fail with an access control issue. So create new tokens with the logged in
+ // user as renewer.
+ sparkConf.get(PRINCIPAL).map { renewer =>
+ val creds = new Credentials()
+ nnsToAccess(hadoopConf, sparkConf).foreach { dst =>
+ val dstFs = dst.getFileSystem(hadoopConf)
+ dstFs.addDelegationTokens(renewer, creds)
+ }
+ val t = creds.getAllTokens.asScala
+ .filter(_.getKind == DelegationTokenIdentifier.HDFS_DELEGATION_KIND)
+ .head
+ val newExpiration = t.renew(hadoopConf)
+ val identifier = new DelegationTokenIdentifier()
+ identifier.readFields(new DataInputStream(new ByteArrayInputStream(t.getIdentifier)))
+ val interval = newExpiration - identifier.getIssueDate
+ logInfo(s"Renewal Interval is $interval")
+ interval
+ }
+ }
+
+ private def getTokenRenewer(conf: Configuration): String = {
+ val delegTokenRenewer = Master.getMasterPrincipal(conf)
+ logDebug("delegation token renewer is: " + delegTokenRenewer)
+ if (delegTokenRenewer == null || delegTokenRenewer.length() == 0) {
+ val errorMessage = "Can't get Master Kerberos principal for use as renewer"
+ logError(errorMessage)
+ throw new SparkException(errorMessage)
+ }
+
+ delegTokenRenewer
+ }
+
+ private def nnsToAccess(hadoopConf: Configuration, sparkConf: SparkConf): Set[Path] = {
+ sparkConf.get(NAMENODES_TO_ACCESS).map(new Path(_)).toSet +
+ sparkConf.get(STAGING_DIR).map(new Path(_))
+ .getOrElse(FileSystem.get(hadoopConf).getHomeDirectory)
+ }
+}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HiveCredentialProvider.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HiveCredentialProvider.scala
new file mode 100644
index 0000000000..16d8fc32bb
--- /dev/null
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HiveCredentialProvider.scala
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn.security
+
+import java.lang.reflect.UndeclaredThrowableException
+import java.security.PrivilegedExceptionAction
+
+import scala.reflect.runtime.universe
+import scala.util.control.NonFatal
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier
+import org.apache.hadoop.io.Text
+import org.apache.hadoop.security.{Credentials, UserGroupInformation}
+import org.apache.hadoop.security.token.Token
+
+import org.apache.spark.SparkConf
+import org.apache.spark.internal.Logging
+import org.apache.spark.util.Utils
+
+private[security] class HiveCredentialProvider extends ServiceCredentialProvider with Logging {
+
+ override def serviceName: String = "hive"
+
+ private def hiveConf(hadoopConf: Configuration): Configuration = {
+ try {
+ val mirror = universe.runtimeMirror(Utils.getContextOrSparkClassLoader)
+ // the hive configuration class is a subclass of Hadoop Configuration, so can be cast down
+ // to a Configuration and used without reflection
+ val hiveConfClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.conf.HiveConf")
+ // using the (Configuration, Class) constructor allows the current configuration to be
+ // included in the hive config.
+ val ctor = hiveConfClass.getDeclaredConstructor(classOf[Configuration],
+ classOf[Object].getClass)
+ ctor.newInstance(hadoopConf, hiveConfClass).asInstanceOf[Configuration]
+ } catch {
+ case NonFatal(e) =>
+ logDebug("Fail to create Hive Configuration", e)
+ hadoopConf
+ }
+ }
+
+ override def credentialsRequired(hadoopConf: Configuration): Boolean = {
+ UserGroupInformation.isSecurityEnabled &&
+ hiveConf(hadoopConf).getTrimmed("hive.metastore.uris", "").nonEmpty
+ }
+
+ override def obtainCredentials(
+ hadoopConf: Configuration,
+ sparkConf: SparkConf,
+ creds: Credentials): Option[Long] = {
+ val conf = hiveConf(hadoopConf)
+
+ val principalKey = "hive.metastore.kerberos.principal"
+ val principal = conf.getTrimmed(principalKey, "")
+ require(principal.nonEmpty, s"Hive principal $principalKey undefined")
+ val metastoreUri = conf.getTrimmed("hive.metastore.uris", "")
+ require(metastoreUri.nonEmpty, "Hive metastore uri undefined")
+
+ val currentUser = UserGroupInformation.getCurrentUser()
+ logDebug(s"Getting Hive delegation token for ${currentUser.getUserName()} against " +
+ s"$principal at $metastoreUri")
+
+ val mirror = universe.runtimeMirror(Utils.getContextOrSparkClassLoader)
+ val hiveClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.ql.metadata.Hive")
+ val hiveConfClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.conf.HiveConf")
+ val closeCurrent = hiveClass.getMethod("closeCurrent")
+
+ try {
+ // get all the instance methods before invoking any
+ val getDelegationToken = hiveClass.getMethod("getDelegationToken",
+ classOf[String], classOf[String])
+ val getHive = hiveClass.getMethod("get", hiveConfClass)
+
+ doAsRealUser {
+ val hive = getHive.invoke(null, conf)
+ val tokenStr = getDelegationToken.invoke(hive, currentUser.getUserName(), principal)
+ .asInstanceOf[String]
+ val hive2Token = new Token[DelegationTokenIdentifier]()
+ hive2Token.decodeFromUrlString(tokenStr)
+ logInfo(s"Get Token from hive metastore: ${hive2Token.toString}")
+ creds.addToken(new Text("hive.server2.delegation.token"), hive2Token)
+ }
+ } catch {
+ case NonFatal(e) =>
+ logDebug(s"Fail to get token from service $serviceName", e)
+ } finally {
+ Utils.tryLogNonFatalError {
+ closeCurrent.invoke(null)
+ }
+ }
+
+ None
+ }
+
+ /**
+ * Run some code as the real logged in user (which may differ from the current user, for
+ * example, when using proxying).
+ */
+ private def doAsRealUser[T](fn: => T): T = {
+ val currentUser = UserGroupInformation.getCurrentUser()
+ val realUser = Option(currentUser.getRealUser()).getOrElse(currentUser)
+
+ // For some reason the Scala-generated anonymous class ends up causing an
+ // UndeclaredThrowableException, even if you annotate the method with @throws.
+ try {
+ realUser.doAs(new PrivilegedExceptionAction[T]() {
+ override def run(): T = fn
+ })
+ } catch {
+ case e: UndeclaredThrowableException => throw Option(e.getCause()).getOrElse(e)
+ }
+ }
+}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ServiceCredentialProvider.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ServiceCredentialProvider.scala
new file mode 100644
index 0000000000..4e3fcce8db
--- /dev/null
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ServiceCredentialProvider.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn.security
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.security.{Credentials, UserGroupInformation}
+
+import org.apache.spark.SparkConf
+
+/**
+ * A credential provider for a service. User must implement this if they need to access a
+ * secure service from Spark.
+ */
+trait ServiceCredentialProvider {
+
+ /**
+ * Name of the service to provide credentials. This name should unique, Spark internally will
+ * use this name to differentiate credential provider.
+ */
+ def serviceName: String
+
+ /**
+ * To decide whether credential is required for this service. By default it based on whether
+ * Hadoop security is enabled.
+ */
+ def credentialsRequired(hadoopConf: Configuration): Boolean = {
+ UserGroupInformation.isSecurityEnabled
+ }
+
+ /**
+ * Obtain credentials for this service and get the time of the next renewal.
+ * @param hadoopConf Configuration of current Hadoop Compatible system.
+ * @param sparkConf Spark configuration.
+ * @param creds Credentials to add tokens and security keys to.
+ * @return If this Credential is renewable and can be renewed, return the time of the next
+ * renewal, otherwise None should be returned.
+ */
+ def obtainCredentials(
+ hadoopConf: Configuration,
+ sparkConf: SparkConf,
+ creds: Credentials): Option[Long]
+}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala
new file mode 100644
index 0000000000..6c3556a2ee
--- /dev/null
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.launcher
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ListBuffer
+import scala.util.Properties
+
+/**
+ * Exposes methods from the launcher library that are used by the YARN backend.
+ */
+private[spark] object YarnCommandBuilderUtils {
+
+ def quoteForBatchScript(arg: String): String = {
+ CommandBuilderUtils.quoteForBatchScript(arg)
+ }
+
+ def findJarsDir(sparkHome: String): String = {
+ val scalaVer = Properties.versionNumberString
+ .split("\\.")
+ .take(2)
+ .mkString(".")
+ CommandBuilderUtils.findJarsDir(sparkHome, scalaVer, true)
+ }
+
+ /**
+ * Adds the perm gen configuration to the list of java options if needed and not yet added.
+ *
+ * Note that this method adds the option based on the local JVM version; if the node where
+ * the container is running has a different Java version, there's a risk that the option will
+ * not be added (e.g. if the AM is running Java 8 but the container's node is set up to use
+ * Java 7).
+ */
+ def addPermGenSizeOpt(args: ListBuffer[String]): Unit = {
+ CommandBuilderUtils.addPermGenSizeOpt(args.asJava)
+ }
+
+}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala
new file mode 100644
index 0000000000..4ed285230f
--- /dev/null
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import java.util.concurrent.atomic.AtomicBoolean
+
+import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId}
+
+import org.apache.spark.SparkContext
+import org.apache.spark.deploy.yarn.config._
+import org.apache.spark.internal.Logging
+import org.apache.spark.util.Utils
+
+/**
+ * An extension service that can be loaded into a Spark YARN scheduler.
+ * A Service that can be started and stopped.
+ *
+ * 1. For implementations to be loadable by `SchedulerExtensionServices`,
+ * they must provide an empty constructor.
+ * 2. The `stop()` operation MUST be idempotent, and succeed even if `start()` was
+ * never invoked.
+ */
+trait SchedulerExtensionService {
+
+ /**
+ * Start the extension service. This should be a no-op if
+ * called more than once.
+ * @param binding binding to the spark application and YARN
+ */
+ def start(binding: SchedulerExtensionServiceBinding): Unit
+
+ /**
+ * Stop the service
+ * The `stop()` operation MUST be idempotent, and succeed even if `start()` was
+ * never invoked.
+ */
+ def stop(): Unit
+}
+
+/**
+ * Binding information for a [[SchedulerExtensionService]].
+ *
+ * The attempt ID will be set if the service is started within a YARN application master;
+ * there is then a different attempt ID for every time that AM is restarted.
+ * When the service binding is instantiated in client mode, there's no attempt ID, as it lacks
+ * this information.
+ * @param sparkContext current spark context
+ * @param applicationId YARN application ID
+ * @param attemptId YARN attemptID. This will always be unset in client mode, and always set in
+ * cluster mode.
+ */
+case class SchedulerExtensionServiceBinding(
+ sparkContext: SparkContext,
+ applicationId: ApplicationId,
+ attemptId: Option[ApplicationAttemptId] = None)
+
+/**
+ * Container for [[SchedulerExtensionService]] instances.
+ *
+ * Loads Extension Services from the configuration property
+ * `"spark.yarn.services"`, instantiates and starts them.
+ * When stopped, it stops all child entries.
+ *
+ * The order in which child extension services are started and stopped
+ * is undefined.
+ */
+private[spark] class SchedulerExtensionServices extends SchedulerExtensionService
+ with Logging {
+ private var serviceOption: Option[String] = None
+ private var services: List[SchedulerExtensionService] = Nil
+ private val started = new AtomicBoolean(false)
+ private var binding: SchedulerExtensionServiceBinding = _
+
+ /**
+ * Binding operation will load the named services and call bind on them too; the
+ * entire set of services are then ready for `init()` and `start()` calls.
+ *
+ * @param binding binding to the spark application and YARN
+ */
+ def start(binding: SchedulerExtensionServiceBinding): Unit = {
+ if (started.getAndSet(true)) {
+ logWarning("Ignoring re-entrant start operation")
+ return
+ }
+ require(binding.sparkContext != null, "Null context parameter")
+ require(binding.applicationId != null, "Null appId parameter")
+ this.binding = binding
+ val sparkContext = binding.sparkContext
+ val appId = binding.applicationId
+ val attemptId = binding.attemptId
+ logInfo(s"Starting Yarn extension services with app $appId and attemptId $attemptId")
+
+ services = sparkContext.conf.get(SCHEDULER_SERVICES).map { sClass =>
+ val instance = Utils.classForName(sClass)
+ .newInstance()
+ .asInstanceOf[SchedulerExtensionService]
+ // bind this service
+ instance.start(binding)
+ logInfo(s"Service $sClass started")
+ instance
+ }.toList
+ }
+
+ /**
+ * Get the list of services.
+ *
+ * @return a list of services; Nil until the service is started
+ */
+ def getServices: List[SchedulerExtensionService] = services
+
+ /**
+ * Stop the services; idempotent.
+ *
+ */
+ override def stop(): Unit = {
+ if (started.getAndSet(false)) {
+ logInfo(s"Stopping $this")
+ services.foreach { s =>
+ Utils.tryLogNonFatalError(s.stop())
+ }
+ }
+ }
+
+ override def toString(): String = s"""SchedulerExtensionServices
+ |(serviceOption=$serviceOption,
+ | services=$services,
+ | started=$started)""".stripMargin
+}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
new file mode 100644
index 0000000000..60da356ad1
--- /dev/null
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.hadoop.yarn.api.records.YarnApplicationState
+
+import org.apache.spark.{SparkContext, SparkException}
+import org.apache.spark.deploy.yarn.{Client, ClientArguments, YarnSparkHadoopUtil}
+import org.apache.spark.internal.Logging
+import org.apache.spark.launcher.SparkAppHandle
+import org.apache.spark.scheduler.TaskSchedulerImpl
+
+private[spark] class YarnClientSchedulerBackend(
+ scheduler: TaskSchedulerImpl,
+ sc: SparkContext)
+ extends YarnSchedulerBackend(scheduler, sc)
+ with Logging {
+
+ private var client: Client = null
+ private var monitorThread: MonitorThread = null
+
+ /**
+ * Create a Yarn client to submit an application to the ResourceManager.
+ * This waits until the application is running.
+ */
+ override def start() {
+ val driverHost = conf.get("spark.driver.host")
+ val driverPort = conf.get("spark.driver.port")
+ val hostport = driverHost + ":" + driverPort
+ sc.ui.foreach { ui => conf.set("spark.driver.appUIAddress", ui.webUrl) }
+
+ val argsArrayBuf = new ArrayBuffer[String]()
+ argsArrayBuf += ("--arg", hostport)
+
+ logDebug("ClientArguments called with: " + argsArrayBuf.mkString(" "))
+ val args = new ClientArguments(argsArrayBuf.toArray)
+ totalExpectedExecutors = YarnSparkHadoopUtil.getInitialTargetExecutorNumber(conf)
+ client = new Client(args, conf)
+ bindToYarn(client.submitApplication(), None)
+
+ // SPARK-8687: Ensure all necessary properties have already been set before
+ // we initialize our driver scheduler backend, which serves these properties
+ // to the executors
+ super.start()
+ waitForApplication()
+
+ // SPARK-8851: In yarn-client mode, the AM still does the credentials refresh. The driver
+ // reads the credentials from HDFS, just like the executors and updates its own credentials
+ // cache.
+ if (conf.contains("spark.yarn.credentials.file")) {
+ YarnSparkHadoopUtil.get.startCredentialUpdater(conf)
+ }
+ monitorThread = asyncMonitorApplication()
+ monitorThread.start()
+ }
+
+ /**
+ * Report the state of the application until it is running.
+ * If the application has finished, failed or been killed in the process, throw an exception.
+ * This assumes both `client` and `appId` have already been set.
+ */
+ private def waitForApplication(): Unit = {
+ assert(client != null && appId.isDefined, "Application has not been submitted yet!")
+ val (state, _) = client.monitorApplication(appId.get, returnOnRunning = true) // blocking
+ if (state == YarnApplicationState.FINISHED ||
+ state == YarnApplicationState.FAILED ||
+ state == YarnApplicationState.KILLED) {
+ throw new SparkException("Yarn application has already ended! " +
+ "It might have been killed or unable to launch application master.")
+ }
+ if (state == YarnApplicationState.RUNNING) {
+ logInfo(s"Application ${appId.get} has started running.")
+ }
+ }
+
+ /**
+ * We create this class for SPARK-9519. Basically when we interrupt the monitor thread it's
+ * because the SparkContext is being shut down(sc.stop() called by user code), but if
+ * monitorApplication return, it means the Yarn application finished before sc.stop() was called,
+ * which means we should call sc.stop() here, and we don't allow the monitor to be interrupted
+ * before SparkContext stops successfully.
+ */
+ private class MonitorThread extends Thread {
+ private var allowInterrupt = true
+
+ override def run() {
+ try {
+ val (state, _) = client.monitorApplication(appId.get, logApplicationReport = false)
+ logError(s"Yarn application has already exited with state $state!")
+ allowInterrupt = false
+ sc.stop()
+ } catch {
+ case e: InterruptedException => logInfo("Interrupting monitor thread")
+ }
+ }
+
+ def stopMonitor(): Unit = {
+ if (allowInterrupt) {
+ this.interrupt()
+ }
+ }
+ }
+
+ /**
+ * Monitor the application state in a separate thread.
+ * If the application has exited for any reason, stop the SparkContext.
+ * This assumes both `client` and `appId` have already been set.
+ */
+ private def asyncMonitorApplication(): MonitorThread = {
+ assert(client != null && appId.isDefined, "Application has not been submitted yet!")
+ val t = new MonitorThread
+ t.setName("Yarn application state monitor")
+ t.setDaemon(true)
+ t
+ }
+
+ /**
+ * Stop the scheduler. This assumes `start()` has already been called.
+ */
+ override def stop() {
+ assert(client != null, "Attempted to stop this scheduler before starting it!")
+ if (monitorThread != null) {
+ monitorThread.stopMonitor()
+ }
+
+ // Report a final state to the launcher if one is connected. This is needed since in client
+ // mode this backend doesn't let the app monitor loop run to completion, so it does not report
+ // the final state itself.
+ //
+ // Note: there's not enough information at this point to provide a better final state,
+ // so assume the application was successful.
+ client.reportLauncherState(SparkAppHandle.State.FINISHED)
+
+ super.stop()
+ YarnSparkHadoopUtil.get.stopCredentialUpdater()
+ client.stop()
+ logInfo("Stopped")
+ }
+
+}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala
new file mode 100644
index 0000000000..64cd1bd088
--- /dev/null
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import org.apache.spark.{SparkContext, SparkException}
+import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl}
+
+/**
+ * Cluster Manager for creation of Yarn scheduler and backend
+ */
+private[spark] class YarnClusterManager extends ExternalClusterManager {
+
+ override def canCreate(masterURL: String): Boolean = {
+ masterURL == "yarn"
+ }
+
+ override def createTaskScheduler(sc: SparkContext, masterURL: String): TaskScheduler = {
+ sc.deployMode match {
+ case "cluster" => new YarnClusterScheduler(sc)
+ case "client" => new YarnScheduler(sc)
+ case _ => throw new SparkException(s"Unknown deploy mode '${sc.deployMode}' for Yarn")
+ }
+ }
+
+ override def createSchedulerBackend(sc: SparkContext,
+ masterURL: String,
+ scheduler: TaskScheduler): SchedulerBackend = {
+ sc.deployMode match {
+ case "cluster" =>
+ new YarnClusterSchedulerBackend(scheduler.asInstanceOf[TaskSchedulerImpl], sc)
+ case "client" =>
+ new YarnClientSchedulerBackend(scheduler.asInstanceOf[TaskSchedulerImpl], sc)
+ case _ =>
+ throw new SparkException(s"Unknown deploy mode '${sc.deployMode}' for Yarn")
+ }
+ }
+
+ override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = {
+ scheduler.asInstanceOf[TaskSchedulerImpl].initialize(backend)
+ }
+}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
new file mode 100644
index 0000000000..96c9151fc3
--- /dev/null
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import org.apache.spark._
+import org.apache.spark.deploy.yarn.ApplicationMaster
+
+/**
+ * This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of
+ * ApplicationMaster, etc is done
+ */
+private[spark] class YarnClusterScheduler(sc: SparkContext) extends YarnScheduler(sc) {
+
+ logInfo("Created YarnClusterScheduler")
+
+ override def postStartHook() {
+ ApplicationMaster.sparkContextInitialized(sc)
+ super.postStartHook()
+ logInfo("YarnClusterScheduler.postStartHook done")
+ }
+
+}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala
new file mode 100644
index 0000000000..4f3d5ebf40
--- /dev/null
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+
+import org.apache.spark.SparkContext
+import org.apache.spark.deploy.yarn.{ApplicationMaster, YarnSparkHadoopUtil}
+import org.apache.spark.scheduler.TaskSchedulerImpl
+import org.apache.spark.util.Utils
+
+private[spark] class YarnClusterSchedulerBackend(
+ scheduler: TaskSchedulerImpl,
+ sc: SparkContext)
+ extends YarnSchedulerBackend(scheduler, sc) {
+
+ override def start() {
+ val attemptId = ApplicationMaster.getAttemptId
+ bindToYarn(attemptId.getApplicationId(), Some(attemptId))
+ super.start()
+ totalExpectedExecutors = YarnSparkHadoopUtil.getInitialTargetExecutorNumber(sc.conf)
+ }
+
+ override def getDriverLogUrls: Option[Map[String, String]] = {
+ var driverLogs: Option[Map[String, String]] = None
+ try {
+ val yarnConf = new YarnConfiguration(sc.hadoopConfiguration)
+ val containerId = YarnSparkHadoopUtil.get.getContainerId
+
+ val httpAddress = System.getenv(Environment.NM_HOST.name()) +
+ ":" + System.getenv(Environment.NM_HTTP_PORT.name())
+ // lookup appropriate http scheme for container log urls
+ val yarnHttpPolicy = yarnConf.get(
+ YarnConfiguration.YARN_HTTP_POLICY_KEY,
+ YarnConfiguration.YARN_HTTP_POLICY_DEFAULT
+ )
+ val user = Utils.getCurrentUserName()
+ val httpScheme = if (yarnHttpPolicy == "HTTPS_ONLY") "https://" else "http://"
+ val baseUrl = s"$httpScheme$httpAddress/node/containerlogs/$containerId/$user"
+ logDebug(s"Base URL for logs: $baseUrl")
+ driverLogs = Some(Map(
+ "stdout" -> s"$baseUrl/stdout?start=-4096",
+ "stderr" -> s"$baseUrl/stderr?start=-4096"))
+ } catch {
+ case e: Exception =>
+ logInfo("Error while building AM log links, so AM" +
+ " logs link will not appear in application UI", e)
+ }
+ driverLogs
+ }
+}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala
new file mode 100644
index 0000000000..029382133d
--- /dev/null
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import org.apache.hadoop.yarn.util.RackResolver
+import org.apache.log4j.{Level, Logger}
+
+import org.apache.spark._
+import org.apache.spark.scheduler.TaskSchedulerImpl
+import org.apache.spark.util.Utils
+
+private[spark] class YarnScheduler(sc: SparkContext) extends TaskSchedulerImpl(sc) {
+
+ // RackResolver logs an INFO message whenever it resolves a rack, which is way too often.
+ if (Logger.getLogger(classOf[RackResolver]).getLevel == null) {
+ Logger.getLogger(classOf[RackResolver]).setLevel(Level.WARN)
+ }
+
+ // By default, rack is unknown
+ override def getRackForHost(hostPort: String): Option[String] = {
+ val host = Utils.parseHostPort(hostPort)._1
+ Option(RackResolver.resolve(sc.hadoopConfiguration, host).getNetworkLocation)
+ }
+}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
new file mode 100644
index 0000000000..2f9ea1911f
--- /dev/null
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
@@ -0,0 +1,315 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import scala.concurrent.{ExecutionContext, Future}
+import scala.util.{Failure, Success}
+import scala.util.control.NonFatal
+
+import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId}
+
+import org.apache.spark.SparkContext
+import org.apache.spark.internal.Logging
+import org.apache.spark.rpc._
+import org.apache.spark.scheduler._
+import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
+import org.apache.spark.ui.JettyUtils
+import org.apache.spark.util.{RpcUtils, ThreadUtils}
+
+/**
+ * Abstract Yarn scheduler backend that contains common logic
+ * between the client and cluster Yarn scheduler backends.
+ */
+private[spark] abstract class YarnSchedulerBackend(
+ scheduler: TaskSchedulerImpl,
+ sc: SparkContext)
+ extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) {
+
+ override val minRegisteredRatio =
+ if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) {
+ 0.8
+ } else {
+ super.minRegisteredRatio
+ }
+
+ protected var totalExpectedExecutors = 0
+
+ private val yarnSchedulerEndpoint = new YarnSchedulerEndpoint(rpcEnv)
+
+ private val yarnSchedulerEndpointRef = rpcEnv.setupEndpoint(
+ YarnSchedulerBackend.ENDPOINT_NAME, yarnSchedulerEndpoint)
+
+ private implicit val askTimeout = RpcUtils.askRpcTimeout(sc.conf)
+
+ /** Application ID. */
+ protected var appId: Option[ApplicationId] = None
+
+ /** Attempt ID. This is unset for client-mode schedulers */
+ private var attemptId: Option[ApplicationAttemptId] = None
+
+ /** Scheduler extension services. */
+ private val services: SchedulerExtensionServices = new SchedulerExtensionServices()
+
+ // Flag to specify whether this schedulerBackend should be reset.
+ private var shouldResetOnAmRegister = false
+
+ /**
+ * Bind to YARN. This *must* be done before calling [[start()]].
+ *
+ * @param appId YARN application ID
+ * @param attemptId Optional YARN attempt ID
+ */
+ protected def bindToYarn(appId: ApplicationId, attemptId: Option[ApplicationAttemptId]): Unit = {
+ this.appId = Some(appId)
+ this.attemptId = attemptId
+ }
+
+ override def start() {
+ require(appId.isDefined, "application ID unset")
+ val binding = SchedulerExtensionServiceBinding(sc, appId.get, attemptId)
+ services.start(binding)
+ super.start()
+ }
+
+ override def stop(): Unit = {
+ try {
+ // SPARK-12009: To prevent Yarn allocator from requesting backup for the executors which
+ // was Stopped by SchedulerBackend.
+ requestTotalExecutors(0, 0, Map.empty)
+ super.stop()
+ } finally {
+ services.stop()
+ }
+ }
+
+ /**
+ * Get the attempt ID for this run, if the cluster manager supports multiple
+ * attempts. Applications run in client mode will not have attempt IDs.
+ * This attempt ID only includes attempt counter, like "1", "2".
+ *
+ * @return The application attempt id, if available.
+ */
+ override def applicationAttemptId(): Option[String] = {
+ attemptId.map(_.getAttemptId.toString)
+ }
+
+ /**
+ * Get an application ID associated with the job.
+ * This returns the string value of [[appId]] if set, otherwise
+ * the locally-generated ID from the superclass.
+ * @return The application ID
+ */
+ override def applicationId(): String = {
+ appId.map(_.toString).getOrElse {
+ logWarning("Application ID is not initialized yet.")
+ super.applicationId
+ }
+ }
+
+ /**
+ * Request executors from the ApplicationMaster by specifying the total number desired.
+ * This includes executors already pending or running.
+ */
+ override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = {
+ yarnSchedulerEndpointRef.ask[Boolean](
+ RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount))
+ }
+
+ /**
+ * Request that the ApplicationMaster kill the specified executors.
+ */
+ override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = {
+ yarnSchedulerEndpointRef.ask[Boolean](KillExecutors(executorIds))
+ }
+
+ override def sufficientResourcesRegistered(): Boolean = {
+ totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio
+ }
+
+ /**
+ * Add filters to the SparkUI.
+ */
+ private def addWebUIFilter(
+ filterName: String,
+ filterParams: Map[String, String],
+ proxyBase: String): Unit = {
+ if (proxyBase != null && proxyBase.nonEmpty) {
+ System.setProperty("spark.ui.proxyBase", proxyBase)
+ }
+
+ val hasFilter =
+ filterName != null && filterName.nonEmpty &&
+ filterParams != null && filterParams.nonEmpty
+ if (hasFilter) {
+ logInfo(s"Add WebUI Filter. $filterName, $filterParams, $proxyBase")
+ conf.set("spark.ui.filters", filterName)
+ filterParams.foreach { case (k, v) => conf.set(s"spark.$filterName.param.$k", v) }
+ scheduler.sc.ui.foreach { ui => JettyUtils.addFilters(ui.getHandlers, conf) }
+ }
+ }
+
+ override def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = {
+ new YarnDriverEndpoint(rpcEnv, properties)
+ }
+
+ /**
+ * Reset the state of SchedulerBackend to the initial state. This is happened when AM is failed
+ * and re-registered itself to driver after a failure. The stale state in driver should be
+ * cleaned.
+ */
+ override protected def reset(): Unit = {
+ super.reset()
+ sc.executorAllocationManager.foreach(_.reset())
+ }
+
+ /**
+ * Override the DriverEndpoint to add extra logic for the case when an executor is disconnected.
+ * This endpoint communicates with the executors and queries the AM for an executor's exit
+ * status when the executor is disconnected.
+ */
+ private class YarnDriverEndpoint(rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)])
+ extends DriverEndpoint(rpcEnv, sparkProperties) {
+
+ /**
+ * When onDisconnected is received at the driver endpoint, the superclass DriverEndpoint
+ * handles it by assuming the Executor was lost for a bad reason and removes the executor
+ * immediately.
+ *
+ * In YARN's case however it is crucial to talk to the application master and ask why the
+ * executor had exited. If the executor exited for some reason unrelated to the running tasks
+ * (e.g., preemption), according to the application master, then we pass that information down
+ * to the TaskSetManager to inform the TaskSetManager that tasks on that lost executor should
+ * not count towards a job failure.
+ */
+ override def onDisconnected(rpcAddress: RpcAddress): Unit = {
+ addressToExecutorId.get(rpcAddress).foreach { executorId =>
+ if (disableExecutor(executorId)) {
+ yarnSchedulerEndpoint.handleExecutorDisconnectedFromDriver(executorId, rpcAddress)
+ }
+ }
+ }
+ }
+
+ /**
+ * An [[RpcEndpoint]] that communicates with the ApplicationMaster.
+ */
+ private class YarnSchedulerEndpoint(override val rpcEnv: RpcEnv)
+ extends ThreadSafeRpcEndpoint with Logging {
+ private var amEndpoint: Option[RpcEndpointRef] = None
+
+ private[YarnSchedulerBackend] def handleExecutorDisconnectedFromDriver(
+ executorId: String,
+ executorRpcAddress: RpcAddress): Unit = {
+ val removeExecutorMessage = amEndpoint match {
+ case Some(am) =>
+ val lossReasonRequest = GetExecutorLossReason(executorId)
+ am.ask[ExecutorLossReason](lossReasonRequest, askTimeout)
+ .map { reason => RemoveExecutor(executorId, reason) }(ThreadUtils.sameThread)
+ .recover {
+ case NonFatal(e) =>
+ logWarning(s"Attempted to get executor loss reason" +
+ s" for executor id ${executorId} at RPC address ${executorRpcAddress}," +
+ s" but got no response. Marking as slave lost.", e)
+ RemoveExecutor(executorId, SlaveLost())
+ }(ThreadUtils.sameThread)
+ case None =>
+ logWarning("Attempted to check for an executor loss reason" +
+ " before the AM has registered!")
+ Future.successful(RemoveExecutor(executorId, SlaveLost("AM is not yet registered.")))
+ }
+
+ removeExecutorMessage
+ .flatMap { message =>
+ driverEndpoint.ask[Boolean](message)
+ }(ThreadUtils.sameThread)
+ .onFailure {
+ case NonFatal(e) => logError(
+ s"Error requesting driver to remove executor $executorId after disconnection.", e)
+ }(ThreadUtils.sameThread)
+ }
+
+ override def receive: PartialFunction[Any, Unit] = {
+ case RegisterClusterManager(am) =>
+ logInfo(s"ApplicationMaster registered as $am")
+ amEndpoint = Option(am)
+ if (!shouldResetOnAmRegister) {
+ shouldResetOnAmRegister = true
+ } else {
+ // AM is already registered before, this potentially means that AM failed and
+ // a new one registered after the failure. This will only happen in yarn-client mode.
+ reset()
+ }
+
+ case AddWebUIFilter(filterName, filterParams, proxyBase) =>
+ addWebUIFilter(filterName, filterParams, proxyBase)
+
+ case r @ RemoveExecutor(executorId, reason) =>
+ logWarning(reason.toString)
+ driverEndpoint.ask[Boolean](r).onFailure {
+ case e =>
+ logError("Error requesting driver to remove executor" +
+ s" $executorId for reason $reason", e)
+ }(ThreadUtils.sameThread)
+ }
+
+
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ case r: RequestExecutors =>
+ amEndpoint match {
+ case Some(am) =>
+ am.ask[Boolean](r).andThen {
+ case Success(b) => context.reply(b)
+ case Failure(NonFatal(e)) =>
+ logError(s"Sending $r to AM was unsuccessful", e)
+ context.sendFailure(e)
+ }(ThreadUtils.sameThread)
+ case None =>
+ logWarning("Attempted to request executors before the AM has registered!")
+ context.reply(false)
+ }
+
+ case k: KillExecutors =>
+ amEndpoint match {
+ case Some(am) =>
+ am.ask[Boolean](k).andThen {
+ case Success(b) => context.reply(b)
+ case Failure(NonFatal(e)) =>
+ logError(s"Sending $k to AM was unsuccessful", e)
+ context.sendFailure(e)
+ }(ThreadUtils.sameThread)
+ case None =>
+ logWarning("Attempted to kill executors before the AM has registered!")
+ context.reply(false)
+ }
+
+ case RetrieveLastAllocatedExecutorId =>
+ context.reply(currentExecutorIdCounter)
+ }
+
+ override def onDisconnected(remoteAddress: RpcAddress): Unit = {
+ if (amEndpoint.exists(_.address == remoteAddress)) {
+ logWarning(s"ApplicationMaster has disassociated: $remoteAddress")
+ amEndpoint = None
+ }
+ }
+ }
+}
+
+private[spark] object YarnSchedulerBackend {
+ val ENDPOINT_NAME = "YarnScheduler"
+}