aboutsummaryrefslogtreecommitdiff
path: root/yarn/src
diff options
context:
space:
mode:
authorSandy Ryza <sandy@cloudera.com>2014-12-09 11:02:43 -0800
committerAndrew Or <andrew@databricks.com>2014-12-09 11:02:43 -0800
commit912563aa3553afc0871d5b5858f533aa39cb99e5 (patch)
tree092241ac4c78deef8053f5095bc9680b3c8532cd /yarn/src
parent383c5555c9f26c080bc9e3a463aab21dd5b3797f (diff)
downloadspark-912563aa3553afc0871d5b5858f533aa39cb99e5.tar.gz
spark-912563aa3553afc0871d5b5858f533aa39cb99e5.tar.bz2
spark-912563aa3553afc0871d5b5858f533aa39cb99e5.zip
SPARK-4338. [YARN] Ditch yarn-alpha.
Sorry if this is a little premature with 1.2 still not out the door, but it will make other work like SPARK-4136 and SPARK-2089 a lot easier. Author: Sandy Ryza <sandy@cloudera.com> Closes #3215 from sryza/sandy-spark-4338 and squashes the following commits: 1c5ac08 [Sandy Ryza] Update building Spark docs and remove unnecessary newline 9c1421c [Sandy Ryza] SPARK-4338. Ditch yarn-alpha.
Diffstat (limited to 'yarn/src')
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala539
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala96
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala141
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala202
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala823
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala207
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala113
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala203
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala213
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala538
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala68
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala110
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala226
-rw-r--r--yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala35
-rw-r--r--yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala157
-rw-r--r--yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala56
-rw-r--r--yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala50
-rw-r--r--yarn/src/test/resources/log4j.properties28
-rw-r--r--yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala256
-rw-r--r--yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala220
-rw-r--r--yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala34
-rw-r--r--yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala189
-rw-r--r--yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala151
23 files changed, 4655 insertions, 0 deletions
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
new file mode 100644
index 0000000000..987b3373fb
--- /dev/null
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -0,0 +1,539 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import scala.util.control.NonFatal
+
+import java.io.IOException
+import java.lang.reflect.InvocationTargetException
+import java.net.Socket
+import java.util.concurrent.atomic.AtomicReference
+
+import akka.actor._
+import akka.remote._
+import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.util.ShutdownHookManager
+import org.apache.hadoop.yarn.api._
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+
+import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkEnv}
+import org.apache.spark.SparkException
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.deploy.history.HistoryServer
+import org.apache.spark.scheduler.cluster.YarnSchedulerBackend
+import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
+import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils}
+
+/**
+ * Common application master functionality for Spark on Yarn.
+ */
+private[spark] class ApplicationMaster(args: ApplicationMasterArguments,
+ client: YarnRMClient) extends Logging {
+ // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be
+ // optimal as more containers are available. Might need to handle this better.
+
+ private val sparkConf = new SparkConf()
+ private val yarnConf: YarnConfiguration = SparkHadoopUtil.get.newConfiguration(sparkConf)
+ .asInstanceOf[YarnConfiguration]
+ private val isDriver = args.userClass != null
+
+ // Default to numExecutors * 2, with minimum of 3
+ private val maxNumExecutorFailures = sparkConf.getInt("spark.yarn.max.executor.failures",
+ sparkConf.getInt("spark.yarn.max.worker.failures", math.max(args.numExecutors * 2, 3)))
+
+ @volatile private var exitCode = 0
+ @volatile private var unregistered = false
+ @volatile private var finished = false
+ @volatile private var finalStatus = FinalApplicationStatus.SUCCEEDED
+ @volatile private var finalMsg: String = ""
+ @volatile private var userClassThread: Thread = _
+
+ private var reporterThread: Thread = _
+ private var allocator: YarnAllocator = _
+
+ // Fields used in client mode.
+ private var actorSystem: ActorSystem = null
+ private var actor: ActorRef = _
+
+ // Fields used in cluster mode.
+ private val sparkContextRef = new AtomicReference[SparkContext](null)
+
+ final def run(): Int = {
+ try {
+ val appAttemptId = client.getAttemptId()
+
+ if (isDriver) {
+ // Set the web ui port to be ephemeral for yarn so we don't conflict with
+ // other spark processes running on the same box
+ System.setProperty("spark.ui.port", "0")
+
+ // Set the master property to match the requested mode.
+ System.setProperty("spark.master", "yarn-cluster")
+
+ // Propagate the application ID so that YarnClusterSchedulerBackend can pick it up.
+ System.setProperty("spark.yarn.app.id", appAttemptId.getApplicationId().toString())
+ }
+
+ logInfo("ApplicationAttemptId: " + appAttemptId)
+
+ val fs = FileSystem.get(yarnConf)
+ val cleanupHook = new Runnable {
+ override def run() {
+ // If the SparkContext is still registered, shut it down as a best case effort in case
+ // users do not call sc.stop or do System.exit().
+ val sc = sparkContextRef.get()
+ if (sc != null) {
+ logInfo("Invoking sc stop from shutdown hook")
+ sc.stop()
+ }
+ val maxAppAttempts = client.getMaxRegAttempts(yarnConf)
+ val isLastAttempt = client.getAttemptId().getAttemptId() >= maxAppAttempts
+
+ if (!finished) {
+ // This happens when the user application calls System.exit(). We have the choice
+ // of either failing or succeeding at this point. We report success to avoid
+ // retrying applications that have succeeded (System.exit(0)), which means that
+ // applications that explicitly exit with a non-zero status will also show up as
+ // succeeded in the RM UI.
+ finish(finalStatus,
+ ApplicationMaster.EXIT_SUCCESS,
+ "Shutdown hook called before final status was reported.")
+ }
+
+ if (!unregistered) {
+ // we only want to unregister if we don't want the RM to retry
+ if (finalStatus == FinalApplicationStatus.SUCCEEDED || isLastAttempt) {
+ unregister(finalStatus, finalMsg)
+ cleanupStagingDir(fs)
+ }
+ }
+ }
+ }
+
+ // Use higher priority than FileSystem.
+ assert(ApplicationMaster.SHUTDOWN_HOOK_PRIORITY > FileSystem.SHUTDOWN_HOOK_PRIORITY)
+ ShutdownHookManager
+ .get().addShutdownHook(cleanupHook, ApplicationMaster.SHUTDOWN_HOOK_PRIORITY)
+
+ // Call this to force generation of secret so it gets populated into the
+ // Hadoop UGI. This has to happen before the startUserClass which does a
+ // doAs in order for the credentials to be passed on to the executor containers.
+ val securityMgr = new SecurityManager(sparkConf)
+
+ if (isDriver) {
+ runDriver(securityMgr)
+ } else {
+ runExecutorLauncher(securityMgr)
+ }
+ } catch {
+ case e: Exception =>
+ // catch everything else if not specifically handled
+ logError("Uncaught exception: ", e)
+ finish(FinalApplicationStatus.FAILED,
+ ApplicationMaster.EXIT_UNCAUGHT_EXCEPTION,
+ "Uncaught exception: " + e.getMessage())
+ }
+ exitCode
+ }
+
+ /**
+ * unregister is used to completely unregister the application from the ResourceManager.
+ * This means the ResourceManager will not retry the application attempt on your behalf if
+ * a failure occurred.
+ */
+ final def unregister(status: FinalApplicationStatus, diagnostics: String = null) = synchronized {
+ if (!unregistered) {
+ logInfo(s"Unregistering ApplicationMaster with $status" +
+ Option(diagnostics).map(msg => s" (diag message: $msg)").getOrElse(""))
+ unregistered = true
+ client.unregister(status, Option(diagnostics).getOrElse(""))
+ }
+ }
+
+ final def finish(status: FinalApplicationStatus, code: Int, msg: String = null) = synchronized {
+ if (!finished) {
+ val inShutdown = Utils.inShutdown()
+ logInfo(s"Final app status: ${status}, exitCode: ${code}" +
+ Option(msg).map(msg => s", (reason: $msg)").getOrElse(""))
+ exitCode = code
+ finalStatus = status
+ finalMsg = msg
+ finished = true
+ if (!inShutdown && Thread.currentThread() != reporterThread && reporterThread != null) {
+ logDebug("shutting down reporter thread")
+ reporterThread.interrupt()
+ }
+ if (!inShutdown && Thread.currentThread() != userClassThread && userClassThread != null) {
+ logDebug("shutting down user thread")
+ userClassThread.interrupt()
+ }
+ }
+ }
+
+ private def sparkContextInitialized(sc: SparkContext) = {
+ sparkContextRef.synchronized {
+ sparkContextRef.compareAndSet(null, sc)
+ sparkContextRef.notifyAll()
+ }
+ }
+
+ private def sparkContextStopped(sc: SparkContext) = {
+ sparkContextRef.compareAndSet(sc, null)
+ }
+
+ private def registerAM(uiAddress: String, securityMgr: SecurityManager) = {
+ val sc = sparkContextRef.get()
+
+ val appId = client.getAttemptId().getApplicationId().toString()
+ val historyAddress =
+ sparkConf.getOption("spark.yarn.historyServer.address")
+ .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}" }
+ .getOrElse("")
+
+ allocator = client.register(yarnConf,
+ if (sc != null) sc.getConf else sparkConf,
+ if (sc != null) sc.preferredNodeLocationData else Map(),
+ uiAddress,
+ historyAddress,
+ securityMgr)
+
+ allocator.allocateResources()
+ reporterThread = launchReporterThread()
+ }
+
+ private def runDriver(securityMgr: SecurityManager): Unit = {
+ addAmIpFilter()
+ userClassThread = startUserClass()
+
+ // This a bit hacky, but we need to wait until the spark.driver.port property has
+ // been set by the Thread executing the user class.
+ val sc = waitForSparkContextInitialized()
+
+ // If there is no SparkContext at this point, just fail the app.
+ if (sc == null) {
+ finish(FinalApplicationStatus.FAILED,
+ ApplicationMaster.EXIT_SC_NOT_INITED,
+ "Timed out waiting for SparkContext.")
+ } else {
+ registerAM(sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr)
+ userClassThread.join()
+ }
+ }
+
+ private def runExecutorLauncher(securityMgr: SecurityManager): Unit = {
+ actorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0,
+ conf = sparkConf, securityManager = securityMgr)._1
+ actor = waitForSparkDriver()
+ addAmIpFilter()
+ registerAM(sparkConf.get("spark.driver.appUIAddress", ""), securityMgr)
+
+ // In client mode the actor will stop the reporter thread.
+ reporterThread.join()
+ }
+
+ private def launchReporterThread(): Thread = {
+ // Ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapses.
+ val expiryInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000)
+
+ // we want to be reasonably responsive without causing too many requests to RM.
+ val schedulerInterval =
+ sparkConf.getLong("spark.yarn.scheduler.heartbeat.interval-ms", 5000)
+
+ // must be <= expiryInterval / 2.
+ val interval = math.max(0, math.min(expiryInterval / 2, schedulerInterval))
+
+ // The number of failures in a row until Reporter thread give up
+ val reporterMaxFailures = sparkConf.getInt("spark.yarn.scheduler.reporterThread.maxFailures", 5)
+
+ val t = new Thread {
+ override def run() {
+ var failureCount = 0
+ while (!finished) {
+ try {
+ if (allocator.getNumExecutorsFailed >= maxNumExecutorFailures) {
+ finish(FinalApplicationStatus.FAILED,
+ ApplicationMaster.EXIT_MAX_EXECUTOR_FAILURES,
+ "Max number of executor failures reached")
+ } else {
+ logDebug("Sending progress")
+ allocator.allocateResources()
+ }
+ failureCount = 0
+ } catch {
+ case i: InterruptedException =>
+ case e: Throwable => {
+ failureCount += 1
+ if (!NonFatal(e) || failureCount >= reporterMaxFailures) {
+ finish(FinalApplicationStatus.FAILED,
+ ApplicationMaster.EXIT_REPORTER_FAILURE, "Exception was thrown " +
+ s"${failureCount} time(s) from Reporter thread.")
+
+ } else {
+ logWarning(s"Reporter thread fails ${failureCount} time(s) in a row.", e)
+ }
+ }
+ }
+ try {
+ Thread.sleep(interval)
+ } catch {
+ case e: InterruptedException =>
+ }
+ }
+ }
+ }
+ // setting to daemon status, though this is usually not a good idea.
+ t.setDaemon(true)
+ t.setName("Reporter")
+ t.start()
+ logInfo("Started progress reporter thread - sleep time : " + interval)
+ t
+ }
+
+ /**
+ * Clean up the staging directory.
+ */
+ private def cleanupStagingDir(fs: FileSystem) {
+ var stagingDirPath: Path = null
+ try {
+ val preserveFiles = sparkConf.get("spark.yarn.preserve.staging.files", "false").toBoolean
+ if (!preserveFiles) {
+ stagingDirPath = new Path(System.getenv("SPARK_YARN_STAGING_DIR"))
+ if (stagingDirPath == null) {
+ logError("Staging directory is null")
+ return
+ }
+ logInfo("Deleting staging directory " + stagingDirPath)
+ fs.delete(stagingDirPath, true)
+ }
+ } catch {
+ case ioe: IOException =>
+ logError("Failed to cleanup staging dir " + stagingDirPath, ioe)
+ }
+ }
+
+ private def waitForSparkContextInitialized(): SparkContext = {
+ logInfo("Waiting for spark context initialization")
+ try {
+ sparkContextRef.synchronized {
+ var count = 0
+ val waitTime = 10000L
+ val numTries = sparkConf.getInt("spark.yarn.applicationMaster.waitTries", 10)
+ while (sparkContextRef.get() == null && count < numTries && !finished) {
+ logInfo("Waiting for spark context initialization ... " + count)
+ count = count + 1
+ sparkContextRef.wait(waitTime)
+ }
+
+ val sparkContext = sparkContextRef.get()
+ if (sparkContext == null) {
+ logError(("SparkContext did not initialize after waiting for %d ms. Please check earlier"
+ + " log output for errors. Failing the application.").format(numTries * waitTime))
+ }
+ sparkContext
+ }
+ }
+ }
+
+ private def waitForSparkDriver(): ActorRef = {
+ logInfo("Waiting for Spark driver to be reachable.")
+ var driverUp = false
+ var count = 0
+ val hostport = args.userArgs(0)
+ val (driverHost, driverPort) = Utils.parseHostPort(hostport)
+
+ // spark driver should already be up since it launched us, but we don't want to
+ // wait forever, so wait 100 seconds max to match the cluster mode setting.
+ // Leave this config unpublished for now. SPARK-3779 to investigating changing
+ // this config to be time based.
+ val numTries = sparkConf.getInt("spark.yarn.applicationMaster.waitTries", 1000)
+
+ while (!driverUp && !finished && count < numTries) {
+ try {
+ count = count + 1
+ val socket = new Socket(driverHost, driverPort)
+ socket.close()
+ logInfo("Driver now available: %s:%s".format(driverHost, driverPort))
+ driverUp = true
+ } catch {
+ case e: Exception =>
+ logError("Failed to connect to driver at %s:%s, retrying ...".
+ format(driverHost, driverPort))
+ Thread.sleep(100)
+ }
+ }
+
+ if (!driverUp) {
+ throw new SparkException("Failed to connect to driver!")
+ }
+
+ sparkConf.set("spark.driver.host", driverHost)
+ sparkConf.set("spark.driver.port", driverPort.toString)
+
+ val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format(
+ SparkEnv.driverActorSystemName,
+ driverHost,
+ driverPort.toString,
+ YarnSchedulerBackend.ACTOR_NAME)
+ actorSystem.actorOf(Props(new AMActor(driverUrl)), name = "YarnAM")
+ }
+
+ /** Add the Yarn IP filter that is required for properly securing the UI. */
+ private def addAmIpFilter() = {
+ val proxyBase = System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV)
+ val amFilter = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter"
+ val params = client.getAmIpFilterParams(yarnConf, proxyBase)
+ if (isDriver) {
+ System.setProperty("spark.ui.filters", amFilter)
+ params.foreach { case (k, v) => System.setProperty(s"spark.$amFilter.param.$k", v) }
+ } else {
+ actor ! AddWebUIFilter(amFilter, params.toMap, proxyBase)
+ }
+ }
+
+ /**
+ * Start the user class, which contains the spark driver, in a separate Thread.
+ * If the main routine exits cleanly or exits with System.exit(N) for any N
+ * we assume it was successful, for all other cases we assume failure.
+ *
+ * Returns the user thread that was started.
+ */
+ private def startUserClass(): Thread = {
+ logInfo("Starting the user JAR in a separate Thread")
+ System.setProperty("spark.executor.instances", args.numExecutors.toString)
+ val mainMethod = Class.forName(args.userClass, false,
+ Thread.currentThread.getContextClassLoader).getMethod("main", classOf[Array[String]])
+
+ val userThread = new Thread {
+ override def run() {
+ try {
+ val mainArgs = new Array[String](args.userArgs.size)
+ args.userArgs.copyToArray(mainArgs, 0, args.userArgs.size)
+ mainMethod.invoke(null, mainArgs)
+ finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS)
+ logDebug("Done running users class")
+ } catch {
+ case e: InvocationTargetException =>
+ e.getCause match {
+ case _: InterruptedException =>
+ // Reporter thread can interrupt to stop user class
+ case e: Exception =>
+ finish(FinalApplicationStatus.FAILED,
+ ApplicationMaster.EXIT_EXCEPTION_USER_CLASS,
+ "User class threw exception: " + e.getMessage)
+ // re-throw to get it logged
+ throw e
+ }
+ }
+ }
+ }
+ userThread.setName("Driver")
+ userThread.start()
+ userThread
+ }
+
+ /**
+ * Actor that communicates with the driver in client deploy mode.
+ */
+ private class AMActor(driverUrl: String) extends Actor {
+ var driver: ActorSelection = _
+
+ override def preStart() = {
+ logInfo("Listen to driver: " + driverUrl)
+ driver = context.actorSelection(driverUrl)
+ // Send a hello message to establish the connection, after which
+ // we can monitor Lifecycle Events.
+ driver ! "Hello"
+ driver ! RegisterClusterManager
+ context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
+ }
+
+ override def receive = {
+ case x: DisassociatedEvent =>
+ logInfo(s"Driver terminated or disconnected! Shutting down. $x")
+ finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS)
+
+ case x: AddWebUIFilter =>
+ logInfo(s"Add WebUI Filter. $x")
+ driver ! x
+
+ case RequestExecutors(requestedTotal) =>
+ logInfo(s"Driver requested a total number of $requestedTotal executor(s).")
+ Option(allocator) match {
+ case Some(a) => a.requestTotalExecutors(requestedTotal)
+ case None => logWarning("Container allocator is not ready to request executors yet.")
+ }
+ sender ! true
+
+ case KillExecutors(executorIds) =>
+ logInfo(s"Driver requested to kill executor(s) ${executorIds.mkString(", ")}.")
+ Option(allocator) match {
+ case Some(a) => executorIds.foreach(a.killExecutor)
+ case None => logWarning("Container allocator is not ready to kill executors yet.")
+ }
+ sender ! true
+ }
+ }
+
+}
+
+object ApplicationMaster extends Logging {
+
+ val SHUTDOWN_HOOK_PRIORITY: Int = 30
+
+ // exit codes for different causes, no reason behind the values
+ private val EXIT_SUCCESS = 0
+ private val EXIT_UNCAUGHT_EXCEPTION = 10
+ private val EXIT_MAX_EXECUTOR_FAILURES = 11
+ private val EXIT_REPORTER_FAILURE = 12
+ private val EXIT_SC_NOT_INITED = 13
+ private val EXIT_SECURITY = 14
+ private val EXIT_EXCEPTION_USER_CLASS = 15
+
+ private var master: ApplicationMaster = _
+
+ def main(args: Array[String]) = {
+ SignalLogger.register(log)
+ val amArgs = new ApplicationMasterArguments(args)
+ SparkHadoopUtil.get.runAsSparkUser { () =>
+ master = new ApplicationMaster(amArgs, new YarnRMClientImpl(amArgs))
+ System.exit(master.run())
+ }
+ }
+
+ private[spark] def sparkContextInitialized(sc: SparkContext) = {
+ master.sparkContextInitialized(sc)
+ }
+
+ private[spark] def sparkContextStopped(sc: SparkContext) = {
+ master.sparkContextStopped(sc)
+ }
+
+}
+
+/**
+ * This object does not provide any special functionality. It exists so that it's easy to tell
+ * apart the client-mode AM from the cluster-mode AM when using tools such as ps or jps.
+ */
+object ExecutorLauncher {
+
+ def main(args: Array[String]) = {
+ ApplicationMaster.main(args)
+ }
+
+}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala
new file mode 100644
index 0000000000..d76a63276d
--- /dev/null
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import org.apache.spark.util.{MemoryParam, IntParam}
+import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._
+import collection.mutable.ArrayBuffer
+
+class ApplicationMasterArguments(val args: Array[String]) {
+ var userJar: String = null
+ var userClass: String = null
+ var userArgs: Seq[String] = Seq[String]()
+ var executorMemory = 1024
+ var executorCores = 1
+ var numExecutors = DEFAULT_NUMBER_EXECUTORS
+
+ parseArgs(args.toList)
+
+ private def parseArgs(inputArgs: List[String]): Unit = {
+ val userArgsBuffer = new ArrayBuffer[String]()
+
+ var args = inputArgs
+
+ while (!args.isEmpty) {
+ // --num-workers, --worker-memory, and --worker-cores are deprecated since 1.0,
+ // the properties with executor in their names are preferred.
+ args match {
+ case ("--jar") :: value :: tail =>
+ userJar = value
+ args = tail
+
+ case ("--class") :: value :: tail =>
+ userClass = value
+ args = tail
+
+ case ("--args" | "--arg") :: value :: tail =>
+ userArgsBuffer += value
+ args = tail
+
+ case ("--num-workers" | "--num-executors") :: IntParam(value) :: tail =>
+ numExecutors = value
+ args = tail
+
+ case ("--worker-memory" | "--executor-memory") :: MemoryParam(value) :: tail =>
+ executorMemory = value
+ args = tail
+
+ case ("--worker-cores" | "--executor-cores") :: IntParam(value) :: tail =>
+ executorCores = value
+ args = tail
+
+ case _ =>
+ printUsageAndExit(1, args)
+ }
+ }
+
+ userArgs = userArgsBuffer.readOnly
+ }
+
+ def printUsageAndExit(exitCode: Int, unknownParam: Any = null) {
+ if (unknownParam != null) {
+ System.err.println("Unknown/unsupported param " + unknownParam)
+ }
+ System.err.println("""
+ |Usage: org.apache.spark.deploy.yarn.ApplicationMaster [options]
+ |Options:
+ | --jar JAR_PATH Path to your application's JAR file
+ | --class CLASS_NAME Name of your application's main class
+ | --args ARGS Arguments to be passed to your application's main class.
+ | Multiple invocations are possible, each will be passed in order.
+ | --num-executors NUM Number of executors to start (Default: 2)
+ | --executor-cores NUM Number of cores for the executors (Default: 1)
+ | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G)
+ """.stripMargin)
+ System.exit(exitCode)
+ }
+}
+
+object ApplicationMasterArguments {
+ val DEFAULT_NUMBER_EXECUTORS = 2
+}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
new file mode 100644
index 0000000000..addaddb711
--- /dev/null
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.nio.ByteBuffer
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.io.DataOutputBuffer
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.client.api.{YarnClient, YarnClientApplication}
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.util.Records
+
+import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.deploy.SparkHadoopUtil
+
+/**
+ * Version of [[org.apache.spark.deploy.yarn.ClientBase]] tailored to YARN's stable API.
+ */
+private[spark] class Client(
+ val args: ClientArguments,
+ val hadoopConf: Configuration,
+ val sparkConf: SparkConf)
+ extends ClientBase with Logging {
+
+ def this(clientArgs: ClientArguments, spConf: SparkConf) =
+ this(clientArgs, SparkHadoopUtil.get.newConfiguration(spConf), spConf)
+
+ def this(clientArgs: ClientArguments) = this(clientArgs, new SparkConf())
+
+ val yarnClient = YarnClient.createYarnClient
+ val yarnConf = new YarnConfiguration(hadoopConf)
+
+ def stop(): Unit = yarnClient.stop()
+
+ /* ------------------------------------------------------------------------------------- *
+ | The following methods have much in common in the stable and alpha versions of Client, |
+ | but cannot be implemented in the parent trait due to subtle API differences across |
+ | hadoop versions. |
+ * ------------------------------------------------------------------------------------- */
+
+ /**
+ * Submit an application running our ApplicationMaster to the ResourceManager.
+ *
+ * The stable Yarn API provides a convenience method (YarnClient#createApplication) for
+ * creating applications and setting up the application submission context. This was not
+ * available in the alpha API.
+ */
+ override def submitApplication(): ApplicationId = {
+ yarnClient.init(yarnConf)
+ yarnClient.start()
+
+ logInfo("Requesting a new application from cluster with %d NodeManagers"
+ .format(yarnClient.getYarnClusterMetrics.getNumNodeManagers))
+
+ // Get a new application from our RM
+ val newApp = yarnClient.createApplication()
+ val newAppResponse = newApp.getNewApplicationResponse()
+ val appId = newAppResponse.getApplicationId()
+
+ // Verify whether the cluster has enough resources for our AM
+ verifyClusterResources(newAppResponse)
+
+ // Set up the appropriate contexts to launch our AM
+ val containerContext = createContainerLaunchContext(newAppResponse)
+ val appContext = createApplicationSubmissionContext(newApp, containerContext)
+
+ // Finally, submit and monitor the application
+ logInfo(s"Submitting application ${appId.getId} to ResourceManager")
+ yarnClient.submitApplication(appContext)
+ appId
+ }
+
+ /**
+ * Set up the context for submitting our ApplicationMaster.
+ * This uses the YarnClientApplication not available in the Yarn alpha API.
+ */
+ def createApplicationSubmissionContext(
+ newApp: YarnClientApplication,
+ containerContext: ContainerLaunchContext): ApplicationSubmissionContext = {
+ val appContext = newApp.getApplicationSubmissionContext
+ appContext.setApplicationName(args.appName)
+ appContext.setQueue(args.amQueue)
+ appContext.setAMContainerSpec(containerContext)
+ appContext.setApplicationType("SPARK")
+ val capability = Records.newRecord(classOf[Resource])
+ capability.setMemory(args.amMemory + amMemoryOverhead)
+ appContext.setResource(capability)
+ appContext
+ }
+
+ /** Set up security tokens for launching our ApplicationMaster container. */
+ override def setupSecurityToken(amContainer: ContainerLaunchContext): Unit = {
+ val dob = new DataOutputBuffer
+ credentials.writeTokenStorageToStream(dob)
+ amContainer.setTokens(ByteBuffer.wrap(dob.getData))
+ }
+
+ /** Get the application report from the ResourceManager for an application we have submitted. */
+ override def getApplicationReport(appId: ApplicationId): ApplicationReport =
+ yarnClient.getApplicationReport(appId)
+
+ /**
+ * Return the security token used by this client to communicate with the ApplicationMaster.
+ * If no security is enabled, the token returned by the report is null.
+ */
+ override def getClientToken(report: ApplicationReport): String =
+ Option(report.getClientToAMToken).map(_.toString).getOrElse("")
+}
+
+object Client {
+ def main(argStrings: Array[String]) {
+ if (!sys.props.contains("SPARK_SUBMIT")) {
+ println("WARNING: This client is deprecated and will be removed in a " +
+ "future version of Spark. Use ./bin/spark-submit with \"--master yarn\"")
+ }
+
+ // Set an env variable indicating we are running in YARN mode.
+ // Note that any env variable with the SPARK_ prefix gets propagated to all (remote) processes
+ System.setProperty("SPARK_YARN_MODE", "true")
+ val sparkConf = new SparkConf
+
+ val args = new ClientArguments(argStrings, sparkConf)
+ new Client(args, sparkConf).run()
+ }
+}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
new file mode 100644
index 0000000000..c439969510
--- /dev/null
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
@@ -0,0 +1,202 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.SparkConf
+import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._
+import org.apache.spark.util.{Utils, IntParam, MemoryParam}
+
+// TODO: Add code and support for ensuring that yarn resource 'tasks' are location aware !
+private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) {
+ var addJars: String = null
+ var files: String = null
+ var archives: String = null
+ var userJar: String = null
+ var userClass: String = null
+ var userArgs: Seq[String] = Seq[String]()
+ var executorMemory = 1024 // MB
+ var executorCores = 1
+ var numExecutors = DEFAULT_NUMBER_EXECUTORS
+ var amQueue = sparkConf.get("spark.yarn.queue", "default")
+ var amMemory: Int = 512 // MB
+ var appName: String = "Spark"
+ var priority = 0
+
+ // Additional memory to allocate to containers
+ // For now, use driver's memory overhead as our AM container's memory overhead
+ val amMemoryOverhead = sparkConf.getInt("spark.yarn.driver.memoryOverhead",
+ math.max((MEMORY_OVERHEAD_FACTOR * amMemory).toInt, MEMORY_OVERHEAD_MIN))
+
+ val executorMemoryOverhead = sparkConf.getInt("spark.yarn.executor.memoryOverhead",
+ math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toInt, MEMORY_OVERHEAD_MIN))
+
+ private val isDynamicAllocationEnabled =
+ sparkConf.getBoolean("spark.dynamicAllocation.enabled", false)
+
+ parseArgs(args.toList)
+ loadEnvironmentArgs()
+ validateArgs()
+
+ /** Load any default arguments provided through environment variables and Spark properties. */
+ private def loadEnvironmentArgs(): Unit = {
+ // For backward compatibility, SPARK_YARN_DIST_{ARCHIVES/FILES} should be resolved to hdfs://,
+ // while spark.yarn.dist.{archives/files} should be resolved to file:// (SPARK-2051).
+ files = Option(files)
+ .orElse(sys.env.get("SPARK_YARN_DIST_FILES"))
+ .orElse(sparkConf.getOption("spark.yarn.dist.files").map(p => Utils.resolveURIs(p)))
+ .orNull
+ archives = Option(archives)
+ .orElse(sys.env.get("SPARK_YARN_DIST_ARCHIVES"))
+ .orElse(sparkConf.getOption("spark.yarn.dist.archives").map(p => Utils.resolveURIs(p)))
+ .orNull
+ // If dynamic allocation is enabled, start at the max number of executors
+ if (isDynamicAllocationEnabled) {
+ val maxExecutorsConf = "spark.dynamicAllocation.maxExecutors"
+ if (!sparkConf.contains(maxExecutorsConf)) {
+ throw new IllegalArgumentException(
+ s"$maxExecutorsConf must be set if dynamic allocation is enabled!")
+ }
+ numExecutors = sparkConf.get(maxExecutorsConf).toInt
+ }
+ }
+
+ /**
+ * Fail fast if any arguments provided are invalid.
+ * This is intended to be called only after the provided arguments have been parsed.
+ */
+ private def validateArgs(): Unit = {
+ if (numExecutors <= 0) {
+ throw new IllegalArgumentException(
+ "You must specify at least 1 executor!\n" + getUsageMessage())
+ }
+ }
+
+ private def parseArgs(inputArgs: List[String]): Unit = {
+ val userArgsBuffer = new ArrayBuffer[String]()
+ var args = inputArgs
+
+ while (!args.isEmpty) {
+ args match {
+ case ("--jar") :: value :: tail =>
+ userJar = value
+ args = tail
+
+ case ("--class") :: value :: tail =>
+ userClass = value
+ args = tail
+
+ case ("--args" | "--arg") :: value :: tail =>
+ if (args(0) == "--args") {
+ println("--args is deprecated. Use --arg instead.")
+ }
+ userArgsBuffer += value
+ args = tail
+
+ case ("--master-class" | "--am-class") :: value :: tail =>
+ println(s"${args(0)} is deprecated and is not used anymore.")
+ args = tail
+
+ case ("--master-memory" | "--driver-memory") :: MemoryParam(value) :: tail =>
+ if (args(0) == "--master-memory") {
+ println("--master-memory is deprecated. Use --driver-memory instead.")
+ }
+ amMemory = value
+ args = tail
+
+ case ("--num-workers" | "--num-executors") :: IntParam(value) :: tail =>
+ if (args(0) == "--num-workers") {
+ println("--num-workers is deprecated. Use --num-executors instead.")
+ }
+ // Dynamic allocation is not compatible with this option
+ if (isDynamicAllocationEnabled) {
+ throw new IllegalArgumentException("Explicitly setting the number " +
+ "of executors is not compatible with spark.dynamicAllocation.enabled!")
+ }
+ numExecutors = value
+ args = tail
+
+ case ("--worker-memory" | "--executor-memory") :: MemoryParam(value) :: tail =>
+ if (args(0) == "--worker-memory") {
+ println("--worker-memory is deprecated. Use --executor-memory instead.")
+ }
+ executorMemory = value
+ args = tail
+
+ case ("--worker-cores" | "--executor-cores") :: IntParam(value) :: tail =>
+ if (args(0) == "--worker-cores") {
+ println("--worker-cores is deprecated. Use --executor-cores instead.")
+ }
+ executorCores = value
+ args = tail
+
+ case ("--queue") :: value :: tail =>
+ amQueue = value
+ args = tail
+
+ case ("--name") :: value :: tail =>
+ appName = value
+ args = tail
+
+ case ("--addJars") :: value :: tail =>
+ addJars = value
+ args = tail
+
+ case ("--files") :: value :: tail =>
+ files = value
+ args = tail
+
+ case ("--archives") :: value :: tail =>
+ archives = value
+ args = tail
+
+ case Nil =>
+
+ case _ =>
+ throw new IllegalArgumentException(getUsageMessage(args))
+ }
+ }
+
+ userArgs = userArgsBuffer.readOnly
+ }
+
+ private def getUsageMessage(unknownParam: List[String] = null): String = {
+ val message = if (unknownParam != null) s"Unknown/unsupported param $unknownParam\n" else ""
+ message + """
+ |Usage: org.apache.spark.deploy.yarn.Client [options]
+ |Options:
+ | --jar JAR_PATH Path to your application's JAR file (required in yarn-cluster
+ | mode)
+ | --class CLASS_NAME Name of your application's main class (required)
+ | --arg ARG Argument to be passed to your application's main class.
+ | Multiple invocations are possible, each will be passed in order.
+ | --num-executors NUM Number of executors to start (Default: 2)
+ | --executor-cores NUM Number of cores for the executors (Default: 1).
+ | --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: 512 Mb)
+ | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G)
+ | --name NAME The name of your application (Default: Spark)
+ | --queue QUEUE The hadoop queue to use for allocation requests (Default:
+ | 'default')
+ | --addJars jars Comma separated list of local jars that want SparkContext.addJar
+ | to work with.
+ | --files files Comma separated list of files to be distributed with the job.
+ | --archives archives Comma separated list of archives to be distributed with the job.
+ """
+ }
+}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala
new file mode 100644
index 0000000000..f95d723791
--- /dev/null
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala
@@ -0,0 +1,823 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.net.{InetAddress, UnknownHostException, URI, URISyntaxException}
+
+import scala.collection.JavaConversions._
+import scala.collection.mutable.{HashMap, ListBuffer, Map}
+import scala.util.{Try, Success, Failure}
+
+import com.google.common.base.Objects
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs._
+import org.apache.hadoop.fs.permission.FsPermission
+import org.apache.hadoop.mapred.Master
+import org.apache.hadoop.mapreduce.MRJobConfig
+import org.apache.hadoop.security.{Credentials, UserGroupInformation}
+import org.apache.hadoop.util.StringUtils
+import org.apache.hadoop.yarn.api._
+import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
+import org.apache.hadoop.yarn.api.protocolrecords._
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.util.Records
+
+import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkException}
+import org.apache.spark.util.Utils
+
+/**
+ * The entry point (starting in Client#main() and Client#run()) for launching Spark on YARN.
+ * The Client submits an application to the YARN ResourceManager.
+ */
+private[spark] trait ClientBase extends Logging {
+ import ClientBase._
+
+ protected val args: ClientArguments
+ protected val hadoopConf: Configuration
+ protected val sparkConf: SparkConf
+ protected val yarnConf: YarnConfiguration
+ protected val credentials = UserGroupInformation.getCurrentUser.getCredentials
+ protected val amMemoryOverhead = args.amMemoryOverhead // MB
+ protected val executorMemoryOverhead = args.executorMemoryOverhead // MB
+ private val distCacheMgr = new ClientDistributedCacheManager()
+ private val isLaunchingDriver = args.userClass != null
+
+ /**
+ * Fail fast if we have requested more resources per container than is available in the cluster.
+ */
+ protected def verifyClusterResources(newAppResponse: GetNewApplicationResponse): Unit = {
+ val maxMem = newAppResponse.getMaximumResourceCapability().getMemory()
+ logInfo("Verifying our application has not requested more than the maximum " +
+ s"memory capability of the cluster ($maxMem MB per container)")
+ val executorMem = args.executorMemory + executorMemoryOverhead
+ if (executorMem > maxMem) {
+ throw new IllegalArgumentException(s"Required executor memory (${args.executorMemory}" +
+ s"+$executorMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster!")
+ }
+ val amMem = args.amMemory + amMemoryOverhead
+ if (amMem > maxMem) {
+ throw new IllegalArgumentException(s"Required AM memory (${args.amMemory}" +
+ s"+$amMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster!")
+ }
+ logInfo("Will allocate AM container, with %d MB memory including %d MB overhead".format(
+ amMem,
+ amMemoryOverhead))
+
+ // We could add checks to make sure the entire cluster has enough resources but that involves
+ // getting all the node reports and computing ourselves.
+ }
+
+ /**
+ * Copy the given file to a remote file system (e.g. HDFS) if needed.
+ * The file is only copied if the source and destination file systems are different. This is used
+ * for preparing resources for launching the ApplicationMaster container. Exposed for testing.
+ */
+ def copyFileToRemote(
+ destDir: Path,
+ srcPath: Path,
+ replication: Short,
+ setPerms: Boolean = false): Path = {
+ val destFs = destDir.getFileSystem(hadoopConf)
+ val srcFs = srcPath.getFileSystem(hadoopConf)
+ var destPath = srcPath
+ if (!compareFs(srcFs, destFs)) {
+ destPath = new Path(destDir, srcPath.getName())
+ logInfo(s"Uploading resource $srcPath -> $destPath")
+ FileUtil.copy(srcFs, srcPath, destFs, destPath, false, hadoopConf)
+ destFs.setReplication(destPath, replication)
+ if (setPerms) {
+ destFs.setPermission(destPath, new FsPermission(APP_FILE_PERMISSION))
+ }
+ } else {
+ logInfo(s"Source and destination file systems are the same. Not copying $srcPath")
+ }
+ // Resolve any symlinks in the URI path so using a "current" symlink to point to a specific
+ // version shows the specific version in the distributed cache configuration
+ val qualifiedDestPath = destFs.makeQualified(destPath)
+ val fc = FileContext.getFileContext(qualifiedDestPath.toUri(), hadoopConf)
+ fc.resolvePath(qualifiedDestPath)
+ }
+
+ /**
+ * Given a local URI, resolve it and return a qualified local path that corresponds to the URI.
+ * This is used for preparing local resources to be included in the container launch context.
+ */
+ private def getQualifiedLocalPath(localURI: URI): Path = {
+ val qualifiedURI =
+ if (localURI.getScheme == null) {
+ // If not specified, assume this is in the local filesystem to keep the behavior
+ // consistent with that of Hadoop
+ new URI(FileSystem.getLocal(hadoopConf).makeQualified(new Path(localURI)).toString)
+ } else {
+ localURI
+ }
+ new Path(qualifiedURI)
+ }
+
+ /**
+ * Upload any resources to the distributed cache if needed. If a resource is intended to be
+ * consumed locally, set up the appropriate config for downstream code to handle it properly.
+ * This is used for setting up a container launch context for our ApplicationMaster.
+ * Exposed for testing.
+ */
+ def prepareLocalResources(appStagingDir: String): HashMap[String, LocalResource] = {
+ logInfo("Preparing resources for our AM container")
+ // Upload Spark and the application JAR to the remote file system if necessary,
+ // and add them as local resources to the application master.
+ val fs = FileSystem.get(hadoopConf)
+ val dst = new Path(fs.getHomeDirectory(), appStagingDir)
+ val nns = getNameNodesToAccess(sparkConf) + dst
+ obtainTokensForNamenodes(nns, hadoopConf, credentials)
+
+ val replication = sparkConf.getInt("spark.yarn.submit.file.replication",
+ fs.getDefaultReplication(dst)).toShort
+ val localResources = HashMap[String, LocalResource]()
+ FileSystem.mkdirs(fs, dst, new FsPermission(STAGING_DIR_PERMISSION))
+
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+
+ val oldLog4jConf = Option(System.getenv("SPARK_LOG4J_CONF"))
+ if (oldLog4jConf.isDefined) {
+ logWarning(
+ "SPARK_LOG4J_CONF detected in the system environment. This variable has been " +
+ "deprecated. Please refer to the \"Launching Spark on YARN\" documentation " +
+ "for alternatives.")
+ }
+
+ /**
+ * Copy the given main resource to the distributed cache if the scheme is not "local".
+ * Otherwise, set the corresponding key in our SparkConf to handle it downstream.
+ * Each resource is represented by a 4-tuple of:
+ * (1) destination resource name,
+ * (2) local path to the resource,
+ * (3) Spark property key to set if the scheme is not local, and
+ * (4) whether to set permissions for this resource
+ */
+ List(
+ (SPARK_JAR, sparkJar(sparkConf), CONF_SPARK_JAR, false),
+ (APP_JAR, args.userJar, CONF_SPARK_USER_JAR, true),
+ ("log4j.properties", oldLog4jConf.orNull, null, false)
+ ).foreach { case (destName, _localPath, confKey, setPermissions) =>
+ val localPath: String = if (_localPath != null) _localPath.trim() else ""
+ if (!localPath.isEmpty()) {
+ val localURI = new URI(localPath)
+ if (localURI.getScheme != LOCAL_SCHEME) {
+ val src = getQualifiedLocalPath(localURI)
+ val destPath = copyFileToRemote(dst, src, replication, setPermissions)
+ val destFs = FileSystem.get(destPath.toUri(), hadoopConf)
+ distCacheMgr.addResource(destFs, hadoopConf, destPath,
+ localResources, LocalResourceType.FILE, destName, statCache)
+ } else if (confKey != null) {
+ // If the resource is intended for local use only, handle this downstream
+ // by setting the appropriate property
+ sparkConf.set(confKey, localPath)
+ }
+ }
+ }
+
+ /**
+ * Do the same for any additional resources passed in through ClientArguments.
+ * Each resource category is represented by a 3-tuple of:
+ * (1) comma separated list of resources in this category,
+ * (2) resource type, and
+ * (3) whether to add these resources to the classpath
+ */
+ val cachedSecondaryJarLinks = ListBuffer.empty[String]
+ List(
+ (args.addJars, LocalResourceType.FILE, true),
+ (args.files, LocalResourceType.FILE, false),
+ (args.archives, LocalResourceType.ARCHIVE, false)
+ ).foreach { case (flist, resType, addToClasspath) =>
+ if (flist != null && !flist.isEmpty()) {
+ flist.split(',').foreach { file =>
+ val localURI = new URI(file.trim())
+ if (localURI.getScheme != LOCAL_SCHEME) {
+ val localPath = new Path(localURI)
+ val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName())
+ val destPath = copyFileToRemote(dst, localPath, replication)
+ distCacheMgr.addResource(
+ fs, hadoopConf, destPath, localResources, resType, linkname, statCache)
+ if (addToClasspath) {
+ cachedSecondaryJarLinks += linkname
+ }
+ } else if (addToClasspath) {
+ // Resource is intended for local use only and should be added to the class path
+ cachedSecondaryJarLinks += file.trim()
+ }
+ }
+ }
+ }
+ if (cachedSecondaryJarLinks.nonEmpty) {
+ sparkConf.set(CONF_SPARK_YARN_SECONDARY_JARS, cachedSecondaryJarLinks.mkString(","))
+ }
+
+ localResources
+ }
+
+ /**
+ * Set up the environment for launching our ApplicationMaster container.
+ */
+ private def setupLaunchEnv(stagingDir: String): HashMap[String, String] = {
+ logInfo("Setting up the launch environment for our AM container")
+ val env = new HashMap[String, String]()
+ val extraCp = sparkConf.getOption("spark.driver.extraClassPath")
+ populateClasspath(args, yarnConf, sparkConf, env, extraCp)
+ env("SPARK_YARN_MODE") = "true"
+ env("SPARK_YARN_STAGING_DIR") = stagingDir
+ env("SPARK_USER") = UserGroupInformation.getCurrentUser().getShortUserName()
+
+ // Set the environment variables to be passed on to the executors.
+ distCacheMgr.setDistFilesEnv(env)
+ distCacheMgr.setDistArchivesEnv(env)
+
+ // Pick up any environment variables for the AM provided through spark.yarn.appMasterEnv.*
+ val amEnvPrefix = "spark.yarn.appMasterEnv."
+ sparkConf.getAll
+ .filter { case (k, v) => k.startsWith(amEnvPrefix) }
+ .map { case (k, v) => (k.substring(amEnvPrefix.length), v) }
+ .foreach { case (k, v) => YarnSparkHadoopUtil.addPathToEnvironment(env, k, v) }
+
+ // Keep this for backwards compatibility but users should move to the config
+ sys.env.get("SPARK_YARN_USER_ENV").foreach { userEnvs =>
+ // Allow users to specify some environment variables.
+ YarnSparkHadoopUtil.setEnvFromInputString(env, userEnvs)
+ // Pass SPARK_YARN_USER_ENV itself to the AM so it can use it to set up executor environments.
+ env("SPARK_YARN_USER_ENV") = userEnvs
+ }
+
+ // In cluster mode, if the deprecated SPARK_JAVA_OPTS is set, we need to propagate it to
+ // executors. But we can't just set spark.executor.extraJavaOptions, because the driver's
+ // SparkContext will not let that set spark* system properties, which is expected behavior for
+ // Yarn clients. So propagate it through the environment.
+ //
+ // Note that to warn the user about the deprecation in cluster mode, some code from
+ // SparkConf#validateSettings() is duplicated here (to avoid triggering the condition
+ // described above).
+ if (isLaunchingDriver) {
+ sys.env.get("SPARK_JAVA_OPTS").foreach { value =>
+ val warning =
+ s"""
+ |SPARK_JAVA_OPTS was detected (set to '$value').
+ |This is deprecated in Spark 1.0+.
+ |
+ |Please instead use:
+ | - ./spark-submit with conf/spark-defaults.conf to set defaults for an application
+ | - ./spark-submit with --driver-java-options to set -X options for a driver
+ | - spark.executor.extraJavaOptions to set -X options for executors
+ """.stripMargin
+ logWarning(warning)
+ for (proc <- Seq("driver", "executor")) {
+ val key = s"spark.$proc.extraJavaOptions"
+ if (sparkConf.contains(key)) {
+ throw new SparkException(s"Found both $key and SPARK_JAVA_OPTS. Use only the former.")
+ }
+ }
+ env("SPARK_JAVA_OPTS") = value
+ }
+ }
+
+ env
+ }
+
+ /**
+ * Set up a ContainerLaunchContext to launch our ApplicationMaster container.
+ * This sets up the launch environment, java options, and the command for launching the AM.
+ */
+ protected def createContainerLaunchContext(newAppResponse: GetNewApplicationResponse)
+ : ContainerLaunchContext = {
+ logInfo("Setting up container launch context for our AM")
+
+ val appId = newAppResponse.getApplicationId
+ val appStagingDir = getAppStagingDir(appId)
+ val localResources = prepareLocalResources(appStagingDir)
+ val launchEnv = setupLaunchEnv(appStagingDir)
+ val amContainer = Records.newRecord(classOf[ContainerLaunchContext])
+ amContainer.setLocalResources(localResources)
+ amContainer.setEnvironment(launchEnv)
+
+ val javaOpts = ListBuffer[String]()
+
+ // Set the environment variable through a command prefix
+ // to append to the existing value of the variable
+ var prefixEnv: Option[String] = None
+
+ // Add Xmx for AM memory
+ javaOpts += "-Xmx" + args.amMemory + "m"
+
+ val tmpDir = new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR)
+ javaOpts += "-Djava.io.tmpdir=" + tmpDir
+
+ // TODO: Remove once cpuset version is pushed out.
+ // The context is, default gc for server class machines ends up using all cores to do gc -
+ // hence if there are multiple containers in same node, Spark GC affects all other containers'
+ // performance (which can be that of other Spark containers)
+ // Instead of using this, rely on cpusets by YARN to enforce "proper" Spark behavior in
+ // multi-tenant environments. Not sure how default Java GC behaves if it is limited to subset
+ // of cores on a node.
+ val useConcurrentAndIncrementalGC = launchEnv.get("SPARK_USE_CONC_INCR_GC").exists(_.toBoolean)
+ if (useConcurrentAndIncrementalGC) {
+ // In our expts, using (default) throughput collector has severe perf ramifications in
+ // multi-tenant machines
+ javaOpts += "-XX:+UseConcMarkSweepGC"
+ javaOpts += "-XX:+CMSIncrementalMode"
+ javaOpts += "-XX:+CMSIncrementalPacing"
+ javaOpts += "-XX:CMSIncrementalDutyCycleMin=0"
+ javaOpts += "-XX:CMSIncrementalDutyCycle=10"
+ }
+
+ // Forward the Spark configuration to the application master / executors.
+ // TODO: it might be nicer to pass these as an internal environment variable rather than
+ // as Java options, due to complications with string parsing of nested quotes.
+ for ((k, v) <- sparkConf.getAll) {
+ javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v")
+ }
+
+ // Include driver-specific java options if we are launching a driver
+ if (isLaunchingDriver) {
+ sparkConf.getOption("spark.driver.extraJavaOptions")
+ .orElse(sys.env.get("SPARK_JAVA_OPTS"))
+ .foreach(opts => javaOpts += opts)
+ val libraryPaths = Seq(sys.props.get("spark.driver.extraLibraryPath"),
+ sys.props.get("spark.driver.libraryPath")).flatten
+ if (libraryPaths.nonEmpty) {
+ prefixEnv = Some(Utils.libraryPathEnvPrefix(libraryPaths))
+ }
+ }
+
+ // For log4j configuration to reference
+ javaOpts += ("-Dspark.yarn.app.container.log.dir=" + ApplicationConstants.LOG_DIR_EXPANSION_VAR)
+
+ val userClass =
+ if (isLaunchingDriver) {
+ Seq("--class", YarnSparkHadoopUtil.escapeForShell(args.userClass))
+ } else {
+ Nil
+ }
+ val userJar =
+ if (args.userJar != null) {
+ Seq("--jar", args.userJar)
+ } else {
+ Nil
+ }
+ val amClass =
+ if (isLaunchingDriver) {
+ Class.forName("org.apache.spark.deploy.yarn.ApplicationMaster").getName
+ } else {
+ Class.forName("org.apache.spark.deploy.yarn.ExecutorLauncher").getName
+ }
+ val userArgs = args.userArgs.flatMap { arg =>
+ Seq("--arg", YarnSparkHadoopUtil.escapeForShell(arg))
+ }
+ val amArgs =
+ Seq(amClass) ++ userClass ++ userJar ++ userArgs ++
+ Seq(
+ "--executor-memory", args.executorMemory.toString + "m",
+ "--executor-cores", args.executorCores.toString,
+ "--num-executors ", args.numExecutors.toString)
+
+ // Command for the ApplicationMaster
+ val commands = prefixEnv ++ Seq(Environment.JAVA_HOME.$() + "/bin/java", "-server") ++
+ javaOpts ++ amArgs ++
+ Seq(
+ "1>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout",
+ "2>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr")
+
+ // TODO: it would be nicer to just make sure there are no null commands here
+ val printableCommands = commands.map(s => if (s == null) "null" else s).toList
+ amContainer.setCommands(printableCommands)
+
+ logDebug("===============================================================================")
+ logDebug("Yarn AM launch context:")
+ logDebug(s" user class: ${Option(args.userClass).getOrElse("N/A")}")
+ logDebug(" env:")
+ launchEnv.foreach { case (k, v) => logDebug(s" $k -> $v") }
+ logDebug(" resources:")
+ localResources.foreach { case (k, v) => logDebug(s" $k -> $v")}
+ logDebug(" command:")
+ logDebug(s" ${printableCommands.mkString(" ")}")
+ logDebug("===============================================================================")
+
+ // send the acl settings into YARN to control who has access via YARN interfaces
+ val securityManager = new SecurityManager(sparkConf)
+ amContainer.setApplicationACLs(YarnSparkHadoopUtil.getApplicationAclsForYarn(securityManager))
+ setupSecurityToken(amContainer)
+ UserGroupInformation.getCurrentUser().addCredentials(credentials)
+
+ amContainer
+ }
+
+ /**
+ * Report the state of an application until it has exited, either successfully or
+ * due to some failure, then return a pair of the yarn application state (FINISHED, FAILED,
+ * KILLED, or RUNNING) and the final application state (UNDEFINED, SUCCEEDED, FAILED,
+ * or KILLED).
+ *
+ * @param appId ID of the application to monitor.
+ * @param returnOnRunning Whether to also return the application state when it is RUNNING.
+ * @param logApplicationReport Whether to log details of the application report every iteration.
+ * @return A pair of the yarn application state and the final application state.
+ */
+ def monitorApplication(
+ appId: ApplicationId,
+ returnOnRunning: Boolean = false,
+ logApplicationReport: Boolean = true): (YarnApplicationState, FinalApplicationStatus) = {
+ val interval = sparkConf.getLong("spark.yarn.report.interval", 1000)
+ var lastState: YarnApplicationState = null
+ while (true) {
+ Thread.sleep(interval)
+ val report = getApplicationReport(appId)
+ val state = report.getYarnApplicationState
+
+ if (logApplicationReport) {
+ logInfo(s"Application report for $appId (state: $state)")
+ val details = Seq[(String, String)](
+ ("client token", getClientToken(report)),
+ ("diagnostics", report.getDiagnostics),
+ ("ApplicationMaster host", report.getHost),
+ ("ApplicationMaster RPC port", report.getRpcPort.toString),
+ ("queue", report.getQueue),
+ ("start time", report.getStartTime.toString),
+ ("final status", report.getFinalApplicationStatus.toString),
+ ("tracking URL", report.getTrackingUrl),
+ ("user", report.getUser)
+ )
+
+ // Use more loggable format if value is null or empty
+ val formattedDetails = details
+ .map { case (k, v) =>
+ val newValue = Option(v).filter(_.nonEmpty).getOrElse("N/A")
+ s"\n\t $k: $newValue" }
+ .mkString("")
+
+ // If DEBUG is enabled, log report details every iteration
+ // Otherwise, log them every time the application changes state
+ if (log.isDebugEnabled) {
+ logDebug(formattedDetails)
+ } else if (lastState != state) {
+ logInfo(formattedDetails)
+ }
+ }
+
+ if (state == YarnApplicationState.FINISHED ||
+ state == YarnApplicationState.FAILED ||
+ state == YarnApplicationState.KILLED) {
+ return (state, report.getFinalApplicationStatus)
+ }
+
+ if (returnOnRunning && state == YarnApplicationState.RUNNING) {
+ return (state, report.getFinalApplicationStatus)
+ }
+
+ lastState = state
+ }
+
+ // Never reached, but keeps compiler happy
+ throw new SparkException("While loop is depleted! This should never happen...")
+ }
+
+ /**
+ * Submit an application to the ResourceManager and monitor its state.
+ * This continues until the application has exited for any reason.
+ * If the application finishes with a failed, killed, or undefined status,
+ * throw an appropriate SparkException.
+ */
+ def run(): Unit = {
+ val (yarnApplicationState, finalApplicationStatus) = monitorApplication(submitApplication())
+ if (yarnApplicationState == YarnApplicationState.FAILED ||
+ finalApplicationStatus == FinalApplicationStatus.FAILED) {
+ throw new SparkException("Application finished with failed status")
+ }
+ if (yarnApplicationState == YarnApplicationState.KILLED ||
+ finalApplicationStatus == FinalApplicationStatus.KILLED) {
+ throw new SparkException("Application is killed")
+ }
+ if (finalApplicationStatus == FinalApplicationStatus.UNDEFINED) {
+ throw new SparkException("The final status of application is undefined")
+ }
+ }
+
+ /* --------------------------------------------------------------------------------------- *
+ | Methods that cannot be implemented here due to API differences across hadoop versions |
+ * --------------------------------------------------------------------------------------- */
+
+ /** Submit an application running our ApplicationMaster to the ResourceManager. */
+ def submitApplication(): ApplicationId
+
+ /** Set up security tokens for launching our ApplicationMaster container. */
+ protected def setupSecurityToken(containerContext: ContainerLaunchContext): Unit
+
+ /** Get the application report from the ResourceManager for an application we have submitted. */
+ protected def getApplicationReport(appId: ApplicationId): ApplicationReport
+
+ /**
+ * Return the security token used by this client to communicate with the ApplicationMaster.
+ * If no security is enabled, the token returned by the report is null.
+ */
+ protected def getClientToken(report: ApplicationReport): String
+}
+
+private[spark] object ClientBase extends Logging {
+
+ // Alias for the Spark assembly jar and the user jar
+ val SPARK_JAR: String = "__spark__.jar"
+ val APP_JAR: String = "__app__.jar"
+
+ // URI scheme that identifies local resources
+ val LOCAL_SCHEME = "local"
+
+ // Staging directory for any temporary jars or files
+ val SPARK_STAGING: String = ".sparkStaging"
+
+ // Location of any user-defined Spark jars
+ val CONF_SPARK_JAR = "spark.yarn.jar"
+ val ENV_SPARK_JAR = "SPARK_JAR"
+
+ // Internal config to propagate the location of the user's jar to the driver/executors
+ val CONF_SPARK_USER_JAR = "spark.yarn.user.jar"
+
+ // Internal config to propagate the locations of any extra jars to add to the classpath
+ // of the executors
+ val CONF_SPARK_YARN_SECONDARY_JARS = "spark.yarn.secondary.jars"
+
+ // Staging directory is private! -> rwx--------
+ val STAGING_DIR_PERMISSION: FsPermission =
+ FsPermission.createImmutable(Integer.parseInt("700", 8).toShort)
+
+ // App files are world-wide readable and owner writable -> rw-r--r--
+ val APP_FILE_PERMISSION: FsPermission =
+ FsPermission.createImmutable(Integer.parseInt("644", 8).toShort)
+
+ /**
+ * Find the user-defined Spark jar if configured, or return the jar containing this
+ * class if not.
+ *
+ * This method first looks in the SparkConf object for the CONF_SPARK_JAR key, and in the
+ * user environment if that is not found (for backwards compatibility).
+ */
+ private def sparkJar(conf: SparkConf): String = {
+ if (conf.contains(CONF_SPARK_JAR)) {
+ conf.get(CONF_SPARK_JAR)
+ } else if (System.getenv(ENV_SPARK_JAR) != null) {
+ logWarning(
+ s"$ENV_SPARK_JAR detected in the system environment. This variable has been deprecated " +
+ s"in favor of the $CONF_SPARK_JAR configuration variable.")
+ System.getenv(ENV_SPARK_JAR)
+ } else {
+ SparkContext.jarOfClass(this.getClass).head
+ }
+ }
+
+ /**
+ * Return the path to the given application's staging directory.
+ */
+ private def getAppStagingDir(appId: ApplicationId): String = {
+ SPARK_STAGING + Path.SEPARATOR + appId.toString() + Path.SEPARATOR
+ }
+
+ /**
+ * Populate the classpath entry in the given environment map with any application
+ * classpath specified through the Hadoop and Yarn configurations.
+ */
+ def populateHadoopClasspath(conf: Configuration, env: HashMap[String, String]): Unit = {
+ val classPathElementsToAdd = getYarnAppClasspath(conf) ++ getMRAppClasspath(conf)
+ for (c <- classPathElementsToAdd.flatten) {
+ YarnSparkHadoopUtil.addPathToEnvironment(env, Environment.CLASSPATH.name, c.trim)
+ }
+ }
+
+ private def getYarnAppClasspath(conf: Configuration): Option[Seq[String]] =
+ Option(conf.getStrings(YarnConfiguration.YARN_APPLICATION_CLASSPATH)) match {
+ case Some(s) => Some(s.toSeq)
+ case None => getDefaultYarnApplicationClasspath
+ }
+
+ private def getMRAppClasspath(conf: Configuration): Option[Seq[String]] =
+ Option(conf.getStrings("mapreduce.application.classpath")) match {
+ case Some(s) => Some(s.toSeq)
+ case None => getDefaultMRApplicationClasspath
+ }
+
+ def getDefaultYarnApplicationClasspath: Option[Seq[String]] = {
+ val triedDefault = Try[Seq[String]] {
+ val field = classOf[YarnConfiguration].getField("DEFAULT_YARN_APPLICATION_CLASSPATH")
+ val value = field.get(null).asInstanceOf[Array[String]]
+ value.toSeq
+ } recoverWith {
+ case e: NoSuchFieldException => Success(Seq.empty[String])
+ }
+
+ triedDefault match {
+ case f: Failure[_] =>
+ logError("Unable to obtain the default YARN Application classpath.", f.exception)
+ case s: Success[_] =>
+ logDebug(s"Using the default YARN application classpath: ${s.get.mkString(",")}")
+ }
+
+ triedDefault.toOption
+ }
+
+ /**
+ * In Hadoop 0.23, the MR application classpath comes with the YARN application
+ * classpath. In Hadoop 2.0, it's an array of Strings, and in 2.2+ it's a String.
+ * So we need to use reflection to retrieve it.
+ */
+ def getDefaultMRApplicationClasspath: Option[Seq[String]] = {
+ val triedDefault = Try[Seq[String]] {
+ val field = classOf[MRJobConfig].getField("DEFAULT_MAPREDUCE_APPLICATION_CLASSPATH")
+ val value = if (field.getType == classOf[String]) {
+ StringUtils.getStrings(field.get(null).asInstanceOf[String]).toArray
+ } else {
+ field.get(null).asInstanceOf[Array[String]]
+ }
+ value.toSeq
+ } recoverWith {
+ case e: NoSuchFieldException => Success(Seq.empty[String])
+ }
+
+ triedDefault match {
+ case f: Failure[_] =>
+ logError("Unable to obtain the default MR Application classpath.", f.exception)
+ case s: Success[_] =>
+ logDebug(s"Using the default MR application classpath: ${s.get.mkString(",")}")
+ }
+
+ triedDefault.toOption
+ }
+
+ /**
+ * Populate the classpath entry in the given environment map.
+ * This includes the user jar, Spark jar, and any extra application jars.
+ */
+ def populateClasspath(
+ args: ClientArguments,
+ conf: Configuration,
+ sparkConf: SparkConf,
+ env: HashMap[String, String],
+ extraClassPath: Option[String] = None): Unit = {
+ extraClassPath.foreach(addClasspathEntry(_, env))
+ addClasspathEntry(Environment.PWD.$(), env)
+
+ // Normally the users app.jar is last in case conflicts with spark jars
+ if (sparkConf.get("spark.yarn.user.classpath.first", "false").toBoolean) {
+ addUserClasspath(args, sparkConf, env)
+ addFileToClasspath(sparkJar(sparkConf), SPARK_JAR, env)
+ populateHadoopClasspath(conf, env)
+ } else {
+ addFileToClasspath(sparkJar(sparkConf), SPARK_JAR, env)
+ populateHadoopClasspath(conf, env)
+ addUserClasspath(args, sparkConf, env)
+ }
+
+ // Append all jar files under the working directory to the classpath.
+ addClasspathEntry(Environment.PWD.$() + Path.SEPARATOR + "*", env)
+ }
+
+ /**
+ * Adds the user jars which have local: URIs (or alternate names, such as APP_JAR) explicitly
+ * to the classpath.
+ */
+ private def addUserClasspath(
+ args: ClientArguments,
+ conf: SparkConf,
+ env: HashMap[String, String]): Unit = {
+
+ // If `args` is not null, we are launching an AM container.
+ // Otherwise, we are launching executor containers.
+ val (mainJar, secondaryJars) =
+ if (args != null) {
+ (args.userJar, args.addJars)
+ } else {
+ (conf.get(CONF_SPARK_USER_JAR, null), conf.get(CONF_SPARK_YARN_SECONDARY_JARS, null))
+ }
+
+ addFileToClasspath(mainJar, APP_JAR, env)
+ if (secondaryJars != null) {
+ secondaryJars.split(",").filter(_.nonEmpty).foreach { jar =>
+ addFileToClasspath(jar, null, env)
+ }
+ }
+ }
+
+ /**
+ * Adds the given path to the classpath, handling "local:" URIs correctly.
+ *
+ * If an alternate name for the file is given, and it's not a "local:" file, the alternate
+ * name will be added to the classpath (relative to the job's work directory).
+ *
+ * If not a "local:" file and no alternate name, the environment is not modified.
+ *
+ * @param path Path to add to classpath (optional).
+ * @param fileName Alternate name for the file (optional).
+ * @param env Map holding the environment variables.
+ */
+ private def addFileToClasspath(
+ path: String,
+ fileName: String,
+ env: HashMap[String, String]): Unit = {
+ if (path != null) {
+ scala.util.control.Exception.ignoring(classOf[URISyntaxException]) {
+ val uri = new URI(path)
+ if (uri.getScheme == LOCAL_SCHEME) {
+ addClasspathEntry(uri.getPath, env)
+ return
+ }
+ }
+ }
+ if (fileName != null) {
+ addClasspathEntry(Environment.PWD.$() + Path.SEPARATOR + fileName, env)
+ }
+ }
+
+ /**
+ * Add the given path to the classpath entry of the given environment map.
+ * If the classpath is already set, this appends the new path to the existing classpath.
+ */
+ private def addClasspathEntry(path: String, env: HashMap[String, String]): Unit =
+ YarnSparkHadoopUtil.addPathToEnvironment(env, Environment.CLASSPATH.name, path)
+
+ /**
+ * Get the list of namenodes the user may access.
+ */
+ def getNameNodesToAccess(sparkConf: SparkConf): Set[Path] = {
+ sparkConf.get("spark.yarn.access.namenodes", "")
+ .split(",")
+ .map(_.trim())
+ .filter(!_.isEmpty)
+ .map(new Path(_))
+ .toSet
+ }
+
+ def getTokenRenewer(conf: Configuration): String = {
+ val delegTokenRenewer = Master.getMasterPrincipal(conf)
+ logDebug("delegation token renewer is: " + delegTokenRenewer)
+ if (delegTokenRenewer == null || delegTokenRenewer.length() == 0) {
+ val errorMessage = "Can't get Master Kerberos principal for use as renewer"
+ logError(errorMessage)
+ throw new SparkException(errorMessage)
+ }
+ delegTokenRenewer
+ }
+
+ /**
+ * Obtains tokens for the namenodes passed in and adds them to the credentials.
+ */
+ def obtainTokensForNamenodes(
+ paths: Set[Path],
+ conf: Configuration,
+ creds: Credentials): Unit = {
+ if (UserGroupInformation.isSecurityEnabled()) {
+ val delegTokenRenewer = getTokenRenewer(conf)
+ paths.foreach { dst =>
+ val dstFs = dst.getFileSystem(conf)
+ logDebug("getting token for namenode: " + dst)
+ dstFs.addDelegationTokens(delegTokenRenewer, creds)
+ }
+ }
+ }
+
+ /**
+ * Return whether the two file systems are the same.
+ */
+ private def compareFs(srcFs: FileSystem, destFs: FileSystem): Boolean = {
+ val srcUri = srcFs.getUri()
+ val dstUri = destFs.getUri()
+ if (srcUri.getScheme() == null || srcUri.getScheme() != dstUri.getScheme()) {
+ return false
+ }
+
+ var srcHost = srcUri.getHost()
+ var dstHost = dstUri.getHost()
+
+ // In HA or when using viewfs, the host part of the URI may not actually be a host, but the
+ // name of the HDFS namespace. Those names won't resolve, so avoid even trying if they
+ // match.
+ if (srcHost != null && dstHost != null && srcHost != dstHost) {
+ try {
+ srcHost = InetAddress.getByName(srcHost).getCanonicalHostName()
+ dstHost = InetAddress.getByName(dstHost).getCanonicalHostName()
+ } catch {
+ case e: UnknownHostException =>
+ return false
+ }
+ }
+
+ Objects.equal(srcHost, dstHost) && srcUri.getPort() == dstUri.getPort()
+ }
+
+}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala
new file mode 100644
index 0000000000..c592ecfdfc
--- /dev/null
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala
@@ -0,0 +1,207 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.net.URI
+
+import scala.collection.mutable.{HashMap, LinkedHashMap, Map}
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
+import org.apache.hadoop.fs.permission.FsAction
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.util.{Records, ConverterUtils}
+
+import org.apache.spark.Logging
+
+/** Client side methods to setup the Hadoop distributed cache */
+private[spark] class ClientDistributedCacheManager() extends Logging {
+
+ // Mappings from remote URI to (file status, modification time, visibility)
+ private val distCacheFiles: Map[String, (String, String, String)] =
+ LinkedHashMap[String, (String, String, String)]()
+ private val distCacheArchives: Map[String, (String, String, String)] =
+ LinkedHashMap[String, (String, String, String)]()
+
+
+ /**
+ * Add a resource to the list of distributed cache resources. This list can
+ * be sent to the ApplicationMaster and possibly the executors so that it can
+ * be downloaded into the Hadoop distributed cache for use by this application.
+ * Adds the LocalResource to the localResources HashMap passed in and saves
+ * the stats of the resources to they can be sent to the executors and verified.
+ *
+ * @param fs FileSystem
+ * @param conf Configuration
+ * @param destPath path to the resource
+ * @param localResources localResource hashMap to insert the resource into
+ * @param resourceType LocalResourceType
+ * @param link link presented in the distributed cache to the destination
+ * @param statCache cache to store the file/directory stats
+ * @param appMasterOnly Whether to only add the resource to the app master
+ */
+ def addResource(
+ fs: FileSystem,
+ conf: Configuration,
+ destPath: Path,
+ localResources: HashMap[String, LocalResource],
+ resourceType: LocalResourceType,
+ link: String,
+ statCache: Map[URI, FileStatus],
+ appMasterOnly: Boolean = false): Unit = {
+ val destStatus = fs.getFileStatus(destPath)
+ val amJarRsrc = Records.newRecord(classOf[LocalResource])
+ amJarRsrc.setType(resourceType)
+ val visibility = getVisibility(conf, destPath.toUri(), statCache)
+ amJarRsrc.setVisibility(visibility)
+ amJarRsrc.setResource(ConverterUtils.getYarnUrlFromPath(destPath))
+ amJarRsrc.setTimestamp(destStatus.getModificationTime())
+ amJarRsrc.setSize(destStatus.getLen())
+ if (link == null || link.isEmpty()) throw new Exception("You must specify a valid link name")
+ localResources(link) = amJarRsrc
+
+ if (!appMasterOnly) {
+ val uri = destPath.toUri()
+ val pathURI = new URI(uri.getScheme(), uri.getAuthority(), uri.getPath(), null, link)
+ if (resourceType == LocalResourceType.FILE) {
+ distCacheFiles(pathURI.toString()) = (destStatus.getLen().toString(),
+ destStatus.getModificationTime().toString(), visibility.name())
+ } else {
+ distCacheArchives(pathURI.toString()) = (destStatus.getLen().toString(),
+ destStatus.getModificationTime().toString(), visibility.name())
+ }
+ }
+ }
+
+ /**
+ * Adds the necessary cache file env variables to the env passed in
+ */
+ def setDistFilesEnv(env: Map[String, String]): Unit = {
+ val (keys, tupleValues) = distCacheFiles.unzip
+ val (sizes, timeStamps, visibilities) = tupleValues.unzip3
+ if (keys.size > 0) {
+ env("SPARK_YARN_CACHE_FILES") = keys.reduceLeft[String] { (acc,n) => acc + "," + n }
+ env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") =
+ timeStamps.reduceLeft[String] { (acc,n) => acc + "," + n }
+ env("SPARK_YARN_CACHE_FILES_FILE_SIZES") =
+ sizes.reduceLeft[String] { (acc,n) => acc + "," + n }
+ env("SPARK_YARN_CACHE_FILES_VISIBILITIES") =
+ visibilities.reduceLeft[String] { (acc,n) => acc + "," + n }
+ }
+ }
+
+ /**
+ * Adds the necessary cache archive env variables to the env passed in
+ */
+ def setDistArchivesEnv(env: Map[String, String]): Unit = {
+ val (keys, tupleValues) = distCacheArchives.unzip
+ val (sizes, timeStamps, visibilities) = tupleValues.unzip3
+ if (keys.size > 0) {
+ env("SPARK_YARN_CACHE_ARCHIVES") = keys.reduceLeft[String] { (acc,n) => acc + "," + n }
+ env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") =
+ timeStamps.reduceLeft[String] { (acc,n) => acc + "," + n }
+ env("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") =
+ sizes.reduceLeft[String] { (acc,n) => acc + "," + n }
+ env("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") =
+ visibilities.reduceLeft[String] { (acc,n) => acc + "," + n }
+ }
+ }
+
+ /**
+ * Returns the local resource visibility depending on the cache file permissions
+ * @return LocalResourceVisibility
+ */
+ def getVisibility(
+ conf: Configuration,
+ uri: URI,
+ statCache: Map[URI, FileStatus]): LocalResourceVisibility = {
+ if (isPublic(conf, uri, statCache)) {
+ LocalResourceVisibility.PUBLIC
+ } else {
+ LocalResourceVisibility.PRIVATE
+ }
+ }
+
+ /**
+ * Returns a boolean to denote whether a cache file is visible to all (public)
+ * @return true if the path in the uri is visible to all, false otherwise
+ */
+ def isPublic(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): Boolean = {
+ val fs = FileSystem.get(uri, conf)
+ val current = new Path(uri.getPath())
+ // the leaf level file should be readable by others
+ if (!checkPermissionOfOther(fs, current, FsAction.READ, statCache)) {
+ return false
+ }
+ ancestorsHaveExecutePermissions(fs, current.getParent(), statCache)
+ }
+
+ /**
+ * Returns true if all ancestors of the specified path have the 'execute'
+ * permission set for all users (i.e. that other users can traverse
+ * the directory hierarchy to the given path)
+ * @return true if all ancestors have the 'execute' permission set for all users
+ */
+ def ancestorsHaveExecutePermissions(
+ fs: FileSystem,
+ path: Path,
+ statCache: Map[URI, FileStatus]): Boolean = {
+ var current = path
+ while (current != null) {
+ // the subdirs in the path should have execute permissions for others
+ if (!checkPermissionOfOther(fs, current, FsAction.EXECUTE, statCache)) {
+ return false
+ }
+ current = current.getParent()
+ }
+ true
+ }
+
+ /**
+ * Checks for a given path whether the Other permissions on it
+ * imply the permission in the passed FsAction
+ * @return true if the path in the uri is visible to all, false otherwise
+ */
+ def checkPermissionOfOther(
+ fs: FileSystem,
+ path: Path,
+ action: FsAction,
+ statCache: Map[URI, FileStatus]): Boolean = {
+ val status = getFileStatus(fs, path.toUri(), statCache)
+ val perms = status.getPermission()
+ val otherAction = perms.getOtherAction()
+ otherAction.implies(action)
+ }
+
+ /**
+ * Checks to see if the given uri exists in the cache, if it does it
+ * returns the existing FileStatus, otherwise it stats the uri, stores
+ * it in the cache, and returns the FileStatus.
+ * @return FileStatus
+ */
+ def getFileStatus(fs: FileSystem, uri: URI, statCache: Map[URI, FileStatus]): FileStatus = {
+ val stat = statCache.get(uri) match {
+ case Some(existstat) => existstat
+ case None =>
+ val newStat = fs.getFileStatus(new Path(uri))
+ statCache.put(uri, newStat)
+ newStat
+ }
+ stat
+ }
+}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
new file mode 100644
index 0000000000..fdd3c2300f
--- /dev/null
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.nio.ByteBuffer
+import java.security.PrivilegedExceptionAction
+
+import scala.collection.JavaConversions._
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.io.DataOutputBuffer
+import org.apache.hadoop.net.NetUtils
+import org.apache.hadoop.security.UserGroupInformation
+import org.apache.hadoop.yarn.api._
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.api.records.impl.pb.ProtoUtils
+import org.apache.hadoop.yarn.api.protocolrecords._
+import org.apache.hadoop.yarn.client.api.NMClient
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.ipc.YarnRPC
+import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records}
+
+import org.apache.spark.{SecurityManager, SparkConf, Logging}
+import org.apache.spark.network.util.JavaUtils
+
+
+class ExecutorRunnable(
+ container: Container,
+ conf: Configuration,
+ spConf: SparkConf,
+ masterAddress: String,
+ slaveId: String,
+ hostname: String,
+ executorMemory: Int,
+ executorCores: Int,
+ appId: String,
+ securityMgr: SecurityManager)
+ extends Runnable with ExecutorRunnableUtil with Logging {
+
+ var rpc: YarnRPC = YarnRPC.create(conf)
+ var nmClient: NMClient = _
+ val sparkConf = spConf
+ val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
+
+ def run = {
+ logInfo("Starting Executor Container")
+ nmClient = NMClient.createNMClient()
+ nmClient.init(yarnConf)
+ nmClient.start()
+ startContainer
+ }
+
+ def startContainer = {
+ logInfo("Setting up ContainerLaunchContext")
+
+ val ctx = Records.newRecord(classOf[ContainerLaunchContext])
+ .asInstanceOf[ContainerLaunchContext]
+
+ val localResources = prepareLocalResources
+ ctx.setLocalResources(localResources)
+
+ ctx.setEnvironment(env)
+
+ val credentials = UserGroupInformation.getCurrentUser().getCredentials()
+ val dob = new DataOutputBuffer()
+ credentials.writeTokenStorageToStream(dob)
+ ctx.setTokens(ByteBuffer.wrap(dob.getData()))
+
+ val commands = prepareCommand(masterAddress, slaveId, hostname, executorMemory, executorCores,
+ appId, localResources)
+
+ logInfo(s"Setting up executor with environment: $env")
+ logInfo("Setting up executor with commands: " + commands)
+ ctx.setCommands(commands)
+
+ ctx.setApplicationACLs(YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr))
+
+ // If external shuffle service is enabled, register with the Yarn shuffle service already
+ // started on the NodeManager and, if authentication is enabled, provide it with our secret
+ // key for fetching shuffle files later
+ if (sparkConf.getBoolean("spark.shuffle.service.enabled", false)) {
+ val secretString = securityMgr.getSecretKey()
+ val secretBytes =
+ if (secretString != null) {
+ // This conversion must match how the YarnShuffleService decodes our secret
+ JavaUtils.stringToBytes(secretString)
+ } else {
+ // Authentication is not enabled, so just provide dummy metadata
+ ByteBuffer.allocate(0)
+ }
+ ctx.setServiceData(Map[String, ByteBuffer]("spark_shuffle" -> secretBytes))
+ }
+
+ // Send the start request to the ContainerManager
+ nmClient.startContainer(container, ctx)
+ }
+
+}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala
new file mode 100644
index 0000000000..22d73ecf6d
--- /dev/null
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala
@@ -0,0 +1,203 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.net.URI
+
+import scala.collection.JavaConversions._
+import scala.collection.mutable.{HashMap, ListBuffer}
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.yarn.api._
+import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
+
+import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.util.Utils
+
+trait ExecutorRunnableUtil extends Logging {
+
+ val yarnConf: YarnConfiguration
+ val sparkConf: SparkConf
+ lazy val env = prepareEnvironment
+
+ def prepareCommand(
+ masterAddress: String,
+ slaveId: String,
+ hostname: String,
+ executorMemory: Int,
+ executorCores: Int,
+ appId: String,
+ localResources: HashMap[String, LocalResource]): List[String] = {
+ // Extra options for the JVM
+ val javaOpts = ListBuffer[String]()
+
+ // Set the environment variable through a command prefix
+ // to append to the existing value of the variable
+ var prefixEnv: Option[String] = None
+
+ // Set the JVM memory
+ val executorMemoryString = executorMemory + "m"
+ javaOpts += "-Xms" + executorMemoryString + " -Xmx" + executorMemoryString + " "
+
+ // Set extra Java options for the executor, if defined
+ sys.props.get("spark.executor.extraJavaOptions").foreach { opts =>
+ javaOpts += opts
+ }
+ sys.env.get("SPARK_JAVA_OPTS").foreach { opts =>
+ javaOpts += opts
+ }
+ sys.props.get("spark.executor.extraLibraryPath").foreach { p =>
+ prefixEnv = Some(Utils.libraryPathEnvPrefix(Seq(p)))
+ }
+
+ javaOpts += "-Djava.io.tmpdir=" +
+ new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR)
+
+ // Certain configs need to be passed here because they are needed before the Executor
+ // registers with the Scheduler and transfers the spark configs. Since the Executor backend
+ // uses Akka to connect to the scheduler, the akka settings are needed as well as the
+ // authentication settings.
+ sparkConf.getAll.
+ filter { case (k, v) => k.startsWith("spark.auth") || k.startsWith("spark.akka") }.
+ foreach { case (k, v) => javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v") }
+
+ sparkConf.getAkkaConf.
+ foreach { case (k, v) => javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v") }
+
+ // Commenting it out for now - so that people can refer to the properties if required. Remove
+ // it once cpuset version is pushed out.
+ // The context is, default gc for server class machines end up using all cores to do gc - hence
+ // if there are multiple containers in same node, spark gc effects all other containers
+ // performance (which can also be other spark containers)
+ // Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in
+ // multi-tenant environments. Not sure how default java gc behaves if it is limited to subset
+ // of cores on a node.
+ /*
+ else {
+ // If no java_opts specified, default to using -XX:+CMSIncrementalMode
+ // It might be possible that other modes/config is being done in
+ // spark.executor.extraJavaOptions, so we dont want to mess with it.
+ // In our expts, using (default) throughput collector has severe perf ramnifications in
+ // multi-tennent machines
+ // The options are based on
+ // http://www.oracle.com/technetwork/java/gc-tuning-5-138395.html#0.0.0.%20When%20to%20Use
+ // %20the%20Concurrent%20Low%20Pause%20Collector|outline
+ javaOpts += " -XX:+UseConcMarkSweepGC "
+ javaOpts += " -XX:+CMSIncrementalMode "
+ javaOpts += " -XX:+CMSIncrementalPacing "
+ javaOpts += " -XX:CMSIncrementalDutyCycleMin=0 "
+ javaOpts += " -XX:CMSIncrementalDutyCycle=10 "
+ }
+ */
+
+ // For log4j configuration to reference
+ javaOpts += ("-Dspark.yarn.app.container.log.dir=" + ApplicationConstants.LOG_DIR_EXPANSION_VAR)
+
+ val commands = prefixEnv ++ Seq(Environment.JAVA_HOME.$() + "/bin/java",
+ "-server",
+ // Kill if OOM is raised - leverage yarn's failure handling to cause rescheduling.
+ // Not killing the task leaves various aspects of the executor and (to some extent) the jvm in
+ // an inconsistent state.
+ // TODO: If the OOM is not recoverable by rescheduling it on different node, then do
+ // 'something' to fail job ... akin to blacklisting trackers in mapred ?
+ "-XX:OnOutOfMemoryError='kill %p'") ++
+ javaOpts ++
+ Seq("org.apache.spark.executor.CoarseGrainedExecutorBackend",
+ masterAddress.toString,
+ slaveId.toString,
+ hostname.toString,
+ executorCores.toString,
+ appId,
+ "1>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout",
+ "2>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr")
+
+ // TODO: it would be nicer to just make sure there are no null commands here
+ commands.map(s => if (s == null) "null" else s).toList
+ }
+
+ private def setupDistributedCache(
+ file: String,
+ rtype: LocalResourceType,
+ localResources: HashMap[String, LocalResource],
+ timestamp: String,
+ size: String,
+ vis: String): Unit = {
+ val uri = new URI(file)
+ val amJarRsrc = Records.newRecord(classOf[LocalResource])
+ amJarRsrc.setType(rtype)
+ amJarRsrc.setVisibility(LocalResourceVisibility.valueOf(vis))
+ amJarRsrc.setResource(ConverterUtils.getYarnUrlFromURI(uri))
+ amJarRsrc.setTimestamp(timestamp.toLong)
+ amJarRsrc.setSize(size.toLong)
+ localResources(uri.getFragment()) = amJarRsrc
+ }
+
+ def prepareLocalResources: HashMap[String, LocalResource] = {
+ logInfo("Preparing Local resources")
+ val localResources = HashMap[String, LocalResource]()
+
+ if (System.getenv("SPARK_YARN_CACHE_FILES") != null) {
+ val timeStamps = System.getenv("SPARK_YARN_CACHE_FILES_TIME_STAMPS").split(',')
+ val fileSizes = System.getenv("SPARK_YARN_CACHE_FILES_FILE_SIZES").split(',')
+ val distFiles = System.getenv("SPARK_YARN_CACHE_FILES").split(',')
+ val visibilities = System.getenv("SPARK_YARN_CACHE_FILES_VISIBILITIES").split(',')
+ for( i <- 0 to distFiles.length - 1) {
+ setupDistributedCache(distFiles(i), LocalResourceType.FILE, localResources, timeStamps(i),
+ fileSizes(i), visibilities(i))
+ }
+ }
+
+ if (System.getenv("SPARK_YARN_CACHE_ARCHIVES") != null) {
+ val timeStamps = System.getenv("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS").split(',')
+ val fileSizes = System.getenv("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES").split(',')
+ val distArchives = System.getenv("SPARK_YARN_CACHE_ARCHIVES").split(',')
+ val visibilities = System.getenv("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES").split(',')
+ for( i <- 0 to distArchives.length - 1) {
+ setupDistributedCache(distArchives(i), LocalResourceType.ARCHIVE, localResources,
+ timeStamps(i), fileSizes(i), visibilities(i))
+ }
+ }
+
+ logInfo("Prepared Local resources " + localResources)
+ localResources
+ }
+
+ def prepareEnvironment: HashMap[String, String] = {
+ val env = new HashMap[String, String]()
+ val extraCp = sparkConf.getOption("spark.executor.extraClassPath")
+ ClientBase.populateClasspath(null, yarnConf, sparkConf, env, extraCp)
+
+ sparkConf.getExecutorEnv.foreach { case (key, value) =>
+ // This assumes each executor environment variable set here is a path
+ // This is kept for backward compatibility and consistency with hadoop
+ YarnSparkHadoopUtil.addPathToEnvironment(env, key, value)
+ }
+
+ // Keep this for backwards compatibility but users should move to the config
+ sys.env.get("SPARK_YARN_USER_ENV").foreach { userEnvs =>
+ YarnSparkHadoopUtil.setEnvFromInputString(env, userEnvs)
+ }
+
+ System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k, v) => env(k) = v }
+ env
+ }
+
+}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
new file mode 100644
index 0000000000..2bbf5d7db8
--- /dev/null
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import scala.collection.JavaConversions._
+import scala.collection.mutable.{ArrayBuffer, HashMap}
+
+import org.apache.spark.{SecurityManager, SparkConf}
+import org.apache.spark.scheduler.SplitInfo
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.api.protocolrecords.AllocateResponse
+import org.apache.hadoop.yarn.client.api.AMRMClient
+import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
+import org.apache.hadoop.yarn.util.Records
+
+/**
+ * Acquires resources for executors from a ResourceManager and launches executors in new containers.
+ */
+private[yarn] class YarnAllocationHandler(
+ conf: Configuration,
+ sparkConf: SparkConf,
+ amClient: AMRMClient[ContainerRequest],
+ appAttemptId: ApplicationAttemptId,
+ args: ApplicationMasterArguments,
+ preferredNodes: collection.Map[String, collection.Set[SplitInfo]],
+ securityMgr: SecurityManager)
+ extends YarnAllocator(conf, sparkConf, appAttemptId, args, preferredNodes, securityMgr) {
+
+ override protected def releaseContainer(container: Container) = {
+ amClient.releaseAssignedContainer(container.getId())
+ }
+
+ // pending isn't used on stable as the AMRMClient handles incremental asks
+ override protected def allocateContainers(count: Int, pending: Int): YarnAllocateResponse = {
+ addResourceRequests(count)
+
+ // We have already set the container request. Poll the ResourceManager for a response.
+ // This doubles as a heartbeat if there are no pending container requests.
+ val progressIndicator = 0.1f
+ new StableAllocateResponse(amClient.allocate(progressIndicator))
+ }
+
+ private def createRackResourceRequests(
+ hostContainers: ArrayBuffer[ContainerRequest]
+ ): ArrayBuffer[ContainerRequest] = {
+ // Generate modified racks and new set of hosts under it before issuing requests.
+ val rackToCounts = new HashMap[String, Int]()
+
+ for (container <- hostContainers) {
+ val candidateHost = container.getNodes.last
+ assert(YarnSparkHadoopUtil.ANY_HOST != candidateHost)
+
+ val rack = YarnSparkHadoopUtil.lookupRack(conf, candidateHost)
+ if (rack != null) {
+ var count = rackToCounts.getOrElse(rack, 0)
+ count += 1
+ rackToCounts.put(rack, count)
+ }
+ }
+
+ val requestedContainers = new ArrayBuffer[ContainerRequest](rackToCounts.size)
+ for ((rack, count) <- rackToCounts) {
+ requestedContainers ++= createResourceRequests(
+ AllocationType.RACK,
+ rack,
+ count,
+ YarnSparkHadoopUtil.RM_REQUEST_PRIORITY)
+ }
+
+ requestedContainers
+ }
+
+ private def addResourceRequests(numExecutors: Int) {
+ val containerRequests: List[ContainerRequest] =
+ if (numExecutors <= 0) {
+ logDebug("numExecutors: " + numExecutors)
+ List()
+ } else if (preferredHostToCount.isEmpty) {
+ logDebug("host preferences is empty")
+ createResourceRequests(
+ AllocationType.ANY,
+ resource = null,
+ numExecutors,
+ YarnSparkHadoopUtil.RM_REQUEST_PRIORITY).toList
+ } else {
+ // Request for all hosts in preferred nodes and for numExecutors -
+ // candidates.size, request by default allocation policy.
+ val hostContainerRequests = new ArrayBuffer[ContainerRequest](preferredHostToCount.size)
+ for ((candidateHost, candidateCount) <- preferredHostToCount) {
+ val requiredCount = candidateCount - allocatedContainersOnHost(candidateHost)
+
+ if (requiredCount > 0) {
+ hostContainerRequests ++= createResourceRequests(
+ AllocationType.HOST,
+ candidateHost,
+ requiredCount,
+ YarnSparkHadoopUtil.RM_REQUEST_PRIORITY)
+ }
+ }
+ val rackContainerRequests: List[ContainerRequest] = createRackResourceRequests(
+ hostContainerRequests).toList
+
+ val anyContainerRequests = createResourceRequests(
+ AllocationType.ANY,
+ resource = null,
+ numExecutors,
+ YarnSparkHadoopUtil.RM_REQUEST_PRIORITY)
+
+ val containerRequestBuffer = new ArrayBuffer[ContainerRequest](
+ hostContainerRequests.size + rackContainerRequests.size() + anyContainerRequests.size)
+
+ containerRequestBuffer ++= hostContainerRequests
+ containerRequestBuffer ++= rackContainerRequests
+ containerRequestBuffer ++= anyContainerRequests
+ containerRequestBuffer.toList
+ }
+
+ for (request <- containerRequests) {
+ amClient.addContainerRequest(request)
+ }
+
+ for (request <- containerRequests) {
+ val nodes = request.getNodes
+ var hostStr = if (nodes == null || nodes.isEmpty) {
+ "Any"
+ } else {
+ nodes.last
+ }
+ logInfo("Container request (host: %s, priority: %s, capability: %s".format(
+ hostStr,
+ request.getPriority().getPriority,
+ request.getCapability))
+ }
+ }
+
+ private def createResourceRequests(
+ requestType: AllocationType.AllocationType,
+ resource: String,
+ numExecutors: Int,
+ priority: Int
+ ): ArrayBuffer[ContainerRequest] = {
+
+ // If hostname is specified, then we need at least two requests - node local and rack local.
+ // There must be a third request, which is ANY. That will be specially handled.
+ requestType match {
+ case AllocationType.HOST => {
+ assert(YarnSparkHadoopUtil.ANY_HOST != resource)
+ val hostname = resource
+ val nodeLocal = constructContainerRequests(
+ Array(hostname),
+ racks = null,
+ numExecutors,
+ priority)
+
+ // Add `hostname` to the global (singleton) host->rack mapping in YarnAllocationHandler.
+ YarnSparkHadoopUtil.populateRackInfo(conf, hostname)
+ nodeLocal
+ }
+ case AllocationType.RACK => {
+ val rack = resource
+ constructContainerRequests(hosts = null, Array(rack), numExecutors, priority)
+ }
+ case AllocationType.ANY => constructContainerRequests(
+ hosts = null, racks = null, numExecutors, priority)
+ case _ => throw new IllegalArgumentException(
+ "Unexpected/unsupported request type: " + requestType)
+ }
+ }
+
+ private def constructContainerRequests(
+ hosts: Array[String],
+ racks: Array[String],
+ numExecutors: Int,
+ priority: Int
+ ): ArrayBuffer[ContainerRequest] = {
+
+ val memoryRequest = executorMemory + memoryOverhead
+ val resource = Resource.newInstance(memoryRequest, executorCores)
+
+ val prioritySetting = Records.newRecord(classOf[Priority])
+ prioritySetting.setPriority(priority)
+
+ val requests = new ArrayBuffer[ContainerRequest]()
+ for (i <- 0 until numExecutors) {
+ requests += new ContainerRequest(resource, hosts, racks, prioritySetting)
+ }
+ requests
+ }
+
+ private class StableAllocateResponse(response: AllocateResponse) extends YarnAllocateResponse {
+ override def getAllocatedContainers() = response.getAllocatedContainers()
+ override def getAvailableResources() = response.getAvailableResources()
+ override def getCompletedContainersStatuses() = response.getCompletedContainersStatuses()
+ }
+
+}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
new file mode 100644
index 0000000000..b32e15738f
--- /dev/null
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
@@ -0,0 +1,538 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.util.{List => JList}
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicInteger
+import java.util.regex.Pattern
+
+import scala.collection.JavaConversions._
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.api.protocolrecords.AllocateResponse
+
+import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv}
+import org.apache.spark.scheduler.{SplitInfo, TaskSchedulerImpl}
+import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
+
+import com.google.common.util.concurrent.ThreadFactoryBuilder
+import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._
+
+object AllocationType extends Enumeration {
+ type AllocationType = Value
+ val HOST, RACK, ANY = Value
+}
+
+// TODO:
+// Too many params.
+// Needs to be mt-safe
+// Need to refactor this to make it 'cleaner' ... right now, all computation is reactive - should
+// make it more proactive and decoupled.
+
+// Note that right now, we assume all node asks as uniform in terms of capabilities and priority
+// Refer to http://developer.yahoo.com/blogs/hadoop/posts/2011/03/mapreduce-nextgen-scheduler/ for
+// more info on how we are requesting for containers.
+
+/**
+ * Common code for the Yarn container allocator. Contains all the version-agnostic code to
+ * manage container allocation for a running Spark application.
+ */
+private[yarn] abstract class YarnAllocator(
+ conf: Configuration,
+ sparkConf: SparkConf,
+ appAttemptId: ApplicationAttemptId,
+ args: ApplicationMasterArguments,
+ preferredNodes: collection.Map[String, collection.Set[SplitInfo]],
+ securityMgr: SecurityManager)
+ extends Logging {
+
+ import YarnAllocator._
+
+ // These three are locked on allocatedHostToContainersMap. Complementary data structures
+ // allocatedHostToContainersMap : containers which are running : host, Set<containerid>
+ // allocatedContainerToHostMap: container to host mapping.
+ private val allocatedHostToContainersMap =
+ new HashMap[String, collection.mutable.Set[ContainerId]]()
+
+ private val allocatedContainerToHostMap = new HashMap[ContainerId, String]()
+
+ // allocatedRackCount is populated ONLY if allocation happens (or decremented if this is an
+ // allocated node)
+ // As with the two data structures above, tightly coupled with them, and to be locked on
+ // allocatedHostToContainersMap
+ private val allocatedRackCount = new HashMap[String, Int]()
+
+ // Containers to be released in next request to RM
+ private val releasedContainers = new ConcurrentHashMap[ContainerId, Boolean]
+
+ // Number of container requests that have been sent to, but not yet allocated by the
+ // ApplicationMaster.
+ private val numPendingAllocate = new AtomicInteger()
+ private val numExecutorsRunning = new AtomicInteger()
+ // Used to generate a unique id per executor
+ private val executorIdCounter = new AtomicInteger()
+ private val numExecutorsFailed = new AtomicInteger()
+
+ private var maxExecutors = args.numExecutors
+
+ // Keep track of which container is running which executor to remove the executors later
+ private val executorIdToContainer = new HashMap[String, Container]
+
+ protected val executorMemory = args.executorMemory
+ protected val executorCores = args.executorCores
+ protected val (preferredHostToCount, preferredRackToCount) =
+ generateNodeToWeight(conf, preferredNodes)
+
+ // Additional memory overhead - in mb.
+ protected val memoryOverhead: Int = sparkConf.getInt("spark.yarn.executor.memoryOverhead",
+ math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toInt, MEMORY_OVERHEAD_MIN))
+
+ private val launcherPool = new ThreadPoolExecutor(
+ // max pool size of Integer.MAX_VALUE is ignored because we use an unbounded queue
+ sparkConf.getInt("spark.yarn.containerLauncherMaxThreads", 25), Integer.MAX_VALUE,
+ 1, TimeUnit.MINUTES,
+ new LinkedBlockingQueue[Runnable](),
+ new ThreadFactoryBuilder().setNameFormat("ContainerLauncher #%d").setDaemon(true).build())
+ launcherPool.allowCoreThreadTimeOut(true)
+
+ def getNumExecutorsRunning: Int = numExecutorsRunning.intValue
+
+ def getNumExecutorsFailed: Int = numExecutorsFailed.intValue
+
+ /**
+ * Request as many executors from the ResourceManager as needed to reach the desired total.
+ * This takes into account executors already running or pending.
+ */
+ def requestTotalExecutors(requestedTotal: Int): Unit = synchronized {
+ val currentTotal = numPendingAllocate.get + numExecutorsRunning.get
+ if (requestedTotal > currentTotal) {
+ maxExecutors += (requestedTotal - currentTotal)
+ // We need to call `allocateResources` here to avoid the following race condition:
+ // If we request executors twice before `allocateResources` is called, then we will end up
+ // double counting the number requested because `numPendingAllocate` is not updated yet.
+ allocateResources()
+ } else {
+ logInfo(s"Not allocating more executors because there are already $currentTotal " +
+ s"(application requested $requestedTotal total)")
+ }
+ }
+
+ /**
+ * Request that the ResourceManager release the container running the specified executor.
+ */
+ def killExecutor(executorId: String): Unit = synchronized {
+ if (executorIdToContainer.contains(executorId)) {
+ val container = executorIdToContainer.remove(executorId).get
+ internalReleaseContainer(container)
+ numExecutorsRunning.decrementAndGet()
+ maxExecutors -= 1
+ assert(maxExecutors >= 0, "Allocator killed more executors than are allocated!")
+ } else {
+ logWarning(s"Attempted to kill unknown executor $executorId!")
+ }
+ }
+
+ /**
+ * Allocate missing containers based on the number of executors currently pending and running.
+ *
+ * This method prioritizes the allocated container responses from the RM based on node and
+ * rack locality. Additionally, it releases any extra containers allocated for this application
+ * but are not needed. This must be synchronized because variables read in this block are
+ * mutated by other methods.
+ */
+ def allocateResources(): Unit = synchronized {
+ val missing = maxExecutors - numPendingAllocate.get() - numExecutorsRunning.get()
+
+ // this is needed by alpha, do it here since we add numPending right after this
+ val executorsPending = numPendingAllocate.get()
+ if (missing > 0) {
+ val totalExecutorMemory = executorMemory + memoryOverhead
+ numPendingAllocate.addAndGet(missing)
+ logInfo(s"Will allocate $missing executor containers, each with $totalExecutorMemory MB " +
+ s"memory including $memoryOverhead MB overhead")
+ } else {
+ logDebug("Empty allocation request ...")
+ }
+
+ val allocateResponse = allocateContainers(missing, executorsPending)
+ val allocatedContainers = allocateResponse.getAllocatedContainers()
+
+ if (allocatedContainers.size > 0) {
+ var numPendingAllocateNow = numPendingAllocate.addAndGet(-1 * allocatedContainers.size)
+
+ if (numPendingAllocateNow < 0) {
+ numPendingAllocateNow = numPendingAllocate.addAndGet(-1 * numPendingAllocateNow)
+ }
+
+ logDebug("""
+ Allocated containers: %d
+ Current executor count: %d
+ Containers released: %s
+ Cluster resources: %s
+ """.format(
+ allocatedContainers.size,
+ numExecutorsRunning.get(),
+ releasedContainers,
+ allocateResponse.getAvailableResources))
+
+ val hostToContainers = new HashMap[String, ArrayBuffer[Container]]()
+
+ for (container <- allocatedContainers) {
+ if (isResourceConstraintSatisfied(container)) {
+ // Add the accepted `container` to the host's list of already accepted,
+ // allocated containers
+ val host = container.getNodeId.getHost
+ val containersForHost = hostToContainers.getOrElseUpdate(host,
+ new ArrayBuffer[Container]())
+ containersForHost += container
+ } else {
+ // Release container, since it doesn't satisfy resource constraints.
+ internalReleaseContainer(container)
+ }
+ }
+
+ // Find the appropriate containers to use.
+ // TODO: Cleanup this group-by...
+ val dataLocalContainers = new HashMap[String, ArrayBuffer[Container]]()
+ val rackLocalContainers = new HashMap[String, ArrayBuffer[Container]]()
+ val offRackContainers = new HashMap[String, ArrayBuffer[Container]]()
+
+ for (candidateHost <- hostToContainers.keySet) {
+ val maxExpectedHostCount = preferredHostToCount.getOrElse(candidateHost, 0)
+ val requiredHostCount = maxExpectedHostCount - allocatedContainersOnHost(candidateHost)
+
+ val remainingContainersOpt = hostToContainers.get(candidateHost)
+ assert(remainingContainersOpt.isDefined)
+ var remainingContainers = remainingContainersOpt.get
+
+ if (requiredHostCount >= remainingContainers.size) {
+ // Since we have <= required containers, add all remaining containers to
+ // `dataLocalContainers`.
+ dataLocalContainers.put(candidateHost, remainingContainers)
+ // There are no more free containers remaining.
+ remainingContainers = null
+ } else if (requiredHostCount > 0) {
+ // Container list has more containers than we need for data locality.
+ // Split the list into two: one based on the data local container count,
+ // (`remainingContainers.size` - `requiredHostCount`), and the other to hold remaining
+ // containers.
+ val (dataLocal, remaining) = remainingContainers.splitAt(
+ remainingContainers.size - requiredHostCount)
+ dataLocalContainers.put(candidateHost, dataLocal)
+
+ // Invariant: remainingContainers == remaining
+
+ // YARN has a nasty habit of allocating a ton of containers on a host - discourage this.
+ // Add each container in `remaining` to list of containers to release. If we have an
+ // insufficient number of containers, then the next allocation cycle will reallocate
+ // (but won't treat it as data local).
+ // TODO(harvey): Rephrase this comment some more.
+ for (container <- remaining) internalReleaseContainer(container)
+ remainingContainers = null
+ }
+
+ // For rack local containers
+ if (remainingContainers != null) {
+ val rack = YarnSparkHadoopUtil.lookupRack(conf, candidateHost)
+ if (rack != null) {
+ val maxExpectedRackCount = preferredRackToCount.getOrElse(rack, 0)
+ val requiredRackCount = maxExpectedRackCount - allocatedContainersOnRack(rack) -
+ rackLocalContainers.getOrElse(rack, List()).size
+
+ if (requiredRackCount >= remainingContainers.size) {
+ // Add all remaining containers to to `dataLocalContainers`.
+ dataLocalContainers.put(rack, remainingContainers)
+ remainingContainers = null
+ } else if (requiredRackCount > 0) {
+ // Container list has more containers that we need for data locality.
+ // Split the list into two: one based on the data local container count,
+ // (`remainingContainers.size` - `requiredHostCount`), and the other to hold remaining
+ // containers.
+ val (rackLocal, remaining) = remainingContainers.splitAt(
+ remainingContainers.size - requiredRackCount)
+ val existingRackLocal = rackLocalContainers.getOrElseUpdate(rack,
+ new ArrayBuffer[Container]())
+
+ existingRackLocal ++= rackLocal
+
+ remainingContainers = remaining
+ }
+ }
+ }
+
+ if (remainingContainers != null) {
+ // Not all containers have been consumed - add them to the list of off-rack containers.
+ offRackContainers.put(candidateHost, remainingContainers)
+ }
+ }
+
+ // Now that we have split the containers into various groups, go through them in order:
+ // first host-local, then rack-local, and finally off-rack.
+ // Note that the list we create below tries to ensure that not all containers end up within
+ // a host if there is a sufficiently large number of hosts/containers.
+ val allocatedContainersToProcess = new ArrayBuffer[Container](allocatedContainers.size)
+ allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(dataLocalContainers)
+ allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(rackLocalContainers)
+ allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(offRackContainers)
+
+ // Run each of the allocated containers.
+ for (container <- allocatedContainersToProcess) {
+ val numExecutorsRunningNow = numExecutorsRunning.incrementAndGet()
+ val executorHostname = container.getNodeId.getHost
+ val containerId = container.getId
+
+ val executorMemoryOverhead = (executorMemory + memoryOverhead)
+ assert(container.getResource.getMemory >= executorMemoryOverhead)
+
+ if (numExecutorsRunningNow > maxExecutors) {
+ logInfo("""Ignoring container %s at host %s, since we already have the required number of
+ containers for it.""".format(containerId, executorHostname))
+ internalReleaseContainer(container)
+ numExecutorsRunning.decrementAndGet()
+ } else {
+ val executorId = executorIdCounter.incrementAndGet().toString
+ val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format(
+ SparkEnv.driverActorSystemName,
+ sparkConf.get("spark.driver.host"),
+ sparkConf.get("spark.driver.port"),
+ CoarseGrainedSchedulerBackend.ACTOR_NAME)
+
+ logInfo("Launching container %s for on host %s".format(containerId, executorHostname))
+ executorIdToContainer(executorId) = container
+
+ // To be safe, remove the container from `releasedContainers`.
+ releasedContainers.remove(containerId)
+
+ val rack = YarnSparkHadoopUtil.lookupRack(conf, executorHostname)
+ allocatedHostToContainersMap.synchronized {
+ val containerSet = allocatedHostToContainersMap.getOrElseUpdate(executorHostname,
+ new HashSet[ContainerId]())
+
+ containerSet += containerId
+ allocatedContainerToHostMap.put(containerId, executorHostname)
+
+ if (rack != null) {
+ allocatedRackCount.put(rack, allocatedRackCount.getOrElse(rack, 0) + 1)
+ }
+ }
+ logInfo("Launching ExecutorRunnable. driverUrl: %s, executorHostname: %s".format(
+ driverUrl, executorHostname))
+ val executorRunnable = new ExecutorRunnable(
+ container,
+ conf,
+ sparkConf,
+ driverUrl,
+ executorId,
+ executorHostname,
+ executorMemory,
+ executorCores,
+ appAttemptId.getApplicationId.toString,
+ securityMgr)
+ launcherPool.execute(executorRunnable)
+ }
+ }
+ logDebug("""
+ Finished allocating %s containers (from %s originally).
+ Current number of executors running: %d,
+ Released containers: %s
+ """.format(
+ allocatedContainersToProcess,
+ allocatedContainers,
+ numExecutorsRunning.get(),
+ releasedContainers))
+ }
+
+ val completedContainers = allocateResponse.getCompletedContainersStatuses()
+ if (completedContainers.size > 0) {
+ logDebug("Completed %d containers".format(completedContainers.size))
+
+ for (completedContainer <- completedContainers) {
+ val containerId = completedContainer.getContainerId
+
+ if (releasedContainers.containsKey(containerId)) {
+ // YarnAllocationHandler already marked the container for release, so remove it from
+ // `releasedContainers`.
+ releasedContainers.remove(containerId)
+ } else {
+ // Decrement the number of executors running. The next iteration of
+ // the ApplicationMaster's reporting thread will take care of allocating.
+ numExecutorsRunning.decrementAndGet()
+ logInfo("Completed container %s (state: %s, exit status: %s)".format(
+ containerId,
+ completedContainer.getState,
+ completedContainer.getExitStatus))
+ // Hadoop 2.2.X added a ContainerExitStatus we should switch to use
+ // there are some exit status' we shouldn't necessarily count against us, but for
+ // now I think its ok as none of the containers are expected to exit
+ if (completedContainer.getExitStatus == -103) { // vmem limit exceeded
+ logWarning(memLimitExceededLogMessage(
+ completedContainer.getDiagnostics,
+ VMEM_EXCEEDED_PATTERN))
+ } else if (completedContainer.getExitStatus == -104) { // pmem limit exceeded
+ logWarning(memLimitExceededLogMessage(
+ completedContainer.getDiagnostics,
+ PMEM_EXCEEDED_PATTERN))
+ } else if (completedContainer.getExitStatus != 0) {
+ logInfo("Container marked as failed: " + containerId +
+ ". Exit status: " + completedContainer.getExitStatus +
+ ". Diagnostics: " + completedContainer.getDiagnostics)
+ numExecutorsFailed.incrementAndGet()
+ }
+ }
+
+ allocatedHostToContainersMap.synchronized {
+ if (allocatedContainerToHostMap.containsKey(containerId)) {
+ val hostOpt = allocatedContainerToHostMap.get(containerId)
+ assert(hostOpt.isDefined)
+ val host = hostOpt.get
+
+ val containerSetOpt = allocatedHostToContainersMap.get(host)
+ assert(containerSetOpt.isDefined)
+ val containerSet = containerSetOpt.get
+
+ containerSet.remove(containerId)
+ if (containerSet.isEmpty) {
+ allocatedHostToContainersMap.remove(host)
+ } else {
+ allocatedHostToContainersMap.update(host, containerSet)
+ }
+
+ allocatedContainerToHostMap.remove(containerId)
+
+ // TODO: Move this part outside the synchronized block?
+ val rack = YarnSparkHadoopUtil.lookupRack(conf, host)
+ if (rack != null) {
+ val rackCount = allocatedRackCount.getOrElse(rack, 0) - 1
+ if (rackCount > 0) {
+ allocatedRackCount.put(rack, rackCount)
+ } else {
+ allocatedRackCount.remove(rack)
+ }
+ }
+ }
+ }
+ }
+ logDebug("""
+ Finished processing %d completed containers.
+ Current number of executors running: %d,
+ Released containers: %s
+ """.format(
+ completedContainers.size,
+ numExecutorsRunning.get(),
+ releasedContainers))
+ }
+ }
+
+ protected def allocatedContainersOnHost(host: String): Int = {
+ var retval = 0
+ allocatedHostToContainersMap.synchronized {
+ retval = allocatedHostToContainersMap.getOrElse(host, Set()).size
+ }
+ retval
+ }
+
+ protected def allocatedContainersOnRack(rack: String): Int = {
+ var retval = 0
+ allocatedHostToContainersMap.synchronized {
+ retval = allocatedRackCount.getOrElse(rack, 0)
+ }
+ retval
+ }
+
+ private def isResourceConstraintSatisfied(container: Container): Boolean = {
+ container.getResource.getMemory >= (executorMemory + memoryOverhead)
+ }
+
+ // A simple method to copy the split info map.
+ private def generateNodeToWeight(
+ conf: Configuration,
+ input: collection.Map[String, collection.Set[SplitInfo]]
+ ): (Map[String, Int], Map[String, Int]) = {
+
+ if (input == null) {
+ return (Map[String, Int](), Map[String, Int]())
+ }
+
+ val hostToCount = new HashMap[String, Int]
+ val rackToCount = new HashMap[String, Int]
+
+ for ((host, splits) <- input) {
+ val hostCount = hostToCount.getOrElse(host, 0)
+ hostToCount.put(host, hostCount + splits.size)
+
+ val rack = YarnSparkHadoopUtil.lookupRack(conf, host)
+ if (rack != null) {
+ val rackCount = rackToCount.getOrElse(host, 0)
+ rackToCount.put(host, rackCount + splits.size)
+ }
+ }
+
+ (hostToCount.toMap, rackToCount.toMap)
+ }
+
+ private def internalReleaseContainer(container: Container) = {
+ releasedContainers.put(container.getId(), true)
+ releaseContainer(container)
+ }
+
+ /**
+ * Called to allocate containers in the cluster.
+ *
+ * @param count Number of containers to allocate.
+ * If zero, should still contact RM (as a heartbeat).
+ * @param pending Number of containers pending allocate. Only used on alpha.
+ * @return Response to the allocation request.
+ */
+ protected def allocateContainers(count: Int, pending: Int): YarnAllocateResponse
+
+ /** Called to release a previously allocated container. */
+ protected def releaseContainer(container: Container): Unit
+
+ /**
+ * Defines the interface for an allocate response from the RM. This is needed since the alpha
+ * and stable interfaces differ here in ways that cannot be fixed using other routes.
+ */
+ protected trait YarnAllocateResponse {
+
+ def getAllocatedContainers(): JList[Container]
+
+ def getAvailableResources(): Resource
+
+ def getCompletedContainersStatuses(): JList[ContainerStatus]
+
+ }
+
+}
+
+private object YarnAllocator {
+ val MEM_REGEX = "[0-9.]+ [KMG]B"
+ val PMEM_EXCEEDED_PATTERN =
+ Pattern.compile(s"$MEM_REGEX of $MEM_REGEX physical memory used")
+ val VMEM_EXCEEDED_PATTERN =
+ Pattern.compile(s"$MEM_REGEX of $MEM_REGEX virtual memory used")
+
+ def memLimitExceededLogMessage(diagnostics: String, pattern: Pattern): String = {
+ val matcher = pattern.matcher(diagnostics)
+ val diag = if (matcher.find()) " " + matcher.group() + "." else ""
+ ("Container killed by YARN for exceeding memory limits." + diag
+ + " Consider boosting spark.yarn.executor.memoryOverhead.")
+ }
+}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
new file mode 100644
index 0000000000..2510b9c9ce
--- /dev/null
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import scala.collection.{Map, Set}
+
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.api.records._
+
+import org.apache.spark.{SecurityManager, SparkConf, SparkContext}
+import org.apache.spark.scheduler.SplitInfo
+
+/**
+ * Interface that defines a Yarn RM client. Abstracts away Yarn version-specific functionality that
+ * is used by Spark's AM.
+ */
+trait YarnRMClient {
+
+ /**
+ * Registers the application master with the RM.
+ *
+ * @param conf The Yarn configuration.
+ * @param sparkConf The Spark configuration.
+ * @param preferredNodeLocations Map with hints about where to allocate containers.
+ * @param uiAddress Address of the SparkUI.
+ * @param uiHistoryAddress Address of the application on the History Server.
+ */
+ def register(
+ conf: YarnConfiguration,
+ sparkConf: SparkConf,
+ preferredNodeLocations: Map[String, Set[SplitInfo]],
+ uiAddress: String,
+ uiHistoryAddress: String,
+ securityMgr: SecurityManager): YarnAllocator
+
+ /**
+ * Unregister the AM. Guaranteed to only be called once.
+ *
+ * @param status The final status of the AM.
+ * @param diagnostics Diagnostics message to include in the final status.
+ */
+ def unregister(status: FinalApplicationStatus, diagnostics: String = ""): Unit
+
+ /** Returns the attempt ID. */
+ def getAttemptId(): ApplicationAttemptId
+
+ /** Returns the configuration for the AmIpFilter to add to the Spark UI. */
+ def getAmIpFilterParams(conf: YarnConfiguration, proxyBase: String): Map[String, String]
+
+ /** Returns the maximum number of attempts to register the AM. */
+ def getMaxRegAttempts(conf: YarnConfiguration): Int
+
+}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala
new file mode 100644
index 0000000000..8d4b96ed79
--- /dev/null
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.util.{List => JList}
+
+import scala.collection.{Map, Set}
+import scala.collection.JavaConversions._
+import scala.util._
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.yarn.api._
+import org.apache.hadoop.yarn.api.protocolrecords._
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.client.api.AMRMClient
+import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.util.ConverterUtils
+import org.apache.hadoop.yarn.webapp.util.WebAppUtils
+
+import org.apache.spark.{Logging, SecurityManager, SparkConf}
+import org.apache.spark.scheduler.SplitInfo
+import org.apache.spark.util.Utils
+
+
+/**
+ * YarnRMClient implementation for the Yarn stable API.
+ */
+private class YarnRMClientImpl(args: ApplicationMasterArguments) extends YarnRMClient with Logging {
+
+ private var amClient: AMRMClient[ContainerRequest] = _
+ private var uiHistoryAddress: String = _
+ private var registered: Boolean = false
+
+ override def register(
+ conf: YarnConfiguration,
+ sparkConf: SparkConf,
+ preferredNodeLocations: Map[String, Set[SplitInfo]],
+ uiAddress: String,
+ uiHistoryAddress: String,
+ securityMgr: SecurityManager) = {
+ amClient = AMRMClient.createAMRMClient()
+ amClient.init(conf)
+ amClient.start()
+ this.uiHistoryAddress = uiHistoryAddress
+
+ logInfo("Registering the ApplicationMaster")
+ synchronized {
+ amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress)
+ registered = true
+ }
+ new YarnAllocationHandler(conf, sparkConf, amClient, getAttemptId(), args,
+ preferredNodeLocations, securityMgr)
+ }
+
+ override def unregister(status: FinalApplicationStatus, diagnostics: String = "") = synchronized {
+ if (registered) {
+ amClient.unregisterApplicationMaster(status, diagnostics, uiHistoryAddress)
+ }
+ }
+
+ override def getAttemptId() = {
+ val containerIdString = System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name())
+ val containerId = ConverterUtils.toContainerId(containerIdString)
+ val appAttemptId = containerId.getApplicationAttemptId()
+ appAttemptId
+ }
+
+ override def getAmIpFilterParams(conf: YarnConfiguration, proxyBase: String) = {
+ // Figure out which scheme Yarn is using. Note the method seems to have been added after 2.2,
+ // so not all stable releases have it.
+ val prefix = Try(classOf[WebAppUtils].getMethod("getHttpSchemePrefix", classOf[Configuration])
+ .invoke(null, conf).asInstanceOf[String]).getOrElse("http://")
+
+ // If running a new enough Yarn, use the HA-aware API for retrieving the RM addresses.
+ try {
+ val method = classOf[WebAppUtils].getMethod("getProxyHostsAndPortsForAmFilter",
+ classOf[Configuration])
+ val proxies = method.invoke(null, conf).asInstanceOf[JList[String]]
+ val hosts = proxies.map { proxy => proxy.split(":")(0) }
+ val uriBases = proxies.map { proxy => prefix + proxy + proxyBase }
+ Map("PROXY_HOSTS" -> hosts.mkString(","), "PROXY_URI_BASES" -> uriBases.mkString(","))
+ } catch {
+ case e: NoSuchMethodException =>
+ val proxy = WebAppUtils.getProxyHostAndPort(conf)
+ val parts = proxy.split(":")
+ val uriBase = prefix + proxy + proxyBase
+ Map("PROXY_HOST" -> parts(0), "PROXY_URI_BASE" -> uriBase)
+ }
+ }
+
+ override def getMaxRegAttempts(conf: YarnConfiguration) =
+ conf.getInt(YarnConfiguration.RM_AM_MAX_ATTEMPTS, YarnConfiguration.DEFAULT_RM_AM_MAX_ATTEMPTS)
+
+}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
new file mode 100644
index 0000000000..d7cf904db1
--- /dev/null
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
@@ -0,0 +1,226 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.lang.{Boolean => JBoolean}
+import java.io.File
+import java.util.{Collections, Set => JSet}
+import java.util.regex.Matcher
+import java.util.regex.Pattern
+import java.util.concurrent.ConcurrentHashMap
+
+import scala.collection.mutable.HashMap
+
+import org.apache.hadoop.io.Text
+import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.security.Credentials
+import org.apache.hadoop.security.UserGroupInformation
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.api.records.ApplicationAccessType
+import org.apache.hadoop.yarn.util.RackResolver
+import org.apache.hadoop.conf.Configuration
+
+import org.apache.spark.{SecurityManager, SparkConf}
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.util.Utils
+
+/**
+ * Contains util methods to interact with Hadoop from spark.
+ */
+class YarnSparkHadoopUtil extends SparkHadoopUtil {
+
+ override def transferCredentials(source: UserGroupInformation, dest: UserGroupInformation) {
+ dest.addCredentials(source.getCredentials())
+ }
+
+ // Note that all params which start with SPARK are propagated all the way through, so if in yarn
+ // mode, this MUST be set to true.
+ override def isYarnMode(): Boolean = { true }
+
+ // Return an appropriate (subclass) of Configuration. Creating a config initializes some Hadoop
+ // subsystems. Always create a new config, dont reuse yarnConf.
+ override def newConfiguration(conf: SparkConf): Configuration =
+ new YarnConfiguration(super.newConfiguration(conf))
+
+ // Add any user credentials to the job conf which are necessary for running on a secure Hadoop
+ // cluster
+ override def addCredentials(conf: JobConf) {
+ val jobCreds = conf.getCredentials()
+ jobCreds.mergeAll(UserGroupInformation.getCurrentUser().getCredentials())
+ }
+
+ override def getCurrentUserCredentials(): Credentials = {
+ UserGroupInformation.getCurrentUser().getCredentials()
+ }
+
+ override def addCurrentUserCredentials(creds: Credentials) {
+ UserGroupInformation.getCurrentUser().addCredentials(creds)
+ }
+
+ override def addSecretKeyToUserCredentials(key: String, secret: String) {
+ val creds = new Credentials()
+ creds.addSecretKey(new Text(key), secret.getBytes("utf-8"))
+ addCurrentUserCredentials(creds)
+ }
+
+ override def getSecretKeyFromUserCredentials(key: String): Array[Byte] = {
+ val credentials = getCurrentUserCredentials()
+ if (credentials != null) credentials.getSecretKey(new Text(key)) else null
+ }
+
+}
+
+object YarnSparkHadoopUtil {
+ // Additional memory overhead
+ // 7% was arrived at experimentally. In the interest of minimizing memory waste while covering
+ // the common cases. Memory overhead tends to grow with container size.
+
+ val MEMORY_OVERHEAD_FACTOR = 0.07
+ val MEMORY_OVERHEAD_MIN = 384
+
+ val ANY_HOST = "*"
+
+ val DEFAULT_NUMBER_EXECUTORS = 2
+
+ // All RM requests are issued with same priority : we do not (yet) have any distinction between
+ // request types (like map/reduce in hadoop for example)
+ val RM_REQUEST_PRIORITY = 1
+
+ // Host to rack map - saved from allocation requests. We are expecting this not to change.
+ // Note that it is possible for this to change : and ResourceManager will indicate that to us via
+ // update response to allocate. But we are punting on handling that for now.
+ private val hostToRack = new ConcurrentHashMap[String, String]()
+ private val rackToHostSet = new ConcurrentHashMap[String, JSet[String]]()
+
+ /**
+ * Add a path variable to the given environment map.
+ * If the map already contains this key, append the value to the existing value instead.
+ */
+ def addPathToEnvironment(env: HashMap[String, String], key: String, value: String): Unit = {
+ val newValue = if (env.contains(key)) { env(key) + File.pathSeparator + value } else value
+ env.put(key, newValue)
+ }
+
+ /**
+ * Set zero or more environment variables specified by the given input string.
+ * The input string is expected to take the form "KEY1=VAL1,KEY2=VAL2,KEY3=VAL3".
+ */
+ def setEnvFromInputString(env: HashMap[String, String], inputString: String): Unit = {
+ if (inputString != null && inputString.length() > 0) {
+ val childEnvs = inputString.split(",")
+ val p = Pattern.compile(environmentVariableRegex)
+ for (cEnv <- childEnvs) {
+ val parts = cEnv.split("=") // split on '='
+ val m = p.matcher(parts(1))
+ val sb = new StringBuffer
+ while (m.find()) {
+ val variable = m.group(1)
+ var replace = ""
+ if (env.get(variable) != None) {
+ replace = env.get(variable).get
+ } else {
+ // if this key is not configured for the child .. get it from the env
+ replace = System.getenv(variable)
+ if (replace == null) {
+ // the env key is note present anywhere .. simply set it
+ replace = ""
+ }
+ }
+ m.appendReplacement(sb, Matcher.quoteReplacement(replace))
+ }
+ m.appendTail(sb)
+ // This treats the environment variable as path variable delimited by `File.pathSeparator`
+ // This is kept for backward compatibility and consistency with Hadoop's behavior
+ addPathToEnvironment(env, parts(0), sb.toString)
+ }
+ }
+ }
+
+ private val environmentVariableRegex: String = {
+ if (Utils.isWindows) {
+ "%([A-Za-z_][A-Za-z0-9_]*?)%"
+ } else {
+ "\\$([A-Za-z_][A-Za-z0-9_]*)"
+ }
+ }
+
+ /**
+ * Escapes a string for inclusion in a command line executed by Yarn. Yarn executes commands
+ * using `bash -c "command arg1 arg2"` and that means plain quoting doesn't really work. The
+ * argument is enclosed in single quotes and some key characters are escaped.
+ *
+ * @param arg A single argument.
+ * @return Argument quoted for execution via Yarn's generated shell script.
+ */
+ def escapeForShell(arg: String): String = {
+ if (arg != null) {
+ val escaped = new StringBuilder("'")
+ for (i <- 0 to arg.length() - 1) {
+ arg.charAt(i) match {
+ case '$' => escaped.append("\\$")
+ case '"' => escaped.append("\\\"")
+ case '\'' => escaped.append("'\\''")
+ case c => escaped.append(c)
+ }
+ }
+ escaped.append("'").toString()
+ } else {
+ arg
+ }
+ }
+
+ def lookupRack(conf: Configuration, host: String): String = {
+ if (!hostToRack.contains(host)) {
+ populateRackInfo(conf, host)
+ }
+ hostToRack.get(host)
+ }
+
+ def populateRackInfo(conf: Configuration, hostname: String) {
+ Utils.checkHost(hostname)
+
+ if (!hostToRack.containsKey(hostname)) {
+ // If there are repeated failures to resolve, all to an ignore list.
+ val rackInfo = RackResolver.resolve(conf, hostname)
+ if (rackInfo != null && rackInfo.getNetworkLocation != null) {
+ val rack = rackInfo.getNetworkLocation
+ hostToRack.put(hostname, rack)
+ if (! rackToHostSet.containsKey(rack)) {
+ rackToHostSet.putIfAbsent(rack,
+ Collections.newSetFromMap(new ConcurrentHashMap[String, JBoolean]()))
+ }
+ rackToHostSet.get(rack).add(hostname)
+
+ // TODO(harvey): Figure out what this comment means...
+ // Since RackResolver caches, we are disabling this for now ...
+ } /* else {
+ // right ? Else we will keep calling rack resolver in case we cant resolve rack info ...
+ hostToRack.put(hostname, null)
+ } */
+ }
+ }
+
+ def getApplicationAclsForYarn(securityMgr: SecurityManager)
+ : Map[ApplicationAccessType, String] = {
+ Map[ApplicationAccessType, String] (
+ ApplicationAccessType.VIEW_APP -> securityMgr.getViewAcls,
+ ApplicationAccessType.MODIFY_APP -> securityMgr.getModifyAcls
+ )
+ }
+
+}
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala
new file mode 100644
index 0000000000..254774a6b8
--- /dev/null
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import org.apache.spark._
+import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil
+import org.apache.spark.scheduler.TaskSchedulerImpl
+import org.apache.spark.util.Utils
+
+/**
+ * This scheduler launches executors through Yarn - by calling into Client to launch the Spark AM.
+ */
+private[spark] class YarnClientClusterScheduler(sc: SparkContext) extends TaskSchedulerImpl(sc) {
+
+ // By default, rack is unknown
+ override def getRackForHost(hostPort: String): Option[String] = {
+ val host = Utils.parseHostPort(hostPort)._1
+ Option(YarnSparkHadoopUtil.lookupRack(sc.hadoopConfiguration, host))
+ }
+}
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
new file mode 100644
index 0000000000..2923e6729c
--- /dev/null
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState}
+
+import org.apache.spark.{SparkException, Logging, SparkContext}
+import org.apache.spark.deploy.yarn.{Client, ClientArguments}
+import org.apache.spark.scheduler.TaskSchedulerImpl
+
+private[spark] class YarnClientSchedulerBackend(
+ scheduler: TaskSchedulerImpl,
+ sc: SparkContext)
+ extends YarnSchedulerBackend(scheduler, sc)
+ with Logging {
+
+ private var client: Client = null
+ private var appId: ApplicationId = null
+ @volatile private var stopping: Boolean = false
+
+ /**
+ * Create a Yarn client to submit an application to the ResourceManager.
+ * This waits until the application is running.
+ */
+ override def start() {
+ super.start()
+ val driverHost = conf.get("spark.driver.host")
+ val driverPort = conf.get("spark.driver.port")
+ val hostport = driverHost + ":" + driverPort
+ sc.ui.foreach { ui => conf.set("spark.driver.appUIAddress", ui.appUIAddress) }
+
+ val argsArrayBuf = new ArrayBuffer[String]()
+ argsArrayBuf += ("--arg", hostport)
+ argsArrayBuf ++= getExtraClientArguments
+
+ logDebug("ClientArguments called with: " + argsArrayBuf.mkString(" "))
+ val args = new ClientArguments(argsArrayBuf.toArray, conf)
+ totalExpectedExecutors = args.numExecutors
+ client = new Client(args, conf)
+ appId = client.submitApplication()
+ waitForApplication()
+ asyncMonitorApplication()
+ }
+
+ /**
+ * Return any extra command line arguments to be passed to Client provided in the form of
+ * environment variables or Spark properties.
+ */
+ private def getExtraClientArguments: Seq[String] = {
+ val extraArgs = new ArrayBuffer[String]
+ val optionTuples = // List of (target Client argument, environment variable, Spark property)
+ List(
+ ("--driver-memory", "SPARK_MASTER_MEMORY", "spark.master.memory"),
+ ("--driver-memory", "SPARK_DRIVER_MEMORY", "spark.driver.memory"),
+ ("--num-executors", "SPARK_WORKER_INSTANCES", "spark.executor.instances"),
+ ("--num-executors", "SPARK_EXECUTOR_INSTANCES", "spark.executor.instances"),
+ ("--executor-memory", "SPARK_WORKER_MEMORY", "spark.executor.memory"),
+ ("--executor-memory", "SPARK_EXECUTOR_MEMORY", "spark.executor.memory"),
+ ("--executor-cores", "SPARK_WORKER_CORES", "spark.executor.cores"),
+ ("--executor-cores", "SPARK_EXECUTOR_CORES", "spark.executor.cores"),
+ ("--queue", "SPARK_YARN_QUEUE", "spark.yarn.queue"),
+ ("--name", "SPARK_YARN_APP_NAME", "spark.app.name")
+ )
+ optionTuples.foreach { case (optionName, envVar, sparkProp) =>
+ if (System.getenv(envVar) != null) {
+ extraArgs += (optionName, System.getenv(envVar))
+ } else if (sc.getConf.contains(sparkProp)) {
+ extraArgs += (optionName, sc.getConf.get(sparkProp))
+ }
+ }
+ extraArgs
+ }
+
+ /**
+ * Report the state of the application until it is running.
+ * If the application has finished, failed or been killed in the process, throw an exception.
+ * This assumes both `client` and `appId` have already been set.
+ */
+ private def waitForApplication(): Unit = {
+ assert(client != null && appId != null, "Application has not been submitted yet!")
+ val (state, _) = client.monitorApplication(appId, returnOnRunning = true) // blocking
+ if (state == YarnApplicationState.FINISHED ||
+ state == YarnApplicationState.FAILED ||
+ state == YarnApplicationState.KILLED) {
+ throw new SparkException("Yarn application has already ended! " +
+ "It might have been killed or unable to launch application master.")
+ }
+ if (state == YarnApplicationState.RUNNING) {
+ logInfo(s"Application $appId has started running.")
+ }
+ }
+
+ /**
+ * Monitor the application state in a separate thread.
+ * If the application has exited for any reason, stop the SparkContext.
+ * This assumes both `client` and `appId` have already been set.
+ */
+ private def asyncMonitorApplication(): Unit = {
+ assert(client != null && appId != null, "Application has not been submitted yet!")
+ val t = new Thread {
+ override def run() {
+ while (!stopping) {
+ val report = client.getApplicationReport(appId)
+ val state = report.getYarnApplicationState()
+ if (state == YarnApplicationState.FINISHED ||
+ state == YarnApplicationState.KILLED ||
+ state == YarnApplicationState.FAILED) {
+ logError(s"Yarn application has already exited with state $state!")
+ sc.stop()
+ stopping = true
+ }
+ Thread.sleep(1000L)
+ }
+ Thread.currentThread().interrupt()
+ }
+ }
+ t.setName("Yarn application state monitor")
+ t.setDaemon(true)
+ t.start()
+ }
+
+ /**
+ * Stop the scheduler. This assumes `start()` has already been called.
+ */
+ override def stop() {
+ assert(client != null, "Attempted to stop this scheduler before starting it!")
+ stopping = true
+ super.stop()
+ client.stop()
+ logInfo("Stopped")
+ }
+
+ override def applicationId(): String = {
+ Option(appId).map(_.toString).getOrElse {
+ logWarning("Application ID is not initialized yet.")
+ super.applicationId
+ }
+ }
+
+}
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
new file mode 100644
index 0000000000..4157ff95c2
--- /dev/null
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import org.apache.spark._
+import org.apache.spark.deploy.yarn.{ApplicationMaster, YarnSparkHadoopUtil}
+import org.apache.spark.scheduler.TaskSchedulerImpl
+import org.apache.spark.util.Utils
+
+/**
+ * This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of
+ * ApplicationMaster, etc is done
+ */
+private[spark] class YarnClusterScheduler(sc: SparkContext) extends TaskSchedulerImpl(sc) {
+
+ logInfo("Created YarnClusterScheduler")
+
+ // Nothing else for now ... initialize application master : which needs a SparkContext to
+ // determine how to allocate.
+ // Note that only the first creation of a SparkContext influences (and ideally, there must be
+ // only one SparkContext, right ?). Subsequent creations are ignored since executors are already
+ // allocated by then.
+
+ // By default, rack is unknown
+ override def getRackForHost(hostPort: String): Option[String] = {
+ val host = Utils.parseHostPort(hostPort)._1
+ Option(YarnSparkHadoopUtil.lookupRack(sc.hadoopConfiguration, host))
+ }
+
+ override def postStartHook() {
+ ApplicationMaster.sparkContextInitialized(sc)
+ super.postStartHook()
+ logInfo("YarnClusterScheduler.postStartHook done")
+ }
+
+ override def stop() {
+ super.stop()
+ ApplicationMaster.sparkContextStopped(sc)
+ }
+
+}
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala
new file mode 100644
index 0000000000..b1de81e6a8
--- /dev/null
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import org.apache.spark.SparkContext
+import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._
+import org.apache.spark.scheduler.TaskSchedulerImpl
+import org.apache.spark.util.IntParam
+
+private[spark] class YarnClusterSchedulerBackend(
+ scheduler: TaskSchedulerImpl,
+ sc: SparkContext)
+ extends YarnSchedulerBackend(scheduler, sc) {
+
+ override def start() {
+ super.start()
+ totalExpectedExecutors = DEFAULT_NUMBER_EXECUTORS
+ if (System.getenv("SPARK_EXECUTOR_INSTANCES") != null) {
+ totalExpectedExecutors = IntParam.unapply(System.getenv("SPARK_EXECUTOR_INSTANCES"))
+ .getOrElse(totalExpectedExecutors)
+ }
+ // System property can override environment variable.
+ totalExpectedExecutors = sc.getConf.getInt("spark.executor.instances", totalExpectedExecutors)
+ }
+
+ override def applicationId(): String =
+ // In YARN Cluster mode, spark.yarn.app.id is expect to be set
+ // before user application is launched.
+ // So, if spark.yarn.app.id is not set, it is something wrong.
+ sc.getConf.getOption("spark.yarn.app.id").getOrElse {
+ logError("Application ID is not set.")
+ super.applicationId
+ }
+
+}
diff --git a/yarn/src/test/resources/log4j.properties b/yarn/src/test/resources/log4j.properties
new file mode 100644
index 0000000000..9dd05f17f0
--- /dev/null
+++ b/yarn/src/test/resources/log4j.properties
@@ -0,0 +1,28 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Set everything to be logged to the file core/target/unit-tests.log
+log4j.rootCategory=INFO, file
+log4j.appender.file=org.apache.log4j.FileAppender
+log4j.appender.file.append=false
+log4j.appender.file.file=target/unit-tests.log
+log4j.appender.file.layout=org.apache.log4j.PatternLayout
+log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
+
+# Ignore messages below warning level from Jetty, because it's a bit verbose
+log4j.logger.org.eclipse.jetty=WARN
+org.eclipse.jetty.LEVEL=WARN
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala
new file mode 100644
index 0000000000..17b79ae1d8
--- /dev/null
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala
@@ -0,0 +1,256 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.io.File
+import java.net.URI
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.mapreduce.MRJobConfig
+import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.mockito.Matchers._
+import org.mockito.Mockito._
+
+
+import org.scalatest.FunSuite
+import org.scalatest.Matchers
+
+import scala.collection.JavaConversions._
+import scala.collection.mutable.{ HashMap => MutableHashMap }
+import scala.reflect.ClassTag
+import scala.util.Try
+
+import org.apache.spark.{SparkException, SparkConf}
+import org.apache.spark.util.Utils
+
+class ClientBaseSuite extends FunSuite with Matchers {
+
+ test("default Yarn application classpath") {
+ ClientBase.getDefaultYarnApplicationClasspath should be(Some(Fixtures.knownDefYarnAppCP))
+ }
+
+ test("default MR application classpath") {
+ ClientBase.getDefaultMRApplicationClasspath should be(Some(Fixtures.knownDefMRAppCP))
+ }
+
+ test("resultant classpath for an application that defines a classpath for YARN") {
+ withAppConf(Fixtures.mapYARNAppConf) { conf =>
+ val env = newEnv
+ ClientBase.populateHadoopClasspath(conf, env)
+ classpath(env) should be(
+ flatten(Fixtures.knownYARNAppCP, ClientBase.getDefaultMRApplicationClasspath))
+ }
+ }
+
+ test("resultant classpath for an application that defines a classpath for MR") {
+ withAppConf(Fixtures.mapMRAppConf) { conf =>
+ val env = newEnv
+ ClientBase.populateHadoopClasspath(conf, env)
+ classpath(env) should be(
+ flatten(ClientBase.getDefaultYarnApplicationClasspath, Fixtures.knownMRAppCP))
+ }
+ }
+
+ test("resultant classpath for an application that defines both classpaths, YARN and MR") {
+ withAppConf(Fixtures.mapAppConf) { conf =>
+ val env = newEnv
+ ClientBase.populateHadoopClasspath(conf, env)
+ classpath(env) should be(flatten(Fixtures.knownYARNAppCP, Fixtures.knownMRAppCP))
+ }
+ }
+
+ private val SPARK = "local:/sparkJar"
+ private val USER = "local:/userJar"
+ private val ADDED = "local:/addJar1,local:/addJar2,/addJar3"
+
+ test("Local jar URIs") {
+ val conf = new Configuration()
+ val sparkConf = new SparkConf().set(ClientBase.CONF_SPARK_JAR, SPARK)
+ val env = new MutableHashMap[String, String]()
+ val args = new ClientArguments(Array("--jar", USER, "--addJars", ADDED), sparkConf)
+
+ ClientBase.populateClasspath(args, conf, sparkConf, env)
+
+ val cp = env("CLASSPATH").split(File.pathSeparator)
+ s"$SPARK,$USER,$ADDED".split(",").foreach({ entry =>
+ val uri = new URI(entry)
+ if (ClientBase.LOCAL_SCHEME.equals(uri.getScheme())) {
+ cp should contain (uri.getPath())
+ } else {
+ cp should not contain (uri.getPath())
+ }
+ })
+ cp should contain (Environment.PWD.$())
+ cp should contain (s"${Environment.PWD.$()}${File.separator}*")
+ cp should not contain (ClientBase.SPARK_JAR)
+ cp should not contain (ClientBase.APP_JAR)
+ }
+
+ test("Jar path propagation through SparkConf") {
+ val conf = new Configuration()
+ val sparkConf = new SparkConf().set(ClientBase.CONF_SPARK_JAR, SPARK)
+ val yarnConf = new YarnConfiguration()
+ val args = new ClientArguments(Array("--jar", USER, "--addJars", ADDED), sparkConf)
+
+ val client = spy(new DummyClient(args, conf, sparkConf, yarnConf))
+ doReturn(new Path("/")).when(client).copyFileToRemote(any(classOf[Path]),
+ any(classOf[Path]), anyShort(), anyBoolean())
+
+ val tempDir = Utils.createTempDir()
+ try {
+ client.prepareLocalResources(tempDir.getAbsolutePath())
+ sparkConf.getOption(ClientBase.CONF_SPARK_USER_JAR) should be (Some(USER))
+
+ // The non-local path should be propagated by name only, since it will end up in the app's
+ // staging dir.
+ val expected = ADDED.split(",")
+ .map(p => {
+ val uri = new URI(p)
+ if (ClientBase.LOCAL_SCHEME == uri.getScheme()) {
+ p
+ } else {
+ Option(uri.getFragment()).getOrElse(new File(p).getName())
+ }
+ })
+ .mkString(",")
+
+ sparkConf.getOption(ClientBase.CONF_SPARK_YARN_SECONDARY_JARS) should be (Some(expected))
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+
+ test("check access nns empty") {
+ val sparkConf = new SparkConf()
+ sparkConf.set("spark.yarn.access.namenodes", "")
+ val nns = ClientBase.getNameNodesToAccess(sparkConf)
+ nns should be(Set())
+ }
+
+ test("check access nns unset") {
+ val sparkConf = new SparkConf()
+ val nns = ClientBase.getNameNodesToAccess(sparkConf)
+ nns should be(Set())
+ }
+
+ test("check access nns") {
+ val sparkConf = new SparkConf()
+ sparkConf.set("spark.yarn.access.namenodes", "hdfs://nn1:8032")
+ val nns = ClientBase.getNameNodesToAccess(sparkConf)
+ nns should be(Set(new Path("hdfs://nn1:8032")))
+ }
+
+ test("check access nns space") {
+ val sparkConf = new SparkConf()
+ sparkConf.set("spark.yarn.access.namenodes", "hdfs://nn1:8032, ")
+ val nns = ClientBase.getNameNodesToAccess(sparkConf)
+ nns should be(Set(new Path("hdfs://nn1:8032")))
+ }
+
+ test("check access two nns") {
+ val sparkConf = new SparkConf()
+ sparkConf.set("spark.yarn.access.namenodes", "hdfs://nn1:8032,hdfs://nn2:8032")
+ val nns = ClientBase.getNameNodesToAccess(sparkConf)
+ nns should be(Set(new Path("hdfs://nn1:8032"), new Path("hdfs://nn2:8032")))
+ }
+
+ test("check token renewer") {
+ val hadoopConf = new Configuration()
+ hadoopConf.set("yarn.resourcemanager.address", "myrm:8033")
+ hadoopConf.set("yarn.resourcemanager.principal", "yarn/myrm:8032@SPARKTEST.COM")
+ val renewer = ClientBase.getTokenRenewer(hadoopConf)
+ renewer should be ("yarn/myrm:8032@SPARKTEST.COM")
+ }
+
+ test("check token renewer default") {
+ val hadoopConf = new Configuration()
+ val caught =
+ intercept[SparkException] {
+ ClientBase.getTokenRenewer(hadoopConf)
+ }
+ assert(caught.getMessage === "Can't get Master Kerberos principal for use as renewer")
+ }
+
+ object Fixtures {
+
+ val knownDefYarnAppCP: Seq[String] =
+ getFieldValue[Array[String], Seq[String]](classOf[YarnConfiguration],
+ "DEFAULT_YARN_APPLICATION_CLASSPATH",
+ Seq[String]())(a => a.toSeq)
+
+
+ val knownDefMRAppCP: Seq[String] =
+ getFieldValue2[String, Array[String], Seq[String]](
+ classOf[MRJobConfig],
+ "DEFAULT_MAPREDUCE_APPLICATION_CLASSPATH",
+ Seq[String]())(a => a.split(","))(a => a.toSeq)
+
+ val knownYARNAppCP = Some(Seq("/known/yarn/path"))
+
+ val knownMRAppCP = Some(Seq("/known/mr/path"))
+
+ val mapMRAppConf =
+ Map("mapreduce.application.classpath" -> knownMRAppCP.map(_.mkString(":")).get)
+
+ val mapYARNAppConf =
+ Map(YarnConfiguration.YARN_APPLICATION_CLASSPATH -> knownYARNAppCP.map(_.mkString(":")).get)
+
+ val mapAppConf = mapYARNAppConf ++ mapMRAppConf
+ }
+
+ def withAppConf(m: Map[String, String] = Map())(testCode: (Configuration) => Any) {
+ val conf = new Configuration
+ m.foreach { case (k, v) => conf.set(k, v, "ClientBaseSpec") }
+ testCode(conf)
+ }
+
+ def newEnv = MutableHashMap[String, String]()
+
+ def classpath(env: MutableHashMap[String, String]) = env(Environment.CLASSPATH.name).split(":|;")
+
+ def flatten(a: Option[Seq[String]], b: Option[Seq[String]]) = (a ++ b).flatten.toArray
+
+ def getFieldValue[A, B](clazz: Class[_], field: String, defaults: => B)(mapTo: A => B): B =
+ Try(clazz.getField(field)).map(_.get(null).asInstanceOf[A]).toOption.map(mapTo).getOrElse(defaults)
+
+ def getFieldValue2[A: ClassTag, A1: ClassTag, B](
+ clazz: Class[_],
+ field: String,
+ defaults: => B)(mapTo: A => B)(mapTo1: A1 => B) : B = {
+ Try(clazz.getField(field)).map(_.get(null)).map {
+ case v: A => mapTo(v)
+ case v1: A1 => mapTo1(v1)
+ case _ => defaults
+ }.toOption.getOrElse(defaults)
+ }
+
+ private class DummyClient(
+ val args: ClientArguments,
+ val hadoopConf: Configuration,
+ val sparkConf: SparkConf,
+ val yarnConf: YarnConfiguration) extends ClientBase {
+ override def setupSecurityToken(amContainer: ContainerLaunchContext): Unit = ???
+ override def submitApplication(): ApplicationId = ???
+ override def getApplicationReport(appId: ApplicationId): ApplicationReport = ???
+ override def getClientToken(report: ApplicationReport): String = ???
+ }
+
+}
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala
new file mode 100644
index 0000000000..80b57d1355
--- /dev/null
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.net.URI
+
+import org.scalatest.FunSuite
+import org.scalatest.mock.MockitoSugar
+import org.mockito.Mockito.when
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.FileStatus
+import org.apache.hadoop.fs.FileSystem
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.fs.permission.FsAction
+import org.apache.hadoop.yarn.api.records.LocalResource
+import org.apache.hadoop.yarn.api.records.LocalResourceVisibility
+import org.apache.hadoop.yarn.api.records.LocalResourceType
+import org.apache.hadoop.yarn.util.{Records, ConverterUtils}
+
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.Map
+
+
+class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar {
+
+ class MockClientDistributedCacheManager extends ClientDistributedCacheManager {
+ override def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]):
+ LocalResourceVisibility = {
+ LocalResourceVisibility.PRIVATE
+ }
+ }
+
+ test("test getFileStatus empty") {
+ val distMgr = new ClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val uri = new URI("/tmp/testing")
+ when(fs.getFileStatus(new Path(uri))).thenReturn(new FileStatus())
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+ val stat = distMgr.getFileStatus(fs, uri, statCache)
+ assert(stat.getPath() === null)
+ }
+
+ test("test getFileStatus cached") {
+ val distMgr = new ClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val uri = new URI("/tmp/testing")
+ val realFileStatus = new FileStatus(10, false, 1, 1024, 10, 10, null, "testOwner",
+ null, new Path("/tmp/testing"))
+ when(fs.getFileStatus(new Path(uri))).thenReturn(new FileStatus())
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus](uri -> realFileStatus)
+ val stat = distMgr.getFileStatus(fs, uri, statCache)
+ assert(stat.getPath().toString() === "/tmp/testing")
+ }
+
+ test("test addResource") {
+ val distMgr = new MockClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val conf = new Configuration()
+ val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing")
+ val localResources = HashMap[String, LocalResource]()
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+ when(fs.getFileStatus(destPath)).thenReturn(new FileStatus())
+
+ distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, "link",
+ statCache, false)
+ val resource = localResources("link")
+ assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE)
+ assert(ConverterUtils.getPathFromYarnURL(resource.getResource()) === destPath)
+ assert(resource.getTimestamp() === 0)
+ assert(resource.getSize() === 0)
+ assert(resource.getType() === LocalResourceType.FILE)
+
+ val env = new HashMap[String, String]()
+ distMgr.setDistFilesEnv(env)
+ assert(env("SPARK_YARN_CACHE_FILES") === "file:/foo.invalid.com:8080/tmp/testing#link")
+ assert(env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") === "0")
+ assert(env("SPARK_YARN_CACHE_FILES_FILE_SIZES") === "0")
+ assert(env("SPARK_YARN_CACHE_FILES_VISIBILITIES") === LocalResourceVisibility.PRIVATE.name())
+
+ distMgr.setDistArchivesEnv(env)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES") === None)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") === None)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") === None)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") === None)
+
+ // add another one and verify both there and order correct
+ val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner",
+ null, new Path("/tmp/testing2"))
+ val destPath2 = new Path("file:///foo.invalid.com:8080/tmp/testing2")
+ when(fs.getFileStatus(destPath2)).thenReturn(realFileStatus)
+ distMgr.addResource(fs, conf, destPath2, localResources, LocalResourceType.FILE, "link2",
+ statCache, false)
+ val resource2 = localResources("link2")
+ assert(resource2.getVisibility() === LocalResourceVisibility.PRIVATE)
+ assert(ConverterUtils.getPathFromYarnURL(resource2.getResource()) === destPath2)
+ assert(resource2.getTimestamp() === 10)
+ assert(resource2.getSize() === 20)
+ assert(resource2.getType() === LocalResourceType.FILE)
+
+ val env2 = new HashMap[String, String]()
+ distMgr.setDistFilesEnv(env2)
+ val timestamps = env2("SPARK_YARN_CACHE_FILES_TIME_STAMPS").split(',')
+ val files = env2("SPARK_YARN_CACHE_FILES").split(',')
+ val sizes = env2("SPARK_YARN_CACHE_FILES_FILE_SIZES").split(',')
+ val visibilities = env2("SPARK_YARN_CACHE_FILES_VISIBILITIES") .split(',')
+ assert(files(0) === "file:/foo.invalid.com:8080/tmp/testing#link")
+ assert(timestamps(0) === "0")
+ assert(sizes(0) === "0")
+ assert(visibilities(0) === LocalResourceVisibility.PRIVATE.name())
+
+ assert(files(1) === "file:/foo.invalid.com:8080/tmp/testing2#link2")
+ assert(timestamps(1) === "10")
+ assert(sizes(1) === "20")
+ assert(visibilities(1) === LocalResourceVisibility.PRIVATE.name())
+ }
+
+ test("test addResource link null") {
+ val distMgr = new MockClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val conf = new Configuration()
+ val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing")
+ val localResources = HashMap[String, LocalResource]()
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+ when(fs.getFileStatus(destPath)).thenReturn(new FileStatus())
+
+ intercept[Exception] {
+ distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, null,
+ statCache, false)
+ }
+ assert(localResources.get("link") === None)
+ assert(localResources.size === 0)
+ }
+
+ test("test addResource appmaster only") {
+ val distMgr = new MockClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val conf = new Configuration()
+ val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing")
+ val localResources = HashMap[String, LocalResource]()
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+ val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner",
+ null, new Path("/tmp/testing"))
+ when(fs.getFileStatus(destPath)).thenReturn(realFileStatus)
+
+ distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link",
+ statCache, true)
+ val resource = localResources("link")
+ assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE)
+ assert(ConverterUtils.getPathFromYarnURL(resource.getResource()) === destPath)
+ assert(resource.getTimestamp() === 10)
+ assert(resource.getSize() === 20)
+ assert(resource.getType() === LocalResourceType.ARCHIVE)
+
+ val env = new HashMap[String, String]()
+ distMgr.setDistFilesEnv(env)
+ assert(env.get("SPARK_YARN_CACHE_FILES") === None)
+ assert(env.get("SPARK_YARN_CACHE_FILES_TIME_STAMPS") === None)
+ assert(env.get("SPARK_YARN_CACHE_FILES_FILE_SIZES") === None)
+ assert(env.get("SPARK_YARN_CACHE_FILES_VISIBILITIES") === None)
+
+ distMgr.setDistArchivesEnv(env)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES") === None)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") === None)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") === None)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") === None)
+ }
+
+ test("test addResource archive") {
+ val distMgr = new MockClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val conf = new Configuration()
+ val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing")
+ val localResources = HashMap[String, LocalResource]()
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+ val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner",
+ null, new Path("/tmp/testing"))
+ when(fs.getFileStatus(destPath)).thenReturn(realFileStatus)
+
+ distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link",
+ statCache, false)
+ val resource = localResources("link")
+ assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE)
+ assert(ConverterUtils.getPathFromYarnURL(resource.getResource()) === destPath)
+ assert(resource.getTimestamp() === 10)
+ assert(resource.getSize() === 20)
+ assert(resource.getType() === LocalResourceType.ARCHIVE)
+
+ val env = new HashMap[String, String]()
+
+ distMgr.setDistArchivesEnv(env)
+ assert(env("SPARK_YARN_CACHE_ARCHIVES") === "file:/foo.invalid.com:8080/tmp/testing#link")
+ assert(env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") === "10")
+ assert(env("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") === "20")
+ assert(env("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") === LocalResourceVisibility.PRIVATE.name())
+
+ distMgr.setDistFilesEnv(env)
+ assert(env.get("SPARK_YARN_CACHE_FILES") === None)
+ assert(env.get("SPARK_YARN_CACHE_FILES_TIME_STAMPS") === None)
+ assert(env.get("SPARK_YARN_CACHE_FILES_FILE_SIZES") === None)
+ assert(env.get("SPARK_YARN_CACHE_FILES_VISIBILITIES") === None)
+ }
+
+
+}
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
new file mode 100644
index 0000000000..8d184a09d6
--- /dev/null
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import org.apache.spark.deploy.yarn.YarnAllocator._
+import org.scalatest.FunSuite
+
+class YarnAllocatorSuite extends FunSuite {
+ test("memory exceeded diagnostic regexes") {
+ val diagnostics =
+ "Container [pid=12465,containerID=container_1412887393566_0003_01_000002] is running " +
+ "beyond physical memory limits. Current usage: 2.1 MB of 2 GB physical memory used; " +
+ "5.8 GB of 4.2 GB virtual memory used. Killing container."
+ val vmemMsg = memLimitExceededLogMessage(diagnostics, VMEM_EXCEEDED_PATTERN)
+ val pmemMsg = memLimitExceededLogMessage(diagnostics, PMEM_EXCEEDED_PATTERN)
+ assert(vmemMsg.contains("5.8 GB of 4.2 GB virtual memory used."))
+ assert(pmemMsg.contains("2.1 MB of 2 GB physical memory used."))
+ }
+}
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
new file mode 100644
index 0000000000..d79b85e867
--- /dev/null
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
@@ -0,0 +1,189 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.io.File
+import java.util.concurrent.TimeUnit
+
+import scala.collection.JavaConversions._
+
+import com.google.common.base.Charsets
+import com.google.common.io.Files
+import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}
+
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.server.MiniYARNCluster
+
+import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException}
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.util.Utils
+
+class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers with Logging {
+
+ // log4j configuration for the Yarn containers, so that their output is collected
+ // by Yarn instead of trying to overwrite unit-tests.log.
+ private val LOG4J_CONF = """
+ |log4j.rootCategory=DEBUG, console
+ |log4j.appender.console=org.apache.log4j.ConsoleAppender
+ |log4j.appender.console.target=System.err
+ |log4j.appender.console.layout=org.apache.log4j.PatternLayout
+ |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
+ """.stripMargin
+
+ private var yarnCluster: MiniYARNCluster = _
+ private var tempDir: File = _
+ private var fakeSparkJar: File = _
+ private var oldConf: Map[String, String] = _
+
+ override def beforeAll() {
+ tempDir = Utils.createTempDir()
+
+ val logConfDir = new File(tempDir, "log4j")
+ logConfDir.mkdir()
+
+ val logConfFile = new File(logConfDir, "log4j.properties")
+ Files.write(LOG4J_CONF, logConfFile, Charsets.UTF_8)
+
+ val childClasspath = logConfDir.getAbsolutePath() + File.pathSeparator +
+ sys.props("java.class.path")
+
+ oldConf = sys.props.filter { case (k, v) => k.startsWith("spark.") }.toMap
+
+ yarnCluster = new MiniYARNCluster(getClass().getName(), 1, 1, 1)
+ yarnCluster.init(new YarnConfiguration())
+ yarnCluster.start()
+
+ // There's a race in MiniYARNCluster in which start() may return before the RM has updated
+ // its address in the configuration. You can see this in the logs by noticing that when
+ // MiniYARNCluster prints the address, it still has port "0" assigned, although later the
+ // test works sometimes:
+ //
+ // INFO MiniYARNCluster: MiniYARN ResourceManager address: blah:0
+ //
+ // That log message prints the contents of the RM_ADDRESS config variable. If you check it
+ // later on, it looks something like this:
+ //
+ // INFO YarnClusterSuite: RM address in configuration is blah:42631
+ //
+ // This hack loops for a bit waiting for the port to change, and fails the test if it hasn't
+ // done so in a timely manner (defined to be 10 seconds).
+ val config = yarnCluster.getConfig()
+ val deadline = System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(10)
+ while (config.get(YarnConfiguration.RM_ADDRESS).split(":")(1) == "0") {
+ if (System.currentTimeMillis() > deadline) {
+ throw new IllegalStateException("Timed out waiting for RM to come up.")
+ }
+ logDebug("RM address still not set in configuration, waiting...")
+ TimeUnit.MILLISECONDS.sleep(100)
+ }
+
+ logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}")
+ config.foreach { e =>
+ sys.props += ("spark.hadoop." + e.getKey() -> e.getValue())
+ }
+
+ fakeSparkJar = File.createTempFile("sparkJar", null, tempDir)
+ sys.props += ("spark.yarn.jar" -> ("local:" + fakeSparkJar.getAbsolutePath()))
+ sys.props += ("spark.executor.instances" -> "1")
+ sys.props += ("spark.driver.extraClassPath" -> childClasspath)
+ sys.props += ("spark.executor.extraClassPath" -> childClasspath)
+
+ super.beforeAll()
+ }
+
+ override def afterAll() {
+ yarnCluster.stop()
+ sys.props.retain { case (k, v) => !k.startsWith("spark.") }
+ sys.props ++= oldConf
+ super.afterAll()
+ }
+
+ test("run Spark in yarn-client mode") {
+ var result = File.createTempFile("result", null, tempDir)
+ YarnClusterDriver.main(Array("yarn-client", result.getAbsolutePath()))
+ checkResult(result)
+ }
+
+ test("run Spark in yarn-cluster mode") {
+ val main = YarnClusterDriver.getClass.getName().stripSuffix("$")
+ var result = File.createTempFile("result", null, tempDir)
+
+ val args = Array("--class", main,
+ "--jar", "file:" + fakeSparkJar.getAbsolutePath(),
+ "--arg", "yarn-cluster",
+ "--arg", result.getAbsolutePath(),
+ "--num-executors", "1")
+ Client.main(args)
+ checkResult(result)
+ }
+
+ test("run Spark in yarn-cluster mode unsuccessfully") {
+ val main = YarnClusterDriver.getClass.getName().stripSuffix("$")
+
+ // Use only one argument so the driver will fail
+ val args = Array("--class", main,
+ "--jar", "file:" + fakeSparkJar.getAbsolutePath(),
+ "--arg", "yarn-cluster",
+ "--num-executors", "1")
+ val exception = intercept[SparkException] {
+ Client.main(args)
+ }
+ assert(Utils.exceptionString(exception).contains("Application finished with failed status"))
+ }
+
+ /**
+ * This is a workaround for an issue with yarn-cluster mode: the Client class will not provide
+ * any sort of error when the job process finishes successfully, but the job itself fails. So
+ * the tests enforce that something is written to a file after everything is ok to indicate
+ * that the job succeeded.
+ */
+ private def checkResult(result: File) = {
+ var resultString = Files.toString(result, Charsets.UTF_8)
+ resultString should be ("success")
+ }
+
+}
+
+private object YarnClusterDriver extends Logging with Matchers {
+
+ def main(args: Array[String]) = {
+ if (args.length != 2) {
+ System.err.println(
+ s"""
+ |Invalid command line: ${args.mkString(" ")}
+ |
+ |Usage: YarnClusterDriver [master] [result file]
+ """.stripMargin)
+ System.exit(1)
+ }
+
+ val sc = new SparkContext(new SparkConf().setMaster(args(0))
+ .setAppName("yarn \"test app\" 'with quotes' and \\back\\slashes and $dollarSigns"))
+ val status = new File(args(1))
+ var result = "failure"
+ try {
+ val data = sc.parallelize(1 to 4, 4).collect().toSet
+ data should be (Set(1, 2, 3, 4))
+ result = "success"
+ } finally {
+ sc.stop()
+ Files.write(result, status, Charsets.UTF_8)
+ }
+ }
+
+}
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala
new file mode 100644
index 0000000000..2cc5abb3a8
--- /dev/null
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.io.{File, IOException}
+
+import com.google.common.io.{ByteStreams, Files}
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.scalatest.{FunSuite, Matchers}
+
+import org.apache.hadoop.yarn.api.records.ApplicationAccessType
+
+import org.apache.spark.{Logging, SecurityManager, SparkConf}
+
+
+class YarnSparkHadoopUtilSuite extends FunSuite with Matchers with Logging {
+
+ val hasBash =
+ try {
+ val exitCode = Runtime.getRuntime().exec(Array("bash", "--version")).waitFor()
+ exitCode == 0
+ } catch {
+ case e: IOException =>
+ false
+ }
+
+ if (!hasBash) {
+ logWarning("Cannot execute bash, skipping bash tests.")
+ }
+
+ def bashTest(name: String)(fn: => Unit) =
+ if (hasBash) test(name)(fn) else ignore(name)(fn)
+
+ bashTest("shell script escaping") {
+ val scriptFile = File.createTempFile("script.", ".sh")
+ val args = Array("arg1", "${arg.2}", "\"arg3\"", "'arg4'", "$arg5", "\\arg6")
+ try {
+ val argLine = args.map(a => YarnSparkHadoopUtil.escapeForShell(a)).mkString(" ")
+ Files.write(("bash -c \"echo " + argLine + "\"").getBytes(), scriptFile)
+ scriptFile.setExecutable(true)
+
+ val proc = Runtime.getRuntime().exec(Array(scriptFile.getAbsolutePath()))
+ val out = new String(ByteStreams.toByteArray(proc.getInputStream())).trim()
+ val err = new String(ByteStreams.toByteArray(proc.getErrorStream()))
+ val exitCode = proc.waitFor()
+ exitCode should be (0)
+ out should be (args.mkString(" "))
+ } finally {
+ scriptFile.delete()
+ }
+ }
+
+ test("Yarn configuration override") {
+ val key = "yarn.nodemanager.hostname"
+ val default = new YarnConfiguration()
+
+ val sparkConf = new SparkConf()
+ .set("spark.hadoop." + key, "someHostName")
+ val yarnConf = new YarnSparkHadoopUtil().newConfiguration(sparkConf)
+
+ yarnConf.getClass() should be (classOf[YarnConfiguration])
+ yarnConf.get(key) should not be default.get(key)
+ }
+
+
+ test("test getApplicationAclsForYarn acls on") {
+
+ // spark acls on, just pick up default user
+ val sparkConf = new SparkConf()
+ sparkConf.set("spark.acls.enable", "true")
+
+ val securityMgr = new SecurityManager(sparkConf)
+ val acls = YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr)
+
+ val viewAcls = acls.get(ApplicationAccessType.VIEW_APP)
+ val modifyAcls = acls.get(ApplicationAccessType.MODIFY_APP)
+
+ viewAcls match {
+ case Some(vacls) => {
+ val aclSet = vacls.split(',').map(_.trim).toSet
+ assert(aclSet.contains(System.getProperty("user.name", "invalid")))
+ }
+ case None => {
+ fail()
+ }
+ }
+ modifyAcls match {
+ case Some(macls) => {
+ val aclSet = macls.split(',').map(_.trim).toSet
+ assert(aclSet.contains(System.getProperty("user.name", "invalid")))
+ }
+ case None => {
+ fail()
+ }
+ }
+ }
+
+ test("test getApplicationAclsForYarn acls on and specify users") {
+
+ // default spark acls are on and specify acls
+ val sparkConf = new SparkConf()
+ sparkConf.set("spark.acls.enable", "true")
+ sparkConf.set("spark.ui.view.acls", "user1,user2")
+ sparkConf.set("spark.modify.acls", "user3,user4")
+
+ val securityMgr = new SecurityManager(sparkConf)
+ val acls = YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr)
+
+ val viewAcls = acls.get(ApplicationAccessType.VIEW_APP)
+ val modifyAcls = acls.get(ApplicationAccessType.MODIFY_APP)
+
+ viewAcls match {
+ case Some(vacls) => {
+ val aclSet = vacls.split(',').map(_.trim).toSet
+ assert(aclSet.contains("user1"))
+ assert(aclSet.contains("user2"))
+ assert(aclSet.contains(System.getProperty("user.name", "invalid")))
+ }
+ case None => {
+ fail()
+ }
+ }
+ modifyAcls match {
+ case Some(macls) => {
+ val aclSet = macls.split(',').map(_.trim).toSet
+ assert(aclSet.contains("user3"))
+ assert(aclSet.contains("user4"))
+ assert(aclSet.contains(System.getProperty("user.name", "invalid")))
+ }
+ case None => {
+ fail()
+ }
+ }
+
+ }
+}