aboutsummaryrefslogtreecommitdiff
path: root/yarn/src
diff options
context:
space:
mode:
authorJey Kottalam <jey@cs.berkeley.edu>2013-07-17 14:53:37 -0700
committerJey Kottalam <jey@cs.berkeley.edu>2013-08-15 16:50:36 -0700
commitb877e20a339872f9a29a35272e6c1f280ac901d5 (patch)
treee55d55f65bddfc91769a28e3c3398c9aed1f6fb0 /yarn/src
parent28369ff7733d0994b8d8580ae4eacd82a7080256 (diff)
downloadspark-b877e20a339872f9a29a35272e6c1f280ac901d5.tar.gz
spark-b877e20a339872f9a29a35272e6c1f280ac901d5.tar.bz2
spark-b877e20a339872f9a29a35272e6c1f280ac901d5.zip
move yarn to its own directory
Diffstat (limited to 'yarn/src')
-rw-r--r--yarn/src/main/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala30
-rw-r--r--yarn/src/main/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala30
-rw-r--r--yarn/src/main/scala/spark/deploy/SparkHadoopUtil.scala76
-rw-r--r--yarn/src/main/scala/spark/deploy/yarn/ApplicationMaster.scala351
-rw-r--r--yarn/src/main/scala/spark/deploy/yarn/ApplicationMasterArguments.scala94
-rw-r--r--yarn/src/main/scala/spark/deploy/yarn/Client.scala327
-rw-r--r--yarn/src/main/scala/spark/deploy/yarn/ClientArguments.scala116
-rw-r--r--yarn/src/main/scala/spark/deploy/yarn/WorkerRunnable.scala217
-rw-r--r--yarn/src/main/scala/spark/deploy/yarn/YarnAllocationHandler.scala564
-rw-r--r--yarn/src/main/scala/spark/scheduler/cluster/YarnClusterScheduler.scala59
10 files changed, 1864 insertions, 0 deletions
diff --git a/yarn/src/main/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala b/yarn/src/main/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala
new file mode 100644
index 0000000000..0f972b7a0b
--- /dev/null
+++ b/yarn/src/main/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala
@@ -0,0 +1,30 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.mapred
+
+import org.apache.hadoop.mapreduce.TaskType
+
+trait HadoopMapRedUtil {
+ def newJobContext(conf: JobConf, jobId: JobID): JobContext = new JobContextImpl(conf, jobId)
+
+ def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
+
+ def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) =
+ new TaskAttemptID(jtIdentifier, jobId, if (isMap) TaskType.MAP else TaskType.REDUCE, taskId, attemptId)
+}
diff --git a/yarn/src/main/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala b/yarn/src/main/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala
new file mode 100644
index 0000000000..1a7cdf4788
--- /dev/null
+++ b/yarn/src/main/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.mapreduce
+
+import org.apache.hadoop.conf.Configuration
+import task.{TaskAttemptContextImpl, JobContextImpl}
+
+trait HadoopMapReduceUtil {
+ def newJobContext(conf: Configuration, jobId: JobID): JobContext = new JobContextImpl(conf, jobId)
+
+ def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
+
+ def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) =
+ new TaskAttemptID(jtIdentifier, jobId, if (isMap) TaskType.MAP else TaskType.REDUCE, taskId, attemptId)
+}
diff --git a/yarn/src/main/scala/spark/deploy/SparkHadoopUtil.scala b/yarn/src/main/scala/spark/deploy/SparkHadoopUtil.scala
new file mode 100644
index 0000000000..6122fdced0
--- /dev/null
+++ b/yarn/src/main/scala/spark/deploy/SparkHadoopUtil.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package spark.deploy
+
+import collection.mutable.HashMap
+import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.security.UserGroupInformation
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
+import java.security.PrivilegedExceptionAction
+
+/**
+ * Contains util methods to interact with Hadoop from spark.
+ */
+object SparkHadoopUtil {
+
+ val yarnConf = newConfiguration()
+
+ def getUserNameFromEnvironment(): String = {
+ // defaulting to env if -D is not present ...
+ val retval = System.getProperty(Environment.USER.name, System.getenv(Environment.USER.name))
+
+ // If nothing found, default to user we are running as
+ if (retval == null) System.getProperty("user.name") else retval
+ }
+
+ def runAsUser(func: (Product) => Unit, args: Product) {
+ runAsUser(func, args, getUserNameFromEnvironment())
+ }
+
+ def runAsUser(func: (Product) => Unit, args: Product, user: String) {
+ func(args)
+ }
+
+ // Note that all params which start with SPARK are propagated all the way through, so if in yarn mode, this MUST be set to true.
+ def isYarnMode(): Boolean = {
+ val yarnMode = System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))
+ java.lang.Boolean.valueOf(yarnMode)
+ }
+
+ // Set an env variable indicating we are running in YARN mode.
+ // Note that anything with SPARK prefix gets propagated to all (remote) processes
+ def setYarnMode() {
+ System.setProperty("SPARK_YARN_MODE", "true")
+ }
+
+ def setYarnMode(env: HashMap[String, String]) {
+ env("SPARK_YARN_MODE") = "true"
+ }
+
+ // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop subsystems
+ // Always create a new config, dont reuse yarnConf.
+ def newConfiguration(): Configuration = new YarnConfiguration(new Configuration())
+
+ // add any user credentials to the job conf which are necessary for running on a secure Hadoop cluster
+ def addCredentials(conf: JobConf) {
+ val jobCreds = conf.getCredentials();
+ jobCreds.mergeAll(UserGroupInformation.getCurrentUser().getCredentials())
+ }
+}
diff --git a/yarn/src/main/scala/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/spark/deploy/yarn/ApplicationMaster.scala
new file mode 100644
index 0000000000..1b06169739
--- /dev/null
+++ b/yarn/src/main/scala/spark/deploy/yarn/ApplicationMaster.scala
@@ -0,0 +1,351 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package spark.deploy.yarn
+
+import java.net.Socket
+import java.util.concurrent.CopyOnWriteArrayList
+import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.net.NetUtils
+import org.apache.hadoop.yarn.api._
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.api.protocolrecords._
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.ipc.YarnRPC
+import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
+import scala.collection.JavaConversions._
+import spark.{SparkContext, Logging, Utils}
+import org.apache.hadoop.security.UserGroupInformation
+import java.security.PrivilegedExceptionAction
+
+class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) extends Logging {
+
+ def this(args: ApplicationMasterArguments) = this(args, new Configuration())
+
+ private var rpc: YarnRPC = YarnRPC.create(conf)
+ private var resourceManager: AMRMProtocol = null
+ private var appAttemptId: ApplicationAttemptId = null
+ private var userThread: Thread = null
+ private val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
+
+ private var yarnAllocator: YarnAllocationHandler = null
+ private var isFinished:Boolean = false
+
+ def run() {
+
+ appAttemptId = getApplicationAttemptId()
+ resourceManager = registerWithResourceManager()
+ val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster()
+
+ // Compute number of threads for akka
+ val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory()
+
+ if (minimumMemory > 0) {
+ val mem = args.workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD
+ val numCore = (mem / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0)
+
+ if (numCore > 0) {
+ // do not override - hits https://issues.apache.org/jira/browse/HADOOP-8406
+ // TODO: Uncomment when hadoop is on a version which has this fixed.
+ // args.workerCores = numCore
+ }
+ }
+
+ // Workaround until hadoop moves to something which has
+ // https://issues.apache.org/jira/browse/HADOOP-8406
+ // ignore result
+ // This does not, unfortunately, always work reliably ... but alleviates the bug a lot of times
+ // Hence args.workerCores = numCore disabled above. Any better option ?
+ // org.apache.hadoop.io.compress.CompressionCodecFactory.getCodecClasses(conf)
+
+ ApplicationMaster.register(this)
+ // Start the user's JAR
+ userThread = startUserClass()
+
+ // This a bit hacky, but we need to wait until the spark.driver.port property has
+ // been set by the Thread executing the user class.
+ waitForSparkMaster()
+
+ // Allocate all containers
+ allocateWorkers()
+
+ // Wait for the user class to Finish
+ userThread.join()
+
+ System.exit(0)
+ }
+
+ private def getApplicationAttemptId(): ApplicationAttemptId = {
+ val envs = System.getenv()
+ val containerIdString = envs.get(ApplicationConstants.AM_CONTAINER_ID_ENV)
+ val containerId = ConverterUtils.toContainerId(containerIdString)
+ val appAttemptId = containerId.getApplicationAttemptId()
+ logInfo("ApplicationAttemptId: " + appAttemptId)
+ return appAttemptId
+ }
+
+ private def registerWithResourceManager(): AMRMProtocol = {
+ val rmAddress = NetUtils.createSocketAddr(yarnConf.get(
+ YarnConfiguration.RM_SCHEDULER_ADDRESS,
+ YarnConfiguration.DEFAULT_RM_SCHEDULER_ADDRESS))
+ logInfo("Connecting to ResourceManager at " + rmAddress)
+ return rpc.getProxy(classOf[AMRMProtocol], rmAddress, conf).asInstanceOf[AMRMProtocol]
+ }
+
+ private def registerApplicationMaster(): RegisterApplicationMasterResponse = {
+ logInfo("Registering the ApplicationMaster")
+ val appMasterRequest = Records.newRecord(classOf[RegisterApplicationMasterRequest])
+ .asInstanceOf[RegisterApplicationMasterRequest]
+ appMasterRequest.setApplicationAttemptId(appAttemptId)
+ // Setting this to master host,port - so that the ApplicationReport at client has some sensible info.
+ // Users can then monitor stderr/stdout on that node if required.
+ appMasterRequest.setHost(Utils.localHostName())
+ appMasterRequest.setRpcPort(0)
+ // What do we provide here ? Might make sense to expose something sensible later ?
+ appMasterRequest.setTrackingUrl("")
+ return resourceManager.registerApplicationMaster(appMasterRequest)
+ }
+
+ private def waitForSparkMaster() {
+ logInfo("Waiting for spark driver to be reachable.")
+ var driverUp = false
+ while(!driverUp) {
+ val driverHost = System.getProperty("spark.driver.host")
+ val driverPort = System.getProperty("spark.driver.port")
+ try {
+ val socket = new Socket(driverHost, driverPort.toInt)
+ socket.close()
+ logInfo("Master now available: " + driverHost + ":" + driverPort)
+ driverUp = true
+ } catch {
+ case e: Exception =>
+ logError("Failed to connect to driver at " + driverHost + ":" + driverPort)
+ Thread.sleep(100)
+ }
+ }
+ }
+
+ private def startUserClass(): Thread = {
+ logInfo("Starting the user JAR in a separate Thread")
+ val mainMethod = Class.forName(args.userClass, false, Thread.currentThread.getContextClassLoader)
+ .getMethod("main", classOf[Array[String]])
+ val t = new Thread {
+ override def run() {
+ var successed = false
+ try {
+ // Copy
+ var mainArgs: Array[String] = new Array[String](args.userArgs.size())
+ args.userArgs.copyToArray(mainArgs, 0, args.userArgs.size())
+ mainMethod.invoke(null, mainArgs)
+ // some job script has "System.exit(0)" at the end, for example SparkPi, SparkLR
+ // userThread will stop here unless it has uncaught exception thrown out
+ // It need shutdown hook to set SUCCEEDED
+ successed = true
+ } finally {
+ if (successed) {
+ ApplicationMaster.this.finishApplicationMaster(FinalApplicationStatus.SUCCEEDED)
+ } else {
+ ApplicationMaster.this.finishApplicationMaster(FinalApplicationStatus.FAILED)
+ }
+ }
+ }
+ }
+ t.start()
+ return t
+ }
+
+ private def allocateWorkers() {
+ logInfo("Waiting for spark context initialization")
+
+ try {
+ var sparkContext: SparkContext = null
+ ApplicationMaster.sparkContextRef.synchronized {
+ var count = 0
+ while (ApplicationMaster.sparkContextRef.get() == null) {
+ logInfo("Waiting for spark context initialization ... " + count)
+ count = count + 1
+ ApplicationMaster.sparkContextRef.wait(10000L)
+ }
+ sparkContext = ApplicationMaster.sparkContextRef.get()
+ assert(sparkContext != null)
+ this.yarnAllocator = YarnAllocationHandler.newAllocator(yarnConf, resourceManager, appAttemptId, args, sparkContext.preferredNodeLocationData)
+ }
+
+
+ logInfo("Allocating " + args.numWorkers + " workers.")
+ // Wait until all containers have finished
+ // TODO: This is a bit ugly. Can we make it nicer?
+ // TODO: Handle container failure
+ while(yarnAllocator.getNumWorkersRunning < args.numWorkers &&
+ // If user thread exists, then quit !
+ userThread.isAlive) {
+
+ this.yarnAllocator.allocateContainers(math.max(args.numWorkers - yarnAllocator.getNumWorkersRunning, 0))
+ ApplicationMaster.incrementAllocatorLoop(1)
+ Thread.sleep(100)
+ }
+ } finally {
+ // in case of exceptions, etc - ensure that count is atleast ALLOCATOR_LOOP_WAIT_COUNT :
+ // so that the loop (in ApplicationMaster.sparkContextInitialized) breaks
+ ApplicationMaster.incrementAllocatorLoop(ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT)
+ }
+ logInfo("All workers have launched.")
+
+ // Launch a progress reporter thread, else app will get killed after expiration (def: 10mins) timeout
+ if (userThread.isAlive) {
+ // ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse.
+
+ val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000)
+ // must be <= timeoutInterval/ 2.
+ // On other hand, also ensure that we are reasonably responsive without causing too many requests to RM.
+ // so atleast 1 minute or timeoutInterval / 10 - whichever is higher.
+ val interval = math.min(timeoutInterval / 2, math.max(timeoutInterval/ 10, 60000L))
+ launchReporterThread(interval)
+ }
+ }
+
+ // TODO: We might want to extend this to allocate more containers in case they die !
+ private def launchReporterThread(_sleepTime: Long): Thread = {
+ val sleepTime = if (_sleepTime <= 0 ) 0 else _sleepTime
+
+ val t = new Thread {
+ override def run() {
+ while (userThread.isAlive) {
+ val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning
+ if (missingWorkerCount > 0) {
+ logInfo("Allocating " + missingWorkerCount + " containers to make up for (potentially ?) lost containers")
+ yarnAllocator.allocateContainers(missingWorkerCount)
+ }
+ else sendProgress()
+ Thread.sleep(sleepTime)
+ }
+ }
+ }
+ // setting to daemon status, though this is usually not a good idea.
+ t.setDaemon(true)
+ t.start()
+ logInfo("Started progress reporter thread - sleep time : " + sleepTime)
+ return t
+ }
+
+ private def sendProgress() {
+ logDebug("Sending progress")
+ // simulated with an allocate request with no nodes requested ...
+ yarnAllocator.allocateContainers(0)
+ }
+
+ /*
+ def printContainers(containers: List[Container]) = {
+ for (container <- containers) {
+ logInfo("Launching shell command on a new container."
+ + ", containerId=" + container.getId()
+ + ", containerNode=" + container.getNodeId().getHost()
+ + ":" + container.getNodeId().getPort()
+ + ", containerNodeURI=" + container.getNodeHttpAddress()
+ + ", containerState" + container.getState()
+ + ", containerResourceMemory"
+ + container.getResource().getMemory())
+ }
+ }
+ */
+
+ def finishApplicationMaster(status: FinalApplicationStatus) {
+
+ synchronized {
+ if (isFinished) {
+ return
+ }
+ isFinished = true
+ }
+
+ logInfo("finishApplicationMaster with " + status)
+ val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest])
+ .asInstanceOf[FinishApplicationMasterRequest]
+ finishReq.setAppAttemptId(appAttemptId)
+ finishReq.setFinishApplicationStatus(status)
+ resourceManager.finishApplicationMaster(finishReq)
+
+ }
+
+}
+
+object ApplicationMaster {
+ // number of times to wait for the allocator loop to complete.
+ // each loop iteration waits for 100ms, so maximum of 3 seconds.
+ // This is to ensure that we have reasonable number of containers before we start
+ // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be optimal as more
+ // containers are available. Might need to handle this better.
+ private val ALLOCATOR_LOOP_WAIT_COUNT = 30
+ def incrementAllocatorLoop(by: Int) {
+ val count = yarnAllocatorLoop.getAndAdd(by)
+ if (count >= ALLOCATOR_LOOP_WAIT_COUNT) {
+ yarnAllocatorLoop.synchronized {
+ // to wake threads off wait ...
+ yarnAllocatorLoop.notifyAll()
+ }
+ }
+ }
+
+ private val applicationMasters = new CopyOnWriteArrayList[ApplicationMaster]()
+
+ def register(master: ApplicationMaster) {
+ applicationMasters.add(master)
+ }
+
+ val sparkContextRef: AtomicReference[SparkContext] = new AtomicReference[SparkContext](null)
+ val yarnAllocatorLoop: AtomicInteger = new AtomicInteger(0)
+
+ def sparkContextInitialized(sc: SparkContext): Boolean = {
+ var modified = false
+ sparkContextRef.synchronized {
+ modified = sparkContextRef.compareAndSet(null, sc)
+ sparkContextRef.notifyAll()
+ }
+
+ // Add a shutdown hook - as a best case effort in case users do not call sc.stop or do System.exit
+ // Should not really have to do this, but it helps yarn to evict resources earlier.
+ // not to mention, prevent Client declaring failure even though we exit'ed properly.
+ if (modified) {
+ Runtime.getRuntime().addShutdownHook(new Thread with Logging {
+ // This is not just to log, but also to ensure that log system is initialized for this instance when we actually are 'run'
+ logInfo("Adding shutdown hook for context " + sc)
+ override def run() {
+ logInfo("Invoking sc stop from shutdown hook")
+ sc.stop()
+ // best case ...
+ for (master <- applicationMasters) {
+ master.finishApplicationMaster(FinalApplicationStatus.SUCCEEDED)
+ }
+ }
+ } )
+ }
+
+ // Wait for initialization to complete and atleast 'some' nodes can get allocated
+ yarnAllocatorLoop.synchronized {
+ while (yarnAllocatorLoop.get() <= ALLOCATOR_LOOP_WAIT_COUNT) {
+ yarnAllocatorLoop.wait(1000L)
+ }
+ }
+ modified
+ }
+
+ def main(argStrings: Array[String]) {
+ val args = new ApplicationMasterArguments(argStrings)
+ new ApplicationMaster(args).run()
+ }
+}
diff --git a/yarn/src/main/scala/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/src/main/scala/spark/deploy/yarn/ApplicationMasterArguments.scala
new file mode 100644
index 0000000000..8de44b1f66
--- /dev/null
+++ b/yarn/src/main/scala/spark/deploy/yarn/ApplicationMasterArguments.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package spark.deploy.yarn
+
+import spark.util.IntParam
+import collection.mutable.ArrayBuffer
+
+class ApplicationMasterArguments(val args: Array[String]) {
+ var userJar: String = null
+ var userClass: String = null
+ var userArgs: Seq[String] = Seq[String]()
+ var workerMemory = 1024
+ var workerCores = 1
+ var numWorkers = 2
+
+ parseArgs(args.toList)
+
+ private def parseArgs(inputArgs: List[String]): Unit = {
+ val userArgsBuffer = new ArrayBuffer[String]()
+
+ var args = inputArgs
+
+ while (! args.isEmpty) {
+
+ args match {
+ case ("--jar") :: value :: tail =>
+ userJar = value
+ args = tail
+
+ case ("--class") :: value :: tail =>
+ userClass = value
+ args = tail
+
+ case ("--args") :: value :: tail =>
+ userArgsBuffer += value
+ args = tail
+
+ case ("--num-workers") :: IntParam(value) :: tail =>
+ numWorkers = value
+ args = tail
+
+ case ("--worker-memory") :: IntParam(value) :: tail =>
+ workerMemory = value
+ args = tail
+
+ case ("--worker-cores") :: IntParam(value) :: tail =>
+ workerCores = value
+ args = tail
+
+ case Nil =>
+ if (userJar == null || userClass == null) {
+ printUsageAndExit(1)
+ }
+
+ case _ =>
+ printUsageAndExit(1, args)
+ }
+ }
+
+ userArgs = userArgsBuffer.readOnly
+ }
+
+ def printUsageAndExit(exitCode: Int, unknownParam: Any = null) {
+ if (unknownParam != null) {
+ System.err.println("Unknown/unsupported param " + unknownParam)
+ }
+ System.err.println(
+ "Usage: spark.deploy.yarn.ApplicationMaster [options] \n" +
+ "Options:\n" +
+ " --jar JAR_PATH Path to your application's JAR file (required)\n" +
+ " --class CLASS_NAME Name of your application's main class (required)\n" +
+ " --args ARGS Arguments to be passed to your application's main class.\n" +
+ " Mutliple invocations are possible, each will be passed in order.\n" +
+ " --num-workers NUM Number of workers to start (Default: 2)\n" +
+ " --worker-cores NUM Number of cores for the workers (Default: 1)\n" +
+ " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n")
+ System.exit(exitCode)
+ }
+}
diff --git a/yarn/src/main/scala/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/spark/deploy/yarn/Client.scala
new file mode 100644
index 0000000000..8bcbfc2735
--- /dev/null
+++ b/yarn/src/main/scala/spark/deploy/yarn/Client.scala
@@ -0,0 +1,327 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package spark.deploy.yarn
+
+import java.net.{InetSocketAddress, URI}
+import java.nio.ByteBuffer
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
+import org.apache.hadoop.mapred.Master
+import org.apache.hadoop.net.NetUtils
+import org.apache.hadoop.io.DataOutputBuffer
+import org.apache.hadoop.security.UserGroupInformation
+import org.apache.hadoop.yarn.api._
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.api.protocolrecords._
+import org.apache.hadoop.yarn.client.YarnClientImpl
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.ipc.YarnRPC
+import scala.collection.mutable.HashMap
+import scala.collection.JavaConversions._
+import spark.{Logging, Utils}
+import org.apache.hadoop.yarn.util.{Apps, Records, ConverterUtils}
+import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
+import spark.deploy.SparkHadoopUtil
+
+class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl with Logging {
+
+ def this(args: ClientArguments) = this(new Configuration(), args)
+
+ var rpc: YarnRPC = YarnRPC.create(conf)
+ val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
+ val credentials = UserGroupInformation.getCurrentUser().getCredentials();
+
+ def run() {
+ init(yarnConf)
+ start()
+ logClusterResourceDetails()
+
+ val newApp = super.getNewApplication()
+ val appId = newApp.getApplicationId()
+
+ verifyClusterResources(newApp)
+ val appContext = createApplicationSubmissionContext(appId)
+ val localResources = prepareLocalResources(appId, "spark")
+ val env = setupLaunchEnv(localResources)
+ val amContainer = createContainerLaunchContext(newApp, localResources, env)
+
+ appContext.setQueue(args.amQueue)
+ appContext.setAMContainerSpec(amContainer)
+ appContext.setUser(UserGroupInformation.getCurrentUser().getShortUserName())
+
+ submitApp(appContext)
+
+ monitorApplication(appId)
+ System.exit(0)
+ }
+
+
+ def logClusterResourceDetails() {
+ val clusterMetrics: YarnClusterMetrics = super.getYarnClusterMetrics
+ logInfo("Got Cluster metric info from ASM, numNodeManagers=" + clusterMetrics.getNumNodeManagers)
+
+ val queueInfo: QueueInfo = super.getQueueInfo(args.amQueue)
+ logInfo("Queue info .. queueName=" + queueInfo.getQueueName + ", queueCurrentCapacity=" + queueInfo.getCurrentCapacity +
+ ", queueMaxCapacity=" + queueInfo.getMaximumCapacity + ", queueApplicationCount=" + queueInfo.getApplications.size +
+ ", queueChildQueueCount=" + queueInfo.getChildQueues.size)
+ }
+
+
+ def verifyClusterResources(app: GetNewApplicationResponse) = {
+ val maxMem = app.getMaximumResourceCapability().getMemory()
+ logInfo("Max mem capabililty of a single resource in this cluster " + maxMem)
+
+ // if we have requested more then the clusters max for a single resource then exit.
+ if (args.workerMemory > maxMem) {
+ logError("the worker size is to large to run on this cluster " + args.workerMemory);
+ System.exit(1)
+ }
+ val amMem = args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD
+ if (amMem > maxMem) {
+ logError("AM size is to large to run on this cluster " + amMem)
+ System.exit(1)
+ }
+
+ // We could add checks to make sure the entire cluster has enough resources but that involves getting
+ // all the node reports and computing ourselves
+ }
+
+ def createApplicationSubmissionContext(appId: ApplicationId): ApplicationSubmissionContext = {
+ logInfo("Setting up application submission context for ASM")
+ val appContext = Records.newRecord(classOf[ApplicationSubmissionContext])
+ appContext.setApplicationId(appId)
+ appContext.setApplicationName("Spark")
+ return appContext
+ }
+
+ def prepareLocalResources(appId: ApplicationId, appName: String): HashMap[String, LocalResource] = {
+ logInfo("Preparing Local resources")
+ val locaResources = HashMap[String, LocalResource]()
+ // Upload Spark and the application JAR to the remote file system
+ // Add them as local resources to the AM
+ val fs = FileSystem.get(conf)
+
+ val delegTokenRenewer = Master.getMasterPrincipal(conf);
+ if (UserGroupInformation.isSecurityEnabled()) {
+ if (delegTokenRenewer == null || delegTokenRenewer.length() == 0) {
+ logError("Can't get Master Kerberos principal for use as renewer")
+ System.exit(1)
+ }
+ }
+
+ Map("spark.jar" -> System.getenv("SPARK_JAR"), "app.jar" -> args.userJar, "log4j.properties" -> System.getenv("SPARK_LOG4J_CONF"))
+ .foreach { case(destName, _localPath) =>
+ val localPath: String = if (_localPath != null) _localPath.trim() else ""
+ if (! localPath.isEmpty()) {
+ val src = new Path(localPath)
+ val pathSuffix = appName + "/" + appId.getId() + destName
+ val dst = new Path(fs.getHomeDirectory(), pathSuffix)
+ logInfo("Uploading " + src + " to " + dst)
+ fs.copyFromLocalFile(false, true, src, dst)
+ val destStatus = fs.getFileStatus(dst)
+
+ // get tokens for anything we upload to hdfs
+ if (UserGroupInformation.isSecurityEnabled()) {
+ fs.addDelegationTokens(delegTokenRenewer, credentials);
+ }
+
+ val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
+ amJarRsrc.setType(LocalResourceType.FILE)
+ amJarRsrc.setVisibility(LocalResourceVisibility.APPLICATION)
+ amJarRsrc.setResource(ConverterUtils.getYarnUrlFromPath(dst))
+ amJarRsrc.setTimestamp(destStatus.getModificationTime())
+ amJarRsrc.setSize(destStatus.getLen())
+ locaResources(destName) = amJarRsrc
+ }
+ }
+ UserGroupInformation.getCurrentUser().addCredentials(credentials);
+ return locaResources
+ }
+
+ def setupLaunchEnv(localResources: HashMap[String, LocalResource]): HashMap[String, String] = {
+ logInfo("Setting up the launch environment")
+ val log4jConfLocalRes = localResources.getOrElse("log4j.properties", null)
+
+ val env = new HashMap[String, String]()
+
+ // If log4j present, ensure ours overrides all others
+ if (log4jConfLocalRes != null) Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./")
+
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./*")
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, "$CLASSPATH")
+ Client.populateHadoopClasspath(yarnConf, env)
+ SparkHadoopUtil.setYarnMode(env)
+ env("SPARK_YARN_JAR_PATH") =
+ localResources("spark.jar").getResource().getScheme.toString() + "://" +
+ localResources("spark.jar").getResource().getFile().toString()
+ env("SPARK_YARN_JAR_TIMESTAMP") = localResources("spark.jar").getTimestamp().toString()
+ env("SPARK_YARN_JAR_SIZE") = localResources("spark.jar").getSize().toString()
+
+ env("SPARK_YARN_USERJAR_PATH") =
+ localResources("app.jar").getResource().getScheme.toString() + "://" +
+ localResources("app.jar").getResource().getFile().toString()
+ env("SPARK_YARN_USERJAR_TIMESTAMP") = localResources("app.jar").getTimestamp().toString()
+ env("SPARK_YARN_USERJAR_SIZE") = localResources("app.jar").getSize().toString()
+
+ if (log4jConfLocalRes != null) {
+ env("SPARK_YARN_LOG4J_PATH") =
+ log4jConfLocalRes.getResource().getScheme.toString() + "://" + log4jConfLocalRes.getResource().getFile().toString()
+ env("SPARK_YARN_LOG4J_TIMESTAMP") = log4jConfLocalRes.getTimestamp().toString()
+ env("SPARK_YARN_LOG4J_SIZE") = log4jConfLocalRes.getSize().toString()
+ }
+
+
+ // Add each SPARK-* key to the environment
+ System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v }
+ return env
+ }
+
+ def userArgsToString(clientArgs: ClientArguments): String = {
+ val prefix = " --args "
+ val args = clientArgs.userArgs
+ val retval = new StringBuilder()
+ for (arg <- args){
+ retval.append(prefix).append(" '").append(arg).append("' ")
+ }
+
+ retval.toString
+ }
+
+ def createContainerLaunchContext(newApp: GetNewApplicationResponse,
+ localResources: HashMap[String, LocalResource],
+ env: HashMap[String, String]): ContainerLaunchContext = {
+ logInfo("Setting up container launch context")
+ val amContainer = Records.newRecord(classOf[ContainerLaunchContext])
+ amContainer.setLocalResources(localResources)
+ amContainer.setEnvironment(env)
+
+ val minResMemory: Int = newApp.getMinimumResourceCapability().getMemory()
+
+ var amMemory = ((args.amMemory / minResMemory) * minResMemory) +
+ (if (0 != (args.amMemory % minResMemory)) minResMemory else 0) - YarnAllocationHandler.MEMORY_OVERHEAD
+
+ // Extra options for the JVM
+ var JAVA_OPTS = ""
+
+ // Add Xmx for am memory
+ JAVA_OPTS += "-Xmx" + amMemory + "m "
+
+ // Commenting it out for now - so that people can refer to the properties if required. Remove it once cpuset version is pushed out.
+ // The context is, default gc for server class machines end up using all cores to do gc - hence if there are multiple containers in same
+ // node, spark gc effects all other containers performance (which can also be other spark containers)
+ // Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in multi-tenant environments. Not sure how default java gc behaves if it is
+ // limited to subset of cores on a node.
+ if (env.isDefinedAt("SPARK_USE_CONC_INCR_GC") && java.lang.Boolean.parseBoolean(env("SPARK_USE_CONC_INCR_GC"))) {
+ // In our expts, using (default) throughput collector has severe perf ramnifications in multi-tenant machines
+ JAVA_OPTS += " -XX:+UseConcMarkSweepGC "
+ JAVA_OPTS += " -XX:+CMSIncrementalMode "
+ JAVA_OPTS += " -XX:+CMSIncrementalPacing "
+ JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 "
+ JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 "
+ }
+ if (env.isDefinedAt("SPARK_JAVA_OPTS")) {
+ JAVA_OPTS += env("SPARK_JAVA_OPTS") + " "
+ }
+
+ // Command for the ApplicationMaster
+ var javaCommand = "java";
+ val javaHome = System.getenv("JAVA_HOME")
+ if (javaHome != null && !javaHome.isEmpty()) {
+ javaCommand = Environment.JAVA_HOME.$() + "/bin/java"
+ }
+
+ val commands = List[String](javaCommand +
+ " -server " +
+ JAVA_OPTS +
+ " spark.deploy.yarn.ApplicationMaster" +
+ " --class " + args.userClass +
+ " --jar " + args.userJar +
+ userArgsToString(args) +
+ " --worker-memory " + args.workerMemory +
+ " --worker-cores " + args.workerCores +
+ " --num-workers " + args.numWorkers +
+ " 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" +
+ " 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr")
+ logInfo("Command for the ApplicationMaster: " + commands(0))
+ amContainer.setCommands(commands)
+
+ val capability = Records.newRecord(classOf[Resource]).asInstanceOf[Resource]
+ // Memory for the ApplicationMaster
+ capability.setMemory(args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
+ amContainer.setResource(capability)
+
+ // Setup security tokens
+ val dob = new DataOutputBuffer()
+ credentials.writeTokenStorageToStream(dob)
+ amContainer.setContainerTokens(ByteBuffer.wrap(dob.getData()))
+
+ return amContainer
+ }
+
+ def submitApp(appContext: ApplicationSubmissionContext) = {
+ // Submit the application to the applications manager
+ logInfo("Submitting application to ASM")
+ super.submitApplication(appContext)
+ }
+
+ def monitorApplication(appId: ApplicationId): Boolean = {
+ while(true) {
+ Thread.sleep(1000)
+ val report = super.getApplicationReport(appId)
+
+ logInfo("Application report from ASM: \n" +
+ "\t application identifier: " + appId.toString() + "\n" +
+ "\t appId: " + appId.getId() + "\n" +
+ "\t clientToken: " + report.getClientToken() + "\n" +
+ "\t appDiagnostics: " + report.getDiagnostics() + "\n" +
+ "\t appMasterHost: " + report.getHost() + "\n" +
+ "\t appQueue: " + report.getQueue() + "\n" +
+ "\t appMasterRpcPort: " + report.getRpcPort() + "\n" +
+ "\t appStartTime: " + report.getStartTime() + "\n" +
+ "\t yarnAppState: " + report.getYarnApplicationState() + "\n" +
+ "\t distributedFinalState: " + report.getFinalApplicationStatus() + "\n" +
+ "\t appTrackingUrl: " + report.getTrackingUrl() + "\n" +
+ "\t appUser: " + report.getUser()
+ )
+
+ val state = report.getYarnApplicationState()
+ val dsStatus = report.getFinalApplicationStatus()
+ if (state == YarnApplicationState.FINISHED ||
+ state == YarnApplicationState.FAILED ||
+ state == YarnApplicationState.KILLED) {
+ return true
+ }
+ }
+ return true
+ }
+}
+
+object Client {
+ def main(argStrings: Array[String]) {
+ val args = new ClientArguments(argStrings)
+ SparkHadoopUtil.setYarnMode()
+ new Client(args).run
+ }
+
+ // Based on code from org.apache.hadoop.mapreduce.v2.util.MRApps
+ def populateHadoopClasspath(conf: Configuration, env: HashMap[String, String]) {
+ for (c <- conf.getStrings(YarnConfiguration.YARN_APPLICATION_CLASSPATH)) {
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, c.trim)
+ }
+ }
+}
diff --git a/yarn/src/main/scala/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/spark/deploy/yarn/ClientArguments.scala
new file mode 100644
index 0000000000..67aff03781
--- /dev/null
+++ b/yarn/src/main/scala/spark/deploy/yarn/ClientArguments.scala
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package spark.deploy.yarn
+
+import spark.util.MemoryParam
+import spark.util.IntParam
+import collection.mutable.{ArrayBuffer, HashMap}
+import spark.scheduler.{InputFormatInfo, SplitInfo}
+
+// TODO: Add code and support for ensuring that yarn resource 'asks' are location aware !
+class ClientArguments(val args: Array[String]) {
+ var userJar: String = null
+ var userClass: String = null
+ var userArgs: Seq[String] = Seq[String]()
+ var workerMemory = 1024
+ var workerCores = 1
+ var numWorkers = 2
+ var amQueue = System.getProperty("QUEUE", "default")
+ var amMemory: Int = 512
+ // TODO
+ var inputFormatInfo: List[InputFormatInfo] = null
+
+ parseArgs(args.toList)
+
+ private def parseArgs(inputArgs: List[String]): Unit = {
+ val userArgsBuffer: ArrayBuffer[String] = new ArrayBuffer[String]()
+ val inputFormatMap: HashMap[String, InputFormatInfo] = new HashMap[String, InputFormatInfo]()
+
+ var args = inputArgs
+
+ while (! args.isEmpty) {
+
+ args match {
+ case ("--jar") :: value :: tail =>
+ userJar = value
+ args = tail
+
+ case ("--class") :: value :: tail =>
+ userClass = value
+ args = tail
+
+ case ("--args") :: value :: tail =>
+ userArgsBuffer += value
+ args = tail
+
+ case ("--master-memory") :: MemoryParam(value) :: tail =>
+ amMemory = value
+ args = tail
+
+ case ("--num-workers") :: IntParam(value) :: tail =>
+ numWorkers = value
+ args = tail
+
+ case ("--worker-memory") :: MemoryParam(value) :: tail =>
+ workerMemory = value
+ args = tail
+
+ case ("--worker-cores") :: IntParam(value) :: tail =>
+ workerCores = value
+ args = tail
+
+ case ("--queue") :: value :: tail =>
+ amQueue = value
+ args = tail
+
+ case Nil =>
+ if (userJar == null || userClass == null) {
+ printUsageAndExit(1)
+ }
+
+ case _ =>
+ printUsageAndExit(1, args)
+ }
+ }
+
+ userArgs = userArgsBuffer.readOnly
+ inputFormatInfo = inputFormatMap.values.toList
+ }
+
+
+ def printUsageAndExit(exitCode: Int, unknownParam: Any = null) {
+ if (unknownParam != null) {
+ System.err.println("Unknown/unsupported param " + unknownParam)
+ }
+ System.err.println(
+ "Usage: spark.deploy.yarn.Client [options] \n" +
+ "Options:\n" +
+ " --jar JAR_PATH Path to your application's JAR file (required)\n" +
+ " --class CLASS_NAME Name of your application's main class (required)\n" +
+ " --args ARGS Arguments to be passed to your application's main class.\n" +
+ " Mutliple invocations are possible, each will be passed in order.\n" +
+ " --num-workers NUM Number of workers to start (Default: 2)\n" +
+ " --worker-cores NUM Number of cores for the workers (Default: 1). This is unsused right now.\n" +
+ " --master-memory MEM Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)\n" +
+ " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n" +
+ " --queue QUEUE The hadoop queue to use for allocation requests (Default: 'default')"
+ )
+ System.exit(exitCode)
+ }
+
+}
diff --git a/yarn/src/main/scala/spark/deploy/yarn/WorkerRunnable.scala b/yarn/src/main/scala/spark/deploy/yarn/WorkerRunnable.scala
new file mode 100644
index 0000000000..f458f2f6a1
--- /dev/null
+++ b/yarn/src/main/scala/spark/deploy/yarn/WorkerRunnable.scala
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package spark.deploy.yarn
+
+import java.net.URI
+import java.nio.ByteBuffer
+import java.security.PrivilegedExceptionAction
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
+import org.apache.hadoop.io.DataOutputBuffer
+import org.apache.hadoop.net.NetUtils
+import org.apache.hadoop.security.UserGroupInformation
+import org.apache.hadoop.yarn.api._
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.api.protocolrecords._
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.ipc.YarnRPC
+import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records, ProtoUtils}
+import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
+
+import scala.collection.JavaConversions._
+import scala.collection.mutable.HashMap
+
+import spark.{Logging, Utils}
+
+class WorkerRunnable(container: Container, conf: Configuration, masterAddress: String,
+ slaveId: String, hostname: String, workerMemory: Int, workerCores: Int)
+ extends Runnable with Logging {
+
+ var rpc: YarnRPC = YarnRPC.create(conf)
+ var cm: ContainerManager = null
+ val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
+
+ def run = {
+ logInfo("Starting Worker Container")
+ cm = connectToCM
+ startContainer
+ }
+
+ def startContainer = {
+ logInfo("Setting up ContainerLaunchContext")
+
+ val ctx = Records.newRecord(classOf[ContainerLaunchContext])
+ .asInstanceOf[ContainerLaunchContext]
+
+ ctx.setContainerId(container.getId())
+ ctx.setResource(container.getResource())
+ val localResources = prepareLocalResources
+ ctx.setLocalResources(localResources)
+
+ val env = prepareEnvironment
+ ctx.setEnvironment(env)
+
+ // Extra options for the JVM
+ var JAVA_OPTS = ""
+ // Set the JVM memory
+ val workerMemoryString = workerMemory + "m"
+ JAVA_OPTS += "-Xms" + workerMemoryString + " -Xmx" + workerMemoryString + " "
+ if (env.isDefinedAt("SPARK_JAVA_OPTS")) {
+ JAVA_OPTS += env("SPARK_JAVA_OPTS") + " "
+ }
+ // Commenting it out for now - so that people can refer to the properties if required. Remove it once cpuset version is pushed out.
+ // The context is, default gc for server class machines end up using all cores to do gc - hence if there are multiple containers in same
+ // node, spark gc effects all other containers performance (which can also be other spark containers)
+ // Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in multi-tenant environments. Not sure how default java gc behaves if it is
+ // limited to subset of cores on a node.
+/*
+ else {
+ // If no java_opts specified, default to using -XX:+CMSIncrementalMode
+ // It might be possible that other modes/config is being done in SPARK_JAVA_OPTS, so we dont want to mess with it.
+ // In our expts, using (default) throughput collector has severe perf ramnifications in multi-tennent machines
+ // The options are based on
+ // http://www.oracle.com/technetwork/java/gc-tuning-5-138395.html#0.0.0.%20When%20to%20Use%20the%20Concurrent%20Low%20Pause%20Collector|outline
+ JAVA_OPTS += " -XX:+UseConcMarkSweepGC "
+ JAVA_OPTS += " -XX:+CMSIncrementalMode "
+ JAVA_OPTS += " -XX:+CMSIncrementalPacing "
+ JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 "
+ JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 "
+ }
+*/
+
+ ctx.setUser(UserGroupInformation.getCurrentUser().getShortUserName())
+
+ val credentials = UserGroupInformation.getCurrentUser().getCredentials()
+ val dob = new DataOutputBuffer()
+ credentials.writeTokenStorageToStream(dob)
+ ctx.setContainerTokens(ByteBuffer.wrap(dob.getData()))
+
+ var javaCommand = "java";
+ val javaHome = System.getenv("JAVA_HOME")
+ if (javaHome != null && !javaHome.isEmpty()) {
+ javaCommand = Environment.JAVA_HOME.$() + "/bin/java"
+ }
+
+ val commands = List[String](javaCommand +
+ " -server " +
+ // Kill if OOM is raised - leverage yarn's failure handling to cause rescheduling.
+ // Not killing the task leaves various aspects of the worker and (to some extent) the jvm in an inconsistent state.
+ // TODO: If the OOM is not recoverable by rescheduling it on different node, then do 'something' to fail job ... akin to blacklisting trackers in mapred ?
+ " -XX:OnOutOfMemoryError='kill %p' " +
+ JAVA_OPTS +
+ " spark.executor.StandaloneExecutorBackend " +
+ masterAddress + " " +
+ slaveId + " " +
+ hostname + " " +
+ workerCores +
+ " 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" +
+ " 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr")
+ logInfo("Setting up worker with commands: " + commands)
+ ctx.setCommands(commands)
+
+ // Send the start request to the ContainerManager
+ val startReq = Records.newRecord(classOf[StartContainerRequest])
+ .asInstanceOf[StartContainerRequest]
+ startReq.setContainerLaunchContext(ctx)
+ cm.startContainer(startReq)
+ }
+
+
+ def prepareLocalResources: HashMap[String, LocalResource] = {
+ logInfo("Preparing Local resources")
+ val locaResources = HashMap[String, LocalResource]()
+
+ // Spark JAR
+ val sparkJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
+ sparkJarResource.setType(LocalResourceType.FILE)
+ sparkJarResource.setVisibility(LocalResourceVisibility.APPLICATION)
+ sparkJarResource.setResource(ConverterUtils.getYarnUrlFromURI(
+ new URI(System.getenv("SPARK_YARN_JAR_PATH"))))
+ sparkJarResource.setTimestamp(System.getenv("SPARK_YARN_JAR_TIMESTAMP").toLong)
+ sparkJarResource.setSize(System.getenv("SPARK_YARN_JAR_SIZE").toLong)
+ locaResources("spark.jar") = sparkJarResource
+ // User JAR
+ val userJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
+ userJarResource.setType(LocalResourceType.FILE)
+ userJarResource.setVisibility(LocalResourceVisibility.APPLICATION)
+ userJarResource.setResource(ConverterUtils.getYarnUrlFromURI(
+ new URI(System.getenv("SPARK_YARN_USERJAR_PATH"))))
+ userJarResource.setTimestamp(System.getenv("SPARK_YARN_USERJAR_TIMESTAMP").toLong)
+ userJarResource.setSize(System.getenv("SPARK_YARN_USERJAR_SIZE").toLong)
+ locaResources("app.jar") = userJarResource
+
+ // Log4j conf - if available
+ if (System.getenv("SPARK_YARN_LOG4J_PATH") != null) {
+ val log4jConfResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
+ log4jConfResource.setType(LocalResourceType.FILE)
+ log4jConfResource.setVisibility(LocalResourceVisibility.APPLICATION)
+ log4jConfResource.setResource(ConverterUtils.getYarnUrlFromURI(
+ new URI(System.getenv("SPARK_YARN_LOG4J_PATH"))))
+ log4jConfResource.setTimestamp(System.getenv("SPARK_YARN_LOG4J_TIMESTAMP").toLong)
+ log4jConfResource.setSize(System.getenv("SPARK_YARN_LOG4J_SIZE").toLong)
+ locaResources("log4j.properties") = log4jConfResource
+ }
+
+
+ logInfo("Prepared Local resources " + locaResources)
+ return locaResources
+ }
+
+ def prepareEnvironment: HashMap[String, String] = {
+ val env = new HashMap[String, String]()
+
+ // If log4j present, ensure ours overrides all others
+ if (System.getenv("SPARK_YARN_LOG4J_PATH") != null) {
+ // Which is correct ?
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./log4j.properties")
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./")
+ }
+
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./*")
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, "$CLASSPATH")
+ Client.populateHadoopClasspath(yarnConf, env)
+
+ System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v }
+ return env
+ }
+
+ def connectToCM: ContainerManager = {
+ val cmHostPortStr = container.getNodeId().getHost() + ":" + container.getNodeId().getPort()
+ val cmAddress = NetUtils.createSocketAddr(cmHostPortStr)
+ logInfo("Connecting to ContainerManager at " + cmHostPortStr)
+
+ // use doAs and remoteUser here so we can add the container token and not
+ // pollute the current users credentials with all of the individual container tokens
+ val user = UserGroupInformation.createRemoteUser(container.getId().toString());
+ val containerToken = container.getContainerToken();
+ if (containerToken != null) {
+ user.addToken(ProtoUtils.convertFromProtoFormat(containerToken, cmAddress))
+ }
+
+ val proxy = user
+ .doAs(new PrivilegedExceptionAction[ContainerManager] {
+ def run: ContainerManager = {
+ return rpc.getProxy(classOf[ContainerManager],
+ cmAddress, conf).asInstanceOf[ContainerManager]
+ }
+ });
+ return proxy;
+ }
+
+}
diff --git a/yarn/src/main/scala/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/src/main/scala/spark/deploy/yarn/YarnAllocationHandler.scala
new file mode 100644
index 0000000000..b0af8baf08
--- /dev/null
+++ b/yarn/src/main/scala/spark/deploy/yarn/YarnAllocationHandler.scala
@@ -0,0 +1,564 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package spark.deploy.yarn
+
+import spark.{Logging, Utils}
+import spark.scheduler.SplitInfo
+import scala.collection
+import org.apache.hadoop.yarn.api.records.{AMResponse, ApplicationAttemptId, ContainerId, Priority, Resource, ResourceRequest, ContainerStatus, Container}
+import spark.scheduler.cluster.{ClusterScheduler, StandaloneSchedulerBackend}
+import org.apache.hadoop.yarn.api.protocolrecords.{AllocateRequest, AllocateResponse}
+import org.apache.hadoop.yarn.util.{RackResolver, Records}
+import java.util.concurrent.{CopyOnWriteArrayList, ConcurrentHashMap}
+import java.util.concurrent.atomic.AtomicInteger
+import org.apache.hadoop.yarn.api.AMRMProtocol
+import collection.JavaConversions._
+import collection.mutable.{ArrayBuffer, HashMap, HashSet}
+import org.apache.hadoop.conf.Configuration
+import java.util.{Collections, Set => JSet}
+import java.lang.{Boolean => JBoolean}
+
+object AllocationType extends Enumeration ("HOST", "RACK", "ANY") {
+ type AllocationType = Value
+ val HOST, RACK, ANY = Value
+}
+
+// too many params ? refactor it 'somehow' ?
+// needs to be mt-safe
+// Need to refactor this to make it 'cleaner' ... right now, all computation is reactive : should make it
+// more proactive and decoupled.
+// Note that right now, we assume all node asks as uniform in terms of capabilities and priority
+// Refer to http://developer.yahoo.com/blogs/hadoop/posts/2011/03/mapreduce-nextgen-scheduler/ for more info
+// on how we are requesting for containers.
+private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceManager: AMRMProtocol,
+ val appAttemptId: ApplicationAttemptId,
+ val maxWorkers: Int, val workerMemory: Int, val workerCores: Int,
+ val preferredHostToCount: Map[String, Int],
+ val preferredRackToCount: Map[String, Int])
+ extends Logging {
+
+
+ // These three are locked on allocatedHostToContainersMap. Complementary data structures
+ // allocatedHostToContainersMap : containers which are running : host, Set<containerid>
+ // allocatedContainerToHostMap: container to host mapping
+ private val allocatedHostToContainersMap = new HashMap[String, collection.mutable.Set[ContainerId]]()
+ private val allocatedContainerToHostMap = new HashMap[ContainerId, String]()
+ // allocatedRackCount is populated ONLY if allocation happens (or decremented if this is an allocated node)
+ // As with the two data structures above, tightly coupled with them, and to be locked on allocatedHostToContainersMap
+ private val allocatedRackCount = new HashMap[String, Int]()
+
+ // containers which have been released.
+ private val releasedContainerList = new CopyOnWriteArrayList[ContainerId]()
+ // containers to be released in next request to RM
+ private val pendingReleaseContainers = new ConcurrentHashMap[ContainerId, Boolean]
+
+ private val numWorkersRunning = new AtomicInteger()
+ // Used to generate a unique id per worker
+ private val workerIdCounter = new AtomicInteger()
+ private val lastResponseId = new AtomicInteger()
+
+ def getNumWorkersRunning: Int = numWorkersRunning.intValue
+
+
+ def isResourceConstraintSatisfied(container: Container): Boolean = {
+ container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
+ }
+
+ def allocateContainers(workersToRequest: Int) {
+ // We need to send the request only once from what I understand ... but for now, not modifying this much.
+
+ // Keep polling the Resource Manager for containers
+ val amResp = allocateWorkerResources(workersToRequest).getAMResponse
+
+ val _allocatedContainers = amResp.getAllocatedContainers()
+ if (_allocatedContainers.size > 0) {
+
+
+ logDebug("Allocated " + _allocatedContainers.size + " containers, current count " +
+ numWorkersRunning.get() + ", to-be-released " + releasedContainerList +
+ ", pendingReleaseContainers : " + pendingReleaseContainers)
+ logDebug("Cluster Resources: " + amResp.getAvailableResources)
+
+ val hostToContainers = new HashMap[String, ArrayBuffer[Container]]()
+
+ // ignore if not satisfying constraints {
+ for (container <- _allocatedContainers) {
+ if (isResourceConstraintSatisfied(container)) {
+ // allocatedContainers += container
+
+ val host = container.getNodeId.getHost
+ val containers = hostToContainers.getOrElseUpdate(host, new ArrayBuffer[Container]())
+
+ containers += container
+ }
+ // Add all ignored containers to released list
+ else releasedContainerList.add(container.getId())
+ }
+
+ // Find the appropriate containers to use
+ // Slightly non trivial groupBy I guess ...
+ val dataLocalContainers = new HashMap[String, ArrayBuffer[Container]]()
+ val rackLocalContainers = new HashMap[String, ArrayBuffer[Container]]()
+ val offRackContainers = new HashMap[String, ArrayBuffer[Container]]()
+
+ for (candidateHost <- hostToContainers.keySet)
+ {
+ val maxExpectedHostCount = preferredHostToCount.getOrElse(candidateHost, 0)
+ val requiredHostCount = maxExpectedHostCount - allocatedContainersOnHost(candidateHost)
+
+ var remainingContainers = hostToContainers.get(candidateHost).getOrElse(null)
+ assert(remainingContainers != null)
+
+ if (requiredHostCount >= remainingContainers.size){
+ // Since we got <= required containers, add all to dataLocalContainers
+ dataLocalContainers.put(candidateHost, remainingContainers)
+ // all consumed
+ remainingContainers = null
+ }
+ else if (requiredHostCount > 0) {
+ // container list has more containers than we need for data locality.
+ // Split into two : data local container count of (remainingContainers.size - requiredHostCount)
+ // and rest as remainingContainer
+ val (dataLocal, remaining) = remainingContainers.splitAt(remainingContainers.size - requiredHostCount)
+ dataLocalContainers.put(candidateHost, dataLocal)
+ // remainingContainers = remaining
+
+ // yarn has nasty habit of allocating a tonne of containers on a host - discourage this :
+ // add remaining to release list. If we have insufficient containers, next allocation cycle
+ // will reallocate (but wont treat it as data local)
+ for (container <- remaining) releasedContainerList.add(container.getId())
+ remainingContainers = null
+ }
+
+ // now rack local
+ if (remainingContainers != null){
+ val rack = YarnAllocationHandler.lookupRack(conf, candidateHost)
+
+ if (rack != null){
+ val maxExpectedRackCount = preferredRackToCount.getOrElse(rack, 0)
+ val requiredRackCount = maxExpectedRackCount - allocatedContainersOnRack(rack) -
+ rackLocalContainers.get(rack).getOrElse(List()).size
+
+
+ if (requiredRackCount >= remainingContainers.size){
+ // Add all to dataLocalContainers
+ dataLocalContainers.put(rack, remainingContainers)
+ // all consumed
+ remainingContainers = null
+ }
+ else if (requiredRackCount > 0) {
+ // container list has more containers than we need for data locality.
+ // Split into two : data local container count of (remainingContainers.size - requiredRackCount)
+ // and rest as remainingContainer
+ val (rackLocal, remaining) = remainingContainers.splitAt(remainingContainers.size - requiredRackCount)
+ val existingRackLocal = rackLocalContainers.getOrElseUpdate(rack, new ArrayBuffer[Container]())
+
+ existingRackLocal ++= rackLocal
+ remainingContainers = remaining
+ }
+ }
+ }
+
+ // If still not consumed, then it is off rack host - add to that list.
+ if (remainingContainers != null){
+ offRackContainers.put(candidateHost, remainingContainers)
+ }
+ }
+
+ // Now that we have split the containers into various groups, go through them in order :
+ // first host local, then rack local and then off rack (everything else).
+ // Note that the list we create below tries to ensure that not all containers end up within a host
+ // if there are sufficiently large number of hosts/containers.
+
+ val allocatedContainers = new ArrayBuffer[Container](_allocatedContainers.size)
+ allocatedContainers ++= ClusterScheduler.prioritizeContainers(dataLocalContainers)
+ allocatedContainers ++= ClusterScheduler.prioritizeContainers(rackLocalContainers)
+ allocatedContainers ++= ClusterScheduler.prioritizeContainers(offRackContainers)
+
+ // Run each of the allocated containers
+ for (container <- allocatedContainers) {
+ val numWorkersRunningNow = numWorkersRunning.incrementAndGet()
+ val workerHostname = container.getNodeId.getHost
+ val containerId = container.getId
+
+ assert (container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD))
+
+ if (numWorkersRunningNow > maxWorkers) {
+ logInfo("Ignoring container " + containerId + " at host " + workerHostname +
+ " .. we already have required number of containers")
+ releasedContainerList.add(containerId)
+ // reset counter back to old value.
+ numWorkersRunning.decrementAndGet()
+ }
+ else {
+ // deallocate + allocate can result in reusing id's wrongly - so use a different counter (workerIdCounter)
+ val workerId = workerIdCounter.incrementAndGet().toString
+ val driverUrl = "akka://spark@%s:%s/user/%s".format(
+ System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
+ StandaloneSchedulerBackend.ACTOR_NAME)
+
+ logInfo("launching container on " + containerId + " host " + workerHostname)
+ // just to be safe, simply remove it from pendingReleaseContainers. Should not be there, but ..
+ pendingReleaseContainers.remove(containerId)
+
+ val rack = YarnAllocationHandler.lookupRack(conf, workerHostname)
+ allocatedHostToContainersMap.synchronized {
+ val containerSet = allocatedHostToContainersMap.getOrElseUpdate(workerHostname, new HashSet[ContainerId]())
+
+ containerSet += containerId
+ allocatedContainerToHostMap.put(containerId, workerHostname)
+ if (rack != null) allocatedRackCount.put(rack, allocatedRackCount.getOrElse(rack, 0) + 1)
+ }
+
+ new Thread(
+ new WorkerRunnable(container, conf, driverUrl, workerId,
+ workerHostname, workerMemory, workerCores)
+ ).start()
+ }
+ }
+ logDebug("After allocated " + allocatedContainers.size + " containers (orig : " +
+ _allocatedContainers.size + "), current count " + numWorkersRunning.get() +
+ ", to-be-released " + releasedContainerList + ", pendingReleaseContainers : " + pendingReleaseContainers)
+ }
+
+
+ val completedContainers = amResp.getCompletedContainersStatuses()
+ if (completedContainers.size > 0){
+ logDebug("Completed " + completedContainers.size + " containers, current count " + numWorkersRunning.get() +
+ ", to-be-released " + releasedContainerList + ", pendingReleaseContainers : " + pendingReleaseContainers)
+
+ for (completedContainer <- completedContainers){
+ val containerId = completedContainer.getContainerId
+
+ // Was this released by us ? If yes, then simply remove from containerSet and move on.
+ if (pendingReleaseContainers.containsKey(containerId)) {
+ pendingReleaseContainers.remove(containerId)
+ }
+ else {
+ // simply decrement count - next iteration of ReporterThread will take care of allocating !
+ numWorkersRunning.decrementAndGet()
+ logInfo("Container completed ? nodeId: " + containerId + ", state " + completedContainer.getState +
+ " httpaddress: " + completedContainer.getDiagnostics)
+ }
+
+ allocatedHostToContainersMap.synchronized {
+ if (allocatedContainerToHostMap.containsKey(containerId)) {
+ val host = allocatedContainerToHostMap.get(containerId).getOrElse(null)
+ assert (host != null)
+
+ val containerSet = allocatedHostToContainersMap.get(host).getOrElse(null)
+ assert (containerSet != null)
+
+ containerSet -= containerId
+ if (containerSet.isEmpty) allocatedHostToContainersMap.remove(host)
+ else allocatedHostToContainersMap.update(host, containerSet)
+
+ allocatedContainerToHostMap -= containerId
+
+ // doing this within locked context, sigh ... move to outside ?
+ val rack = YarnAllocationHandler.lookupRack(conf, host)
+ if (rack != null) {
+ val rackCount = allocatedRackCount.getOrElse(rack, 0) - 1
+ if (rackCount > 0) allocatedRackCount.put(rack, rackCount)
+ else allocatedRackCount.remove(rack)
+ }
+ }
+ }
+ }
+ logDebug("After completed " + completedContainers.size + " containers, current count " +
+ numWorkersRunning.get() + ", to-be-released " + releasedContainerList +
+ ", pendingReleaseContainers : " + pendingReleaseContainers)
+ }
+ }
+
+ def createRackResourceRequests(hostContainers: List[ResourceRequest]): List[ResourceRequest] = {
+ // First generate modified racks and new set of hosts under it : then issue requests
+ val rackToCounts = new HashMap[String, Int]()
+
+ // Within this lock - used to read/write to the rack related maps too.
+ for (container <- hostContainers) {
+ val candidateHost = container.getHostName
+ val candidateNumContainers = container.getNumContainers
+ assert(YarnAllocationHandler.ANY_HOST != candidateHost)
+
+ val rack = YarnAllocationHandler.lookupRack(conf, candidateHost)
+ if (rack != null) {
+ var count = rackToCounts.getOrElse(rack, 0)
+ count += candidateNumContainers
+ rackToCounts.put(rack, count)
+ }
+ }
+
+ val requestedContainers: ArrayBuffer[ResourceRequest] =
+ new ArrayBuffer[ResourceRequest](rackToCounts.size)
+ for ((rack, count) <- rackToCounts){
+ requestedContainers +=
+ createResourceRequest(AllocationType.RACK, rack, count, YarnAllocationHandler.PRIORITY)
+ }
+
+ requestedContainers.toList
+ }
+
+ def allocatedContainersOnHost(host: String): Int = {
+ var retval = 0
+ allocatedHostToContainersMap.synchronized {
+ retval = allocatedHostToContainersMap.getOrElse(host, Set()).size
+ }
+ retval
+ }
+
+ def allocatedContainersOnRack(rack: String): Int = {
+ var retval = 0
+ allocatedHostToContainersMap.synchronized {
+ retval = allocatedRackCount.getOrElse(rack, 0)
+ }
+ retval
+ }
+
+ private def allocateWorkerResources(numWorkers: Int): AllocateResponse = {
+
+ var resourceRequests: List[ResourceRequest] = null
+
+ // default.
+ if (numWorkers <= 0 || preferredHostToCount.isEmpty) {
+ logDebug("numWorkers: " + numWorkers + ", host preferences ? " + preferredHostToCount.isEmpty)
+ resourceRequests = List(
+ createResourceRequest(AllocationType.ANY, null, numWorkers, YarnAllocationHandler.PRIORITY))
+ }
+ else {
+ // request for all hosts in preferred nodes and for numWorkers -
+ // candidates.size, request by default allocation policy.
+ val hostContainerRequests: ArrayBuffer[ResourceRequest] =
+ new ArrayBuffer[ResourceRequest](preferredHostToCount.size)
+ for ((candidateHost, candidateCount) <- preferredHostToCount) {
+ val requiredCount = candidateCount - allocatedContainersOnHost(candidateHost)
+
+ if (requiredCount > 0) {
+ hostContainerRequests +=
+ createResourceRequest(AllocationType.HOST, candidateHost, requiredCount, YarnAllocationHandler.PRIORITY)
+ }
+ }
+ val rackContainerRequests: List[ResourceRequest] = createRackResourceRequests(hostContainerRequests.toList)
+
+ val anyContainerRequests: ResourceRequest =
+ createResourceRequest(AllocationType.ANY, null, numWorkers, YarnAllocationHandler.PRIORITY)
+
+ val containerRequests: ArrayBuffer[ResourceRequest] =
+ new ArrayBuffer[ResourceRequest](hostContainerRequests.size() + rackContainerRequests.size() + 1)
+
+ containerRequests ++= hostContainerRequests
+ containerRequests ++= rackContainerRequests
+ containerRequests += anyContainerRequests
+
+ resourceRequests = containerRequests.toList
+ }
+
+ val req = Records.newRecord(classOf[AllocateRequest])
+ req.setResponseId(lastResponseId.incrementAndGet)
+ req.setApplicationAttemptId(appAttemptId)
+
+ req.addAllAsks(resourceRequests)
+
+ val releasedContainerList = createReleasedContainerList()
+ req.addAllReleases(releasedContainerList)
+
+
+
+ if (numWorkers > 0) {
+ logInfo("Allocating " + numWorkers + " worker containers with " + (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + " of memory each.")
+ }
+ else {
+ logDebug("Empty allocation req .. release : " + releasedContainerList)
+ }
+
+ for (req <- resourceRequests) {
+ logInfo("rsrcRequest ... host : " + req.getHostName + ", numContainers : " + req.getNumContainers +
+ ", p = " + req.getPriority().getPriority + ", capability: " + req.getCapability)
+ }
+ resourceManager.allocate(req)
+ }
+
+
+ private def createResourceRequest(requestType: AllocationType.AllocationType,
+ resource:String, numWorkers: Int, priority: Int): ResourceRequest = {
+
+ // If hostname specified, we need atleast two requests - node local and rack local.
+ // There must be a third request - which is ANY : that will be specially handled.
+ requestType match {
+ case AllocationType.HOST => {
+ assert (YarnAllocationHandler.ANY_HOST != resource)
+
+ val hostname = resource
+ val nodeLocal = createResourceRequestImpl(hostname, numWorkers, priority)
+
+ // add to host->rack mapping
+ YarnAllocationHandler.populateRackInfo(conf, hostname)
+
+ nodeLocal
+ }
+
+ case AllocationType.RACK => {
+ val rack = resource
+ createResourceRequestImpl(rack, numWorkers, priority)
+ }
+
+ case AllocationType.ANY => {
+ createResourceRequestImpl(YarnAllocationHandler.ANY_HOST, numWorkers, priority)
+ }
+
+ case _ => throw new IllegalArgumentException("Unexpected/unsupported request type .. " + requestType)
+ }
+ }
+
+ private def createResourceRequestImpl(hostname:String, numWorkers: Int, priority: Int): ResourceRequest = {
+
+ val rsrcRequest = Records.newRecord(classOf[ResourceRequest])
+ val memCapability = Records.newRecord(classOf[Resource])
+ // There probably is some overhead here, let's reserve a bit more memory.
+ memCapability.setMemory(workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
+ rsrcRequest.setCapability(memCapability)
+
+ val pri = Records.newRecord(classOf[Priority])
+ pri.setPriority(priority)
+ rsrcRequest.setPriority(pri)
+
+ rsrcRequest.setHostName(hostname)
+
+ rsrcRequest.setNumContainers(java.lang.Math.max(numWorkers, 0))
+ rsrcRequest
+ }
+
+ def createReleasedContainerList(): ArrayBuffer[ContainerId] = {
+
+ val retval = new ArrayBuffer[ContainerId](1)
+ // iterator on COW list ...
+ for (container <- releasedContainerList.iterator()){
+ retval += container
+ }
+ // remove from the original list.
+ if (! retval.isEmpty) {
+ releasedContainerList.removeAll(retval)
+ for (v <- retval) pendingReleaseContainers.put(v, true)
+ logInfo("Releasing " + retval.size + " containers. pendingReleaseContainers : " +
+ pendingReleaseContainers)
+ }
+
+ retval
+ }
+}
+
+object YarnAllocationHandler {
+
+ val ANY_HOST = "*"
+ // all requests are issued with same priority : we do not (yet) have any distinction between
+ // request types (like map/reduce in hadoop for example)
+ val PRIORITY = 1
+
+ // Additional memory overhead - in mb
+ val MEMORY_OVERHEAD = 384
+
+ // host to rack map - saved from allocation requests
+ // We are expecting this not to change.
+ // Note that it is possible for this to change : and RM will indicate that to us via update
+ // response to allocate. But we are punting on handling that for now.
+ private val hostToRack = new ConcurrentHashMap[String, String]()
+ private val rackToHostSet = new ConcurrentHashMap[String, JSet[String]]()
+
+ def newAllocator(conf: Configuration,
+ resourceManager: AMRMProtocol, appAttemptId: ApplicationAttemptId,
+ args: ApplicationMasterArguments,
+ map: collection.Map[String, collection.Set[SplitInfo]]): YarnAllocationHandler = {
+
+ val (hostToCount, rackToCount) = generateNodeToWeight(conf, map)
+
+
+ new YarnAllocationHandler(conf, resourceManager, appAttemptId, args.numWorkers,
+ args.workerMemory, args.workerCores, hostToCount, rackToCount)
+ }
+
+ def newAllocator(conf: Configuration,
+ resourceManager: AMRMProtocol, appAttemptId: ApplicationAttemptId,
+ maxWorkers: Int, workerMemory: Int, workerCores: Int,
+ map: collection.Map[String, collection.Set[SplitInfo]]): YarnAllocationHandler = {
+
+ val (hostToCount, rackToCount) = generateNodeToWeight(conf, map)
+
+ new YarnAllocationHandler(conf, resourceManager, appAttemptId, maxWorkers,
+ workerMemory, workerCores, hostToCount, rackToCount)
+ }
+
+ // A simple method to copy the split info map.
+ private def generateNodeToWeight(conf: Configuration, input: collection.Map[String, collection.Set[SplitInfo]]) :
+ // host to count, rack to count
+ (Map[String, Int], Map[String, Int]) = {
+
+ if (input == null) return (Map[String, Int](), Map[String, Int]())
+
+ val hostToCount = new HashMap[String, Int]
+ val rackToCount = new HashMap[String, Int]
+
+ for ((host, splits) <- input) {
+ val hostCount = hostToCount.getOrElse(host, 0)
+ hostToCount.put(host, hostCount + splits.size)
+
+ val rack = lookupRack(conf, host)
+ if (rack != null){
+ val rackCount = rackToCount.getOrElse(host, 0)
+ rackToCount.put(host, rackCount + splits.size)
+ }
+ }
+
+ (hostToCount.toMap, rackToCount.toMap)
+ }
+
+ def lookupRack(conf: Configuration, host: String): String = {
+ if (! hostToRack.contains(host)) populateRackInfo(conf, host)
+ hostToRack.get(host)
+ }
+
+ def fetchCachedHostsForRack(rack: String): Option[Set[String]] = {
+ val set = rackToHostSet.get(rack)
+ if (set == null) return None
+
+ // No better way to get a Set[String] from JSet ?
+ val convertedSet: collection.mutable.Set[String] = set
+ Some(convertedSet.toSet)
+ }
+
+ def populateRackInfo(conf: Configuration, hostname: String) {
+ Utils.checkHost(hostname)
+
+ if (!hostToRack.containsKey(hostname)) {
+ // If there are repeated failures to resolve, all to an ignore list ?
+ val rackInfo = RackResolver.resolve(conf, hostname)
+ if (rackInfo != null && rackInfo.getNetworkLocation != null) {
+ val rack = rackInfo.getNetworkLocation
+ hostToRack.put(hostname, rack)
+ if (! rackToHostSet.containsKey(rack)) {
+ rackToHostSet.putIfAbsent(rack, Collections.newSetFromMap(new ConcurrentHashMap[String, JBoolean]()))
+ }
+ rackToHostSet.get(rack).add(hostname)
+
+ // Since RackResolver caches, we are disabling this for now ...
+ } /* else {
+ // right ? Else we will keep calling rack resolver in case we cant resolve rack info ...
+ hostToRack.put(hostname, null)
+ } */
+ }
+ }
+}
diff --git a/yarn/src/main/scala/spark/scheduler/cluster/YarnClusterScheduler.scala b/yarn/src/main/scala/spark/scheduler/cluster/YarnClusterScheduler.scala
new file mode 100644
index 0000000000..307d96111c
--- /dev/null
+++ b/yarn/src/main/scala/spark/scheduler/cluster/YarnClusterScheduler.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package spark.scheduler.cluster
+
+import spark._
+import spark.deploy.yarn.{ApplicationMaster, YarnAllocationHandler}
+import org.apache.hadoop.conf.Configuration
+
+/**
+ *
+ * This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of ApplicationMaster, etc is done
+ */
+private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) {
+
+ def this(sc: SparkContext) = this(sc, new Configuration())
+
+ // Nothing else for now ... initialize application master : which needs sparkContext to determine how to allocate
+ // Note that only the first creation of SparkContext influences (and ideally, there must be only one SparkContext, right ?)
+ // Subsequent creations are ignored - since nodes are already allocated by then.
+
+
+ // By default, rack is unknown
+ override def getRackForHost(hostPort: String): Option[String] = {
+ val host = Utils.parseHostPort(hostPort)._1
+ val retval = YarnAllocationHandler.lookupRack(conf, host)
+ if (retval != null) Some(retval) else None
+ }
+
+ // By default, if rack is unknown, return nothing
+ override def getCachedHostsForRack(rack: String): Option[Set[String]] = {
+ if (rack == None || rack == null) return None
+
+ YarnAllocationHandler.fetchCachedHostsForRack(rack)
+ }
+
+ override def postStartHook() {
+ val sparkContextInitialized = ApplicationMaster.sparkContextInitialized(sc)
+ if (sparkContextInitialized){
+ // Wait for a few seconds for the slaves to bootstrap and register with master - best case attempt
+ Thread.sleep(3000L)
+ }
+ logInfo("YarnClusterScheduler.postStartHook done")
+ }
+}