aboutsummaryrefslogtreecommitdiff
path: root/yarn
diff options
context:
space:
mode:
authorMarcelo Vanzin <vanzin@cloudera.com>2016-03-07 14:13:44 -0800
committerMarcelo Vanzin <vanzin@cloudera.com>2016-03-07 14:13:44 -0800
commite1fb857992074164dcaa02498c5a9604fac6f57e (patch)
tree5f2a9de0230df4ebd0ca7317c879472eb8d3fbbc /yarn
parente9e67b39abb23a88d8be2d0fea5b5fd93184a25b (diff)
downloadspark-e1fb857992074164dcaa02498c5a9604fac6f57e.tar.gz
spark-e1fb857992074164dcaa02498c5a9604fac6f57e.tar.bz2
spark-e1fb857992074164dcaa02498c5a9604fac6f57e.zip
[SPARK-529][CORE][YARN] Add type-safe config keys to SparkConf.
This is, in a way, the basics to enable SPARK-529 (which was closed as won't fix but I think is still valuable). In fact, Spark SQL created something for that, and this change basically factors out that code and inserts it into SparkConf, with some extra bells and whistles. To showcase the usage of this pattern, I modified the YARN backend to use the new config keys (defined in the new `config` package object under `o.a.s.deploy.yarn`). Most of the changes are mechanic, although logic had to be slightly modified in a handful of places. Author: Marcelo Vanzin <vanzin@cloudera.com> Closes #10205 from vanzin/conf-opts.
Diffstat (limited to 'yarn')
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala14
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala28
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala230
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala53
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala3
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala14
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala6
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala10
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala3
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala18
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala243
-rw-r--r--yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala32
-rw-r--r--yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala26
-rw-r--r--yarn/src/test/scala/org/apache/spark/scheduler/cluster/ExtensionServiceIntegrationSuite.scala4
14 files changed, 443 insertions, 241 deletions
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala
index 70b67d21ec..6e95bb9710 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala
@@ -27,6 +27,8 @@ import org.apache.hadoop.security.UserGroupInformation
import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.deploy.yarn.config._
+import org.apache.spark.internal.config._
import org.apache.spark.util.ThreadUtils
/*
@@ -60,11 +62,9 @@ private[yarn] class AMDelegationTokenRenewer(
private val hadoopUtil = YarnSparkHadoopUtil.get
- private val credentialsFile = sparkConf.get("spark.yarn.credentials.file")
- private val daysToKeepFiles =
- sparkConf.getInt("spark.yarn.credentials.file.retention.days", 5)
- private val numFilesToKeep =
- sparkConf.getInt("spark.yarn.credentials.file.retention.count", 5)
+ private val credentialsFile = sparkConf.get(CREDENTIALS_FILE_PATH)
+ private val daysToKeepFiles = sparkConf.get(CREDENTIALS_FILE_MAX_RETENTION)
+ private val numFilesToKeep = sparkConf.get(CREDENTIAL_FILE_MAX_COUNT)
private val freshHadoopConf =
hadoopUtil.getConfBypassingFSCache(hadoopConf, new Path(credentialsFile).toUri.getScheme)
@@ -76,8 +76,8 @@ private[yarn] class AMDelegationTokenRenewer(
*
*/
private[spark] def scheduleLoginFromKeytab(): Unit = {
- val principal = sparkConf.get("spark.yarn.principal")
- val keytab = sparkConf.get("spark.yarn.keytab")
+ val principal = sparkConf.get(PRINCIPAL).get
+ val keytab = sparkConf.get(KEYTAB).get
/**
* Schedule re-login and creation of new tokens. If tokens have already expired, this method
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 9f586bf4c1..7d7bf88b9e 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -32,6 +32,8 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.deploy.history.HistoryServer
+import org.apache.spark.deploy.yarn.config._
+import org.apache.spark.internal.config._
import org.apache.spark.rpc._
import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, YarnSchedulerBackend}
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
@@ -65,16 +67,15 @@ private[spark] class ApplicationMaster(
// allocation is enabled), with a minimum of 3.
private val maxNumExecutorFailures = {
- val defaultKey =
+ val effectiveNumExecutors =
if (Utils.isDynamicAllocationEnabled(sparkConf)) {
- "spark.dynamicAllocation.maxExecutors"
+ sparkConf.get(DYN_ALLOCATION_MAX_EXECUTORS)
} else {
- "spark.executor.instances"
+ sparkConf.get(EXECUTOR_INSTANCES).getOrElse(0)
}
- val effectiveNumExecutors = sparkConf.getInt(defaultKey, 0)
val defaultMaxNumExecutorFailures = math.max(3, 2 * effectiveNumExecutors)
- sparkConf.getInt("spark.yarn.max.executor.failures", defaultMaxNumExecutorFailures)
+ sparkConf.get(MAX_EXECUTOR_FAILURES).getOrElse(defaultMaxNumExecutorFailures)
}
@volatile private var exitCode = 0
@@ -95,14 +96,13 @@ private[spark] class ApplicationMaster(
private val heartbeatInterval = {
// Ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapses.
val expiryInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000)
- math.max(0, math.min(expiryInterval / 2,
- sparkConf.getTimeAsMs("spark.yarn.scheduler.heartbeat.interval-ms", "3s")))
+ math.max(0, math.min(expiryInterval / 2, sparkConf.get(RM_HEARTBEAT_INTERVAL)))
}
// Initial wait interval before allocator poll, to allow for quicker ramp up when executors are
// being requested.
private val initialAllocationInterval = math.min(heartbeatInterval,
- sparkConf.getTimeAsMs("spark.yarn.scheduler.initial-allocation.interval", "200ms"))
+ sparkConf.get(INITIAL_HEARTBEAT_INTERVAL))
// Next wait interval before allocator poll.
private var nextAllocationInterval = initialAllocationInterval
@@ -178,7 +178,7 @@ private[spark] class ApplicationMaster(
// If the credentials file config is present, we must periodically renew tokens. So create
// a new AMDelegationTokenRenewer
- if (sparkConf.contains("spark.yarn.credentials.file")) {
+ if (sparkConf.contains(CREDENTIALS_FILE_PATH.key)) {
delegationTokenRenewerOption = Some(new AMDelegationTokenRenewer(sparkConf, yarnConf))
// If a principal and keytab have been set, use that to create new credentials for executors
// periodically
@@ -275,7 +275,7 @@ private[spark] class ApplicationMaster(
val appId = client.getAttemptId().getApplicationId().toString()
val attemptId = client.getAttemptId().getAttemptId().toString()
val historyAddress =
- sparkConf.getOption("spark.yarn.historyServer.address")
+ sparkConf.get(HISTORY_SERVER_ADDRESS)
.map { text => SparkHadoopUtil.get.substituteHadoopVariables(text, yarnConf) }
.map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}/${attemptId}" }
.getOrElse("")
@@ -355,7 +355,7 @@ private[spark] class ApplicationMaster(
private def launchReporterThread(): Thread = {
// The number of failures in a row until Reporter thread give up
- val reporterMaxFailures = sparkConf.getInt("spark.yarn.scheduler.reporterThread.maxFailures", 5)
+ val reporterMaxFailures = sparkConf.get(MAX_REPORTER_THREAD_FAILURES)
val t = new Thread {
override def run() {
@@ -429,7 +429,7 @@ private[spark] class ApplicationMaster(
private def cleanupStagingDir(fs: FileSystem) {
var stagingDirPath: Path = null
try {
- val preserveFiles = sparkConf.getBoolean("spark.yarn.preserve.staging.files", false)
+ val preserveFiles = sparkConf.get(PRESERVE_STAGING_FILES)
if (!preserveFiles) {
stagingDirPath = new Path(System.getenv("SPARK_YARN_STAGING_DIR"))
if (stagingDirPath == null) {
@@ -448,7 +448,7 @@ private[spark] class ApplicationMaster(
private def waitForSparkContextInitialized(): SparkContext = {
logInfo("Waiting for spark context initialization")
sparkContextRef.synchronized {
- val totalWaitTime = sparkConf.getTimeAsMs("spark.yarn.am.waitTime", "100s")
+ val totalWaitTime = sparkConf.get(AM_MAX_WAIT_TIME)
val deadline = System.currentTimeMillis() + totalWaitTime
while (sparkContextRef.get() == null && System.currentTimeMillis < deadline && !finished) {
@@ -473,7 +473,7 @@ private[spark] class ApplicationMaster(
// Spark driver should already be up since it launched us, but we don't want to
// wait forever, so wait 100 seconds max to match the cluster mode setting.
- val totalWaitTimeMs = sparkConf.getTimeAsMs("spark.yarn.am.waitTime", "100s")
+ val totalWaitTimeMs = sparkConf.get(AM_MAX_WAIT_TIME)
val deadline = System.currentTimeMillis + totalWaitTimeMs
while (!driverUp && !finished && System.currentTimeMillis < deadline) {
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index be45e9597f..36073de90d 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -51,6 +51,8 @@ import org.apache.hadoop.yarn.util.Records
import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkException}
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.deploy.yarn.config._
+import org.apache.spark.internal.config._
import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle, YarnCommandBuilderUtils}
import org.apache.spark.util.Utils
@@ -87,8 +89,7 @@ private[spark] class Client(
}
}
}
- private val fireAndForget = isClusterMode &&
- !sparkConf.getBoolean("spark.yarn.submit.waitAppCompletion", true)
+ private val fireAndForget = isClusterMode && sparkConf.get(WAIT_FOR_APP_COMPLETION)
private var appId: ApplicationId = null
@@ -156,7 +157,7 @@ private[spark] class Client(
private def cleanupStagingDir(appId: ApplicationId): Unit = {
val appStagingDir = getAppStagingDir(appId)
try {
- val preserveFiles = sparkConf.getBoolean("spark.yarn.preserve.staging.files", false)
+ val preserveFiles = sparkConf.get(PRESERVE_STAGING_FILES)
val stagingDirPath = new Path(appStagingDir)
val fs = FileSystem.get(hadoopConf)
if (!preserveFiles && fs.exists(stagingDirPath)) {
@@ -181,39 +182,36 @@ private[spark] class Client(
appContext.setQueue(args.amQueue)
appContext.setAMContainerSpec(containerContext)
appContext.setApplicationType("SPARK")
- sparkConf.getOption(CONF_SPARK_YARN_APPLICATION_TAGS)
- .map(StringUtils.getTrimmedStringCollection(_))
- .filter(!_.isEmpty())
- .foreach { tagCollection =>
- try {
- // The setApplicationTags method was only introduced in Hadoop 2.4+, so we need to use
- // reflection to set it, printing a warning if a tag was specified but the YARN version
- // doesn't support it.
- val method = appContext.getClass().getMethod(
- "setApplicationTags", classOf[java.util.Set[String]])
- method.invoke(appContext, new java.util.HashSet[String](tagCollection))
- } catch {
- case e: NoSuchMethodException =>
- logWarning(s"Ignoring $CONF_SPARK_YARN_APPLICATION_TAGS because this version of " +
- "YARN does not support it")
- }
+
+ sparkConf.get(APPLICATION_TAGS).foreach { tags =>
+ try {
+ // The setApplicationTags method was only introduced in Hadoop 2.4+, so we need to use
+ // reflection to set it, printing a warning if a tag was specified but the YARN version
+ // doesn't support it.
+ val method = appContext.getClass().getMethod(
+ "setApplicationTags", classOf[java.util.Set[String]])
+ method.invoke(appContext, new java.util.HashSet[String](tags.asJava))
+ } catch {
+ case e: NoSuchMethodException =>
+ logWarning(s"Ignoring ${APPLICATION_TAGS.key} because this version of " +
+ "YARN does not support it")
}
- sparkConf.getOption("spark.yarn.maxAppAttempts").map(_.toInt) match {
+ }
+ sparkConf.get(MAX_APP_ATTEMPTS) match {
case Some(v) => appContext.setMaxAppAttempts(v)
- case None => logDebug("spark.yarn.maxAppAttempts is not set. " +
+ case None => logDebug(s"${MAX_APP_ATTEMPTS.key} is not set. " +
"Cluster's default value will be used.")
}
- if (sparkConf.contains("spark.yarn.am.attemptFailuresValidityInterval")) {
+ sparkConf.get(ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS).foreach { interval =>
try {
- val interval = sparkConf.getTimeAsMs("spark.yarn.am.attemptFailuresValidityInterval")
val method = appContext.getClass().getMethod(
"setAttemptFailuresValidityInterval", classOf[Long])
method.invoke(appContext, interval: java.lang.Long)
} catch {
case e: NoSuchMethodException =>
- logWarning("Ignoring spark.yarn.am.attemptFailuresValidityInterval because the version " +
- "of YARN does not support it")
+ logWarning(s"Ignoring ${ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS.key} because " +
+ "the version of YARN does not support it")
}
}
@@ -221,28 +219,28 @@ private[spark] class Client(
capability.setMemory(args.amMemory + amMemoryOverhead)
capability.setVirtualCores(args.amCores)
- if (sparkConf.contains("spark.yarn.am.nodeLabelExpression")) {
- try {
- val amRequest = Records.newRecord(classOf[ResourceRequest])
- amRequest.setResourceName(ResourceRequest.ANY)
- amRequest.setPriority(Priority.newInstance(0))
- amRequest.setCapability(capability)
- amRequest.setNumContainers(1)
- val amLabelExpression = sparkConf.get("spark.yarn.am.nodeLabelExpression")
- val method = amRequest.getClass.getMethod("setNodeLabelExpression", classOf[String])
- method.invoke(amRequest, amLabelExpression)
-
- val setResourceRequestMethod =
- appContext.getClass.getMethod("setAMContainerResourceRequest", classOf[ResourceRequest])
- setResourceRequestMethod.invoke(appContext, amRequest)
- } catch {
- case e: NoSuchMethodException =>
- logWarning("Ignoring spark.yarn.am.nodeLabelExpression because the version " +
- "of YARN does not support it")
- appContext.setResource(capability)
- }
- } else {
- appContext.setResource(capability)
+ sparkConf.get(AM_NODE_LABEL_EXPRESSION) match {
+ case Some(expr) =>
+ try {
+ val amRequest = Records.newRecord(classOf[ResourceRequest])
+ amRequest.setResourceName(ResourceRequest.ANY)
+ amRequest.setPriority(Priority.newInstance(0))
+ amRequest.setCapability(capability)
+ amRequest.setNumContainers(1)
+ val method = amRequest.getClass.getMethod("setNodeLabelExpression", classOf[String])
+ method.invoke(amRequest, expr)
+
+ val setResourceRequestMethod =
+ appContext.getClass.getMethod("setAMContainerResourceRequest", classOf[ResourceRequest])
+ setResourceRequestMethod.invoke(appContext, amRequest)
+ } catch {
+ case e: NoSuchMethodException =>
+ logWarning(s"Ignoring ${AM_NODE_LABEL_EXPRESSION.key} because the version " +
+ "of YARN does not support it")
+ appContext.setResource(capability)
+ }
+ case None =>
+ appContext.setResource(capability)
}
appContext
@@ -345,8 +343,8 @@ private[spark] class Client(
YarnSparkHadoopUtil.get.obtainTokenForHiveMetastore(sparkConf, hadoopConf, credentials)
YarnSparkHadoopUtil.get.obtainTokenForHBase(sparkConf, hadoopConf, credentials)
- val replication = sparkConf.getInt("spark.yarn.submit.file.replication",
- fs.getDefaultReplication(dst)).toShort
+ val replication = sparkConf.get(STAGING_FILE_REPLICATION).map(_.toShort)
+ .getOrElse(fs.getDefaultReplication(dst))
val localResources = HashMap[String, LocalResource]()
FileSystem.mkdirs(fs, dst, new FsPermission(STAGING_DIR_PERMISSION))
@@ -419,7 +417,7 @@ private[spark] class Client(
logInfo("To enable the AM to login from keytab, credentials are being copied over to the AM" +
" via the YARN Secure Distributed Cache.")
val (_, localizedPath) = distribute(keytab,
- destName = Some(sparkConf.get("spark.yarn.keytab")),
+ destName = sparkConf.get(KEYTAB),
appMasterOnly = true)
require(localizedPath != null, "Keytab file already distributed.")
}
@@ -433,8 +431,8 @@ private[spark] class Client(
* (3) Spark property key to set if the scheme is not local
*/
List(
- (SPARK_JAR, sparkJar(sparkConf), CONF_SPARK_JAR),
- (APP_JAR, args.userJar, CONF_SPARK_USER_JAR),
+ (SPARK_JAR_NAME, sparkJar(sparkConf), SPARK_JAR.key),
+ (APP_JAR_NAME, args.userJar, APP_JAR.key),
("log4j.properties", oldLog4jConf.orNull, null)
).foreach { case (destName, path, confKey) =>
if (path != null && !path.trim().isEmpty()) {
@@ -472,7 +470,7 @@ private[spark] class Client(
}
}
if (cachedSecondaryJarLinks.nonEmpty) {
- sparkConf.set(CONF_SPARK_YARN_SECONDARY_JARS, cachedSecondaryJarLinks.mkString(","))
+ sparkConf.set(SECONDARY_JARS, cachedSecondaryJarLinks)
}
if (isClusterMode && args.primaryPyFile != null) {
@@ -586,7 +584,7 @@ private[spark] class Client(
val creds = new Credentials()
val nns = YarnSparkHadoopUtil.get.getNameNodesToAccess(sparkConf) + stagingDirPath
YarnSparkHadoopUtil.get.obtainTokensForNamenodes(
- nns, hadoopConf, creds, Some(sparkConf.get("spark.yarn.principal")))
+ nns, hadoopConf, creds, sparkConf.get(PRINCIPAL))
val t = creds.getAllTokens.asScala
.filter(_.getKind == DelegationTokenIdentifier.HDFS_DELEGATION_KIND)
.head
@@ -606,8 +604,7 @@ private[spark] class Client(
pySparkArchives: Seq[String]): HashMap[String, String] = {
logInfo("Setting up the launch environment for our AM container")
val env = new HashMap[String, String]()
- val extraCp = sparkConf.getOption("spark.driver.extraClassPath")
- populateClasspath(args, yarnConf, sparkConf, env, true, extraCp)
+ populateClasspath(args, yarnConf, sparkConf, env, true, sparkConf.get(DRIVER_CLASS_PATH))
env("SPARK_YARN_MODE") = "true"
env("SPARK_YARN_STAGING_DIR") = stagingDir
env("SPARK_USER") = UserGroupInformation.getCurrentUser().getShortUserName()
@@ -615,11 +612,10 @@ private[spark] class Client(
val remoteFs = FileSystem.get(hadoopConf)
val stagingDirPath = new Path(remoteFs.getHomeDirectory, stagingDir)
val credentialsFile = "credentials-" + UUID.randomUUID().toString
- sparkConf.set(
- "spark.yarn.credentials.file", new Path(stagingDirPath, credentialsFile).toString)
+ sparkConf.set(CREDENTIALS_FILE_PATH, new Path(stagingDirPath, credentialsFile).toString)
logInfo(s"Credentials file set to: $credentialsFile")
val renewalInterval = getTokenRenewalInterval(stagingDirPath)
- sparkConf.set("spark.yarn.token.renewal.interval", renewalInterval.toString)
+ sparkConf.set(TOKEN_RENEWAL_INTERVAL, renewalInterval)
}
// Pick up any environment variables for the AM provided through spark.yarn.appMasterEnv.*
@@ -713,7 +709,7 @@ private[spark] class Client(
val appId = newAppResponse.getApplicationId
val appStagingDir = getAppStagingDir(appId)
val pySparkArchives =
- if (sparkConf.getBoolean("spark.yarn.isPython", false)) {
+ if (sparkConf.get(IS_PYTHON_APP)) {
findPySparkArchives()
} else {
Nil
@@ -766,36 +762,33 @@ private[spark] class Client(
// Include driver-specific java options if we are launching a driver
if (isClusterMode) {
- val driverOpts = sparkConf.getOption("spark.driver.extraJavaOptions")
- .orElse(sys.env.get("SPARK_JAVA_OPTS"))
+ val driverOpts = sparkConf.get(DRIVER_JAVA_OPTIONS).orElse(sys.env.get("SPARK_JAVA_OPTS"))
driverOpts.foreach { opts =>
javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell)
}
- val libraryPaths = Seq(sys.props.get("spark.driver.extraLibraryPath"),
+ val libraryPaths = Seq(sparkConf.get(DRIVER_LIBRARY_PATH),
sys.props.get("spark.driver.libraryPath")).flatten
if (libraryPaths.nonEmpty) {
prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(libraryPaths)))
}
- if (sparkConf.getOption("spark.yarn.am.extraJavaOptions").isDefined) {
- logWarning("spark.yarn.am.extraJavaOptions will not take effect in cluster mode")
+ if (sparkConf.get(AM_JAVA_OPTIONS).isDefined) {
+ logWarning(s"${AM_JAVA_OPTIONS.key} will not take effect in cluster mode")
}
} else {
// Validate and include yarn am specific java options in yarn-client mode.
- val amOptsKey = "spark.yarn.am.extraJavaOptions"
- val amOpts = sparkConf.getOption(amOptsKey)
- amOpts.foreach { opts =>
+ sparkConf.get(AM_JAVA_OPTIONS).foreach { opts =>
if (opts.contains("-Dspark")) {
- val msg = s"$amOptsKey is not allowed to set Spark options (was '$opts'). "
+ val msg = s"$${amJavaOptions.key} is not allowed to set Spark options (was '$opts'). "
throw new SparkException(msg)
}
if (opts.contains("-Xmx") || opts.contains("-Xms")) {
- val msg = s"$amOptsKey is not allowed to alter memory settings (was '$opts')."
+ val msg = s"$${amJavaOptions.key} is not allowed to alter memory settings (was '$opts')."
throw new SparkException(msg)
}
javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell)
}
- sparkConf.getOption("spark.yarn.am.extraLibraryPath").foreach { paths =>
+ sparkConf.get(AM_LIBRARY_PATH).foreach { paths =>
prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(paths))))
}
}
@@ -883,17 +876,10 @@ private[spark] class Client(
}
def setupCredentials(): Unit = {
- loginFromKeytab = args.principal != null || sparkConf.contains("spark.yarn.principal")
+ loginFromKeytab = args.principal != null || sparkConf.contains(PRINCIPAL.key)
if (loginFromKeytab) {
- principal =
- if (args.principal != null) args.principal else sparkConf.get("spark.yarn.principal")
- keytab = {
- if (args.keytab != null) {
- args.keytab
- } else {
- sparkConf.getOption("spark.yarn.keytab").orNull
- }
- }
+ principal = Option(args.principal).orElse(sparkConf.get(PRINCIPAL)).get
+ keytab = Option(args.keytab).orElse(sparkConf.get(KEYTAB)).orNull
require(keytab != null, "Keytab must be specified when principal is specified.")
logInfo("Attempting to login to the Kerberos" +
@@ -902,8 +888,8 @@ private[spark] class Client(
// Generate a file name that can be used for the keytab file, that does not conflict
// with any user file.
val keytabFileName = f.getName + "-" + UUID.randomUUID().toString
- sparkConf.set("spark.yarn.keytab", keytabFileName)
- sparkConf.set("spark.yarn.principal", principal)
+ sparkConf.set(KEYTAB.key, keytabFileName)
+ sparkConf.set(PRINCIPAL.key, principal)
}
credentials = UserGroupInformation.getCurrentUser.getCredentials
}
@@ -923,7 +909,7 @@ private[spark] class Client(
appId: ApplicationId,
returnOnRunning: Boolean = false,
logApplicationReport: Boolean = true): (YarnApplicationState, FinalApplicationStatus) = {
- val interval = sparkConf.getLong("spark.yarn.report.interval", 1000)
+ val interval = sparkConf.get(REPORT_INTERVAL)
var lastState: YarnApplicationState = null
while (true) {
Thread.sleep(interval)
@@ -1071,14 +1057,14 @@ object Client extends Logging {
val args = new ClientArguments(argStrings, sparkConf)
// to maintain backwards-compatibility
if (!Utils.isDynamicAllocationEnabled(sparkConf)) {
- sparkConf.setIfMissing("spark.executor.instances", args.numExecutors.toString)
+ sparkConf.setIfMissing(EXECUTOR_INSTANCES, args.numExecutors)
}
new Client(args, sparkConf).run()
}
// Alias for the Spark assembly jar and the user jar
- val SPARK_JAR: String = "__spark__.jar"
- val APP_JAR: String = "__app__.jar"
+ val SPARK_JAR_NAME: String = "__spark__.jar"
+ val APP_JAR_NAME: String = "__app__.jar"
// URI scheme that identifies local resources
val LOCAL_SCHEME = "local"
@@ -1087,20 +1073,8 @@ object Client extends Logging {
val SPARK_STAGING: String = ".sparkStaging"
// Location of any user-defined Spark jars
- val CONF_SPARK_JAR = "spark.yarn.jar"
val ENV_SPARK_JAR = "SPARK_JAR"
- // Internal config to propagate the location of the user's jar to the driver/executors
- val CONF_SPARK_USER_JAR = "spark.yarn.user.jar"
-
- // Internal config to propagate the locations of any extra jars to add to the classpath
- // of the executors
- val CONF_SPARK_YARN_SECONDARY_JARS = "spark.yarn.secondary.jars"
-
- // Comma-separated list of strings to pass through as YARN application tags appearing
- // in YARN ApplicationReports, which can be used for filtering when querying YARN.
- val CONF_SPARK_YARN_APPLICATION_TAGS = "spark.yarn.tags"
-
// Staging directory is private! -> rwx--------
val STAGING_DIR_PERMISSION: FsPermission =
FsPermission.createImmutable(Integer.parseInt("700", 8).toShort)
@@ -1125,23 +1099,23 @@ object Client extends Logging {
* Find the user-defined Spark jar if configured, or return the jar containing this
* class if not.
*
- * This method first looks in the SparkConf object for the CONF_SPARK_JAR key, and in the
+ * This method first looks in the SparkConf object for the spark.yarn.jar key, and in the
* user environment if that is not found (for backwards compatibility).
*/
private def sparkJar(conf: SparkConf): String = {
- if (conf.contains(CONF_SPARK_JAR)) {
- conf.get(CONF_SPARK_JAR)
- } else if (System.getenv(ENV_SPARK_JAR) != null) {
- logWarning(
- s"$ENV_SPARK_JAR detected in the system environment. This variable has been deprecated " +
- s"in favor of the $CONF_SPARK_JAR configuration variable.")
- System.getenv(ENV_SPARK_JAR)
- } else {
- SparkContext.jarOfClass(this.getClass).getOrElse(throw new SparkException("Could not "
- + "find jar containing Spark classes. The jar can be defined using the "
- + "spark.yarn.jar configuration option. If testing Spark, either set that option or "
- + "make sure SPARK_PREPEND_CLASSES is not set."))
- }
+ conf.get(SPARK_JAR).getOrElse(
+ if (System.getenv(ENV_SPARK_JAR) != null) {
+ logWarning(
+ s"$ENV_SPARK_JAR detected in the system environment. This variable has been deprecated " +
+ s"in favor of the ${SPARK_JAR.key} configuration variable.")
+ System.getenv(ENV_SPARK_JAR)
+ } else {
+ SparkContext.jarOfClass(this.getClass).getOrElse(throw new SparkException("Could not "
+ + "find jar containing Spark classes. The jar can be defined using the "
+ + s"${SPARK_JAR.key} configuration option. If testing Spark, either set that option "
+ + "or make sure SPARK_PREPEND_CLASSES is not set."))
+ }
+ )
}
/**
@@ -1240,7 +1214,7 @@ object Client extends Logging {
LOCALIZED_CONF_DIR, env)
}
- if (sparkConf.getBoolean("spark.yarn.user.classpath.first", false)) {
+ if (sparkConf.get(USER_CLASS_PATH_FIRST)) {
// in order to properly add the app jar when user classpath is first
// we have to do the mainJar separate in order to send the right thing
// into addFileToClasspath
@@ -1248,21 +1222,21 @@ object Client extends Logging {
if (args != null) {
getMainJarUri(Option(args.userJar))
} else {
- getMainJarUri(sparkConf.getOption(CONF_SPARK_USER_JAR))
+ getMainJarUri(sparkConf.get(APP_JAR))
}
- mainJar.foreach(addFileToClasspath(sparkConf, conf, _, APP_JAR, env))
+ mainJar.foreach(addFileToClasspath(sparkConf, conf, _, APP_JAR_NAME, env))
val secondaryJars =
if (args != null) {
- getSecondaryJarUris(Option(args.addJars))
+ getSecondaryJarUris(Option(args.addJars).map(_.split(",").toSeq))
} else {
- getSecondaryJarUris(sparkConf.getOption(CONF_SPARK_YARN_SECONDARY_JARS))
+ getSecondaryJarUris(sparkConf.get(SECONDARY_JARS))
}
secondaryJars.foreach { x =>
addFileToClasspath(sparkConf, conf, x, null, env)
}
}
- addFileToClasspath(sparkConf, conf, new URI(sparkJar(sparkConf)), SPARK_JAR, env)
+ addFileToClasspath(sparkConf, conf, new URI(sparkJar(sparkConf)), SPARK_JAR_NAME, env)
populateHadoopClasspath(conf, env)
sys.env.get(ENV_DIST_CLASSPATH).foreach { cp =>
addClasspathEntry(getClusterPath(sparkConf, cp), env)
@@ -1275,8 +1249,8 @@ object Client extends Logging {
* @param conf Spark configuration.
*/
def getUserClasspath(conf: SparkConf): Array[URI] = {
- val mainUri = getMainJarUri(conf.getOption(CONF_SPARK_USER_JAR))
- val secondaryUris = getSecondaryJarUris(conf.getOption(CONF_SPARK_YARN_SECONDARY_JARS))
+ val mainUri = getMainJarUri(conf.get(APP_JAR))
+ val secondaryUris = getSecondaryJarUris(conf.get(SECONDARY_JARS))
(mainUri ++ secondaryUris).toArray
}
@@ -1284,11 +1258,11 @@ object Client extends Logging {
mainJar.flatMap { path =>
val uri = Utils.resolveURI(path)
if (uri.getScheme == LOCAL_SCHEME) Some(uri) else None
- }.orElse(Some(new URI(APP_JAR)))
+ }.orElse(Some(new URI(APP_JAR_NAME)))
}
- private def getSecondaryJarUris(secondaryJars: Option[String]): Seq[URI] = {
- secondaryJars.map(_.split(",")).toSeq.flatten.map(new URI(_))
+ private def getSecondaryJarUris(secondaryJars: Option[Seq[String]]): Seq[URI] = {
+ secondaryJars.getOrElse(Nil).map(new URI(_))
}
/**
@@ -1345,8 +1319,8 @@ object Client extends Logging {
* If either config is not available, the input path is returned.
*/
def getClusterPath(conf: SparkConf, path: String): String = {
- val localPath = conf.get("spark.yarn.config.gatewayPath", null)
- val clusterPath = conf.get("spark.yarn.config.replacementPath", null)
+ val localPath = conf.get(GATEWAY_ROOT_PATH)
+ val clusterPath = conf.get(REPLACEMENT_ROOT_PATH)
if (localPath != null && clusterPath != null) {
path.replace(localPath, clusterPath)
} else {
@@ -1405,9 +1379,9 @@ object Client extends Logging {
*/
def isUserClassPathFirst(conf: SparkConf, isDriver: Boolean): Boolean = {
if (isDriver) {
- conf.getBoolean("spark.driver.userClassPathFirst", false)
+ conf.get(DRIVER_USER_CLASS_PATH_FIRST)
} else {
- conf.getBoolean("spark.executor.userClassPathFirst", false)
+ conf.get(EXECUTOR_USER_CLASS_PATH_FIRST)
}
}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
index a9f4374357..47b4cc3009 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
@@ -21,10 +21,15 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._
+import org.apache.spark.deploy.yarn.config._
+import org.apache.spark.internal.config._
import org.apache.spark.util.{IntParam, MemoryParam, Utils}
// TODO: Add code and support for ensuring that yarn resource 'tasks' are location aware !
-private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) {
+private[spark] class ClientArguments(
+ args: Array[String],
+ sparkConf: SparkConf) {
+
var addJars: String = null
var files: String = null
var archives: String = null
@@ -37,9 +42,9 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf)
var executorMemory = 1024 // MB
var executorCores = 1
var numExecutors = DEFAULT_NUMBER_EXECUTORS
- var amQueue = sparkConf.get("spark.yarn.queue", "default")
- var amMemory: Int = 512 // MB
- var amCores: Int = 1
+ var amQueue = sparkConf.get(QUEUE_NAME)
+ var amMemory: Int = _
+ var amCores: Int = _
var appName: String = "Spark"
var priority = 0
var principal: String = null
@@ -48,11 +53,6 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf)
private var driverMemory: Int = Utils.DEFAULT_DRIVER_MEM_MB // MB
private var driverCores: Int = 1
- private val driverMemOverheadKey = "spark.yarn.driver.memoryOverhead"
- private val amMemKey = "spark.yarn.am.memory"
- private val amMemOverheadKey = "spark.yarn.am.memoryOverhead"
- private val driverCoresKey = "spark.driver.cores"
- private val amCoresKey = "spark.yarn.am.cores"
private val isDynamicAllocationEnabled = Utils.isDynamicAllocationEnabled(sparkConf)
parseArgs(args.toList)
@@ -60,33 +60,33 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf)
validateArgs()
// Additional memory to allocate to containers
- val amMemoryOverheadConf = if (isClusterMode) driverMemOverheadKey else amMemOverheadKey
- val amMemoryOverhead = sparkConf.getInt(amMemoryOverheadConf,
- math.max((MEMORY_OVERHEAD_FACTOR * amMemory).toInt, MEMORY_OVERHEAD_MIN))
+ val amMemoryOverheadEntry = if (isClusterMode) DRIVER_MEMORY_OVERHEAD else AM_MEMORY_OVERHEAD
+ val amMemoryOverhead = sparkConf.get(amMemoryOverheadEntry).getOrElse(
+ math.max((MEMORY_OVERHEAD_FACTOR * amMemory).toLong, MEMORY_OVERHEAD_MIN)).toInt
- val executorMemoryOverhead = sparkConf.getInt("spark.yarn.executor.memoryOverhead",
- math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toInt, MEMORY_OVERHEAD_MIN))
+ val executorMemoryOverhead = sparkConf.get(EXECUTOR_MEMORY_OVERHEAD).getOrElse(
+ math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toLong, MEMORY_OVERHEAD_MIN)).toInt
/** Load any default arguments provided through environment variables and Spark properties. */
private def loadEnvironmentArgs(): Unit = {
// For backward compatibility, SPARK_YARN_DIST_{ARCHIVES/FILES} should be resolved to hdfs://,
// while spark.yarn.dist.{archives/files} should be resolved to file:// (SPARK-2051).
files = Option(files)
- .orElse(sparkConf.getOption("spark.yarn.dist.files").map(p => Utils.resolveURIs(p)))
+ .orElse(sparkConf.get(FILES_TO_DISTRIBUTE).map(p => Utils.resolveURIs(p)))
.orElse(sys.env.get("SPARK_YARN_DIST_FILES"))
.orNull
archives = Option(archives)
- .orElse(sparkConf.getOption("spark.yarn.dist.archives").map(p => Utils.resolveURIs(p)))
+ .orElse(sparkConf.get(ARCHIVES_TO_DISTRIBUTE).map(p => Utils.resolveURIs(p)))
.orElse(sys.env.get("SPARK_YARN_DIST_ARCHIVES"))
.orNull
// If dynamic allocation is enabled, start at the configured initial number of executors.
// Default to minExecutors if no initialExecutors is set.
numExecutors = YarnSparkHadoopUtil.getInitialTargetExecutorNumber(sparkConf, numExecutors)
principal = Option(principal)
- .orElse(sparkConf.getOption("spark.yarn.principal"))
+ .orElse(sparkConf.get(PRINCIPAL))
.orNull
keytab = Option(keytab)
- .orElse(sparkConf.getOption("spark.yarn.keytab"))
+ .orElse(sparkConf.get(KEYTAB))
.orNull
}
@@ -103,13 +103,12 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf)
|${getUsageMessage()}
""".stripMargin)
}
- if (executorCores < sparkConf.getInt("spark.task.cpus", 1)) {
- throw new SparkException("Executor cores must not be less than " +
- "spark.task.cpus.")
+ if (executorCores < sparkConf.get(CPUS_PER_TASK)) {
+ throw new SparkException(s"Executor cores must not be less than ${CPUS_PER_TASK.key}.")
}
// scalastyle:off println
if (isClusterMode) {
- for (key <- Seq(amMemKey, amMemOverheadKey, amCoresKey)) {
+ for (key <- Seq(AM_MEMORY.key, AM_MEMORY_OVERHEAD.key, AM_CORES.key)) {
if (sparkConf.contains(key)) {
println(s"$key is set but does not apply in cluster mode.")
}
@@ -117,17 +116,13 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf)
amMemory = driverMemory
amCores = driverCores
} else {
- for (key <- Seq(driverMemOverheadKey, driverCoresKey)) {
+ for (key <- Seq(DRIVER_MEMORY_OVERHEAD.key, DRIVER_CORES.key)) {
if (sparkConf.contains(key)) {
println(s"$key is set but does not apply in client mode.")
}
}
- sparkConf.getOption(amMemKey)
- .map(Utils.memoryStringToMb)
- .foreach { mem => amMemory = mem }
- sparkConf.getOption(amCoresKey)
- .map(_.toInt)
- .foreach { cores => amCores = cores }
+ amMemory = sparkConf.get(AM_MEMORY).toInt
+ amCores = sparkConf.get(AM_CORES)
}
// scalastyle:on println
}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala
index 6474acc3dc..1ae278d76f 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala
@@ -26,6 +26,7 @@ import org.apache.hadoop.security.{Credentials, UserGroupInformation}
import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.deploy.yarn.config._
import org.apache.spark.util.{ThreadUtils, Utils}
private[spark] class ExecutorDelegationTokenUpdater(
@@ -34,7 +35,7 @@ private[spark] class ExecutorDelegationTokenUpdater(
@volatile private var lastCredentialsFileSuffix = 0
- private val credentialsFile = sparkConf.get("spark.yarn.credentials.file")
+ private val credentialsFile = sparkConf.get(CREDENTIALS_FILE_PATH)
private val freshHadoopConf =
SparkHadoopUtil.get.getConfBypassingFSCache(
hadoopConf, new Path(credentialsFile).toUri.getScheme)
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
index 21ac04dc76..9f91d182eb 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
@@ -38,11 +38,13 @@ import org.apache.hadoop.yarn.ipc.YarnRPC
import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
+import org.apache.spark.deploy.yarn.config._
+import org.apache.spark.internal.config._
import org.apache.spark.launcher.YarnCommandBuilderUtils
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.util.Utils
-class ExecutorRunnable(
+private[yarn] class ExecutorRunnable(
container: Container,
conf: Configuration,
sparkConf: SparkConf,
@@ -104,7 +106,7 @@ class ExecutorRunnable(
// If external shuffle service is enabled, register with the Yarn shuffle service already
// started on the NodeManager and, if authentication is enabled, provide it with our secret
// key for fetching shuffle files later
- if (sparkConf.getBoolean("spark.shuffle.service.enabled", false)) {
+ if (sparkConf.get(SHUFFLE_SERVICE_ENABLED)) {
val secretString = securityMgr.getSecretKey()
val secretBytes =
if (secretString != null) {
@@ -148,13 +150,13 @@ class ExecutorRunnable(
javaOpts += "-Xmx" + executorMemoryString
// Set extra Java options for the executor, if defined
- sys.props.get("spark.executor.extraJavaOptions").foreach { opts =>
+ sparkConf.get(EXECUTOR_JAVA_OPTIONS).foreach { opts =>
javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell)
}
sys.env.get("SPARK_JAVA_OPTS").foreach { opts =>
javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell)
}
- sys.props.get("spark.executor.extraLibraryPath").foreach { p =>
+ sparkConf.get(EXECUTOR_LIBRARY_PATH).foreach { p =>
prefixEnv = Some(Client.getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(p))))
}
@@ -286,8 +288,8 @@ class ExecutorRunnable(
private def prepareEnvironment(container: Container): HashMap[String, String] = {
val env = new HashMap[String, String]()
- val extraCp = sparkConf.getOption("spark.executor.extraClassPath")
- Client.populateClasspath(null, yarnConf, sparkConf, env, false, extraCp)
+ Client.populateClasspath(null, yarnConf, sparkConf, env, false,
+ sparkConf.get(EXECUTOR_CLASS_PATH))
sparkConf.getExecutorEnv.foreach { case (key, value) =>
// This assumes each executor environment variable set here is a path
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala
index 2ec189de7c..8772e26f43 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala
@@ -26,6 +26,7 @@ import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
import org.apache.hadoop.yarn.util.RackResolver
import org.apache.spark.SparkConf
+import org.apache.spark.internal.config._
private[yarn] case class ContainerLocalityPreferences(nodes: Array[String], racks: Array[String])
@@ -84,9 +85,6 @@ private[yarn] class LocalityPreferredContainerPlacementStrategy(
val yarnConf: Configuration,
val resource: Resource) {
- // Number of CPUs per task
- private val CPUS_PER_TASK = sparkConf.getInt("spark.task.cpus", 1)
-
/**
* Calculate each container's node locality and rack locality
* @param numContainer number of containers to calculate
@@ -159,7 +157,7 @@ private[yarn] class LocalityPreferredContainerPlacementStrategy(
*/
private def numExecutorsPending(numTasksPending: Int): Int = {
val coresPerExecutor = resource.getVirtualCores
- (numTasksPending * CPUS_PER_TASK + coresPerExecutor - 1) / coresPerExecutor
+ (numTasksPending * sparkConf.get(CPUS_PER_TASK) + coresPerExecutor - 1) / coresPerExecutor
}
/**
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
index 11426eb07c..a96cb4957b 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
@@ -34,6 +34,7 @@ import org.apache.log4j.{Level, Logger}
import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._
+import org.apache.spark.deploy.yarn.config._
import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef}
import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason}
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor
@@ -107,21 +108,20 @@ private[yarn] class YarnAllocator(
// Executor memory in MB.
protected val executorMemory = args.executorMemory
// Additional memory overhead.
- protected val memoryOverhead: Int = sparkConf.getInt("spark.yarn.executor.memoryOverhead",
- math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toInt, MEMORY_OVERHEAD_MIN))
+ protected val memoryOverhead: Int = sparkConf.get(EXECUTOR_MEMORY_OVERHEAD).getOrElse(
+ math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toInt, MEMORY_OVERHEAD_MIN)).toInt
// Number of cores per executor.
protected val executorCores = args.executorCores
// Resource capability requested for each executors
private[yarn] val resource = Resource.newInstance(executorMemory + memoryOverhead, executorCores)
private val launcherPool = ThreadUtils.newDaemonCachedThreadPool(
- "ContainerLauncher",
- sparkConf.getInt("spark.yarn.containerLauncherMaxThreads", 25))
+ "ContainerLauncher", sparkConf.get(CONTAINER_LAUNCH_MAX_THREADS))
// For testing
private val launchContainers = sparkConf.getBoolean("spark.yarn.launchContainers", true)
- private val labelExpression = sparkConf.getOption("spark.yarn.executor.nodeLabelExpression")
+ private val labelExpression = sparkConf.get(EXECUTOR_NODE_LABEL_EXPRESSION)
// ContainerRequest constructor that can take a node label expression. We grab it through
// reflection because it's only available in later versions of YARN.
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
index 98505b93dd..968f635276 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
@@ -31,6 +31,7 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.webapp.util.WebAppUtils
import org.apache.spark.{Logging, SecurityManager, SparkConf}
+import org.apache.spark.deploy.yarn.config._
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.Utils
@@ -117,7 +118,7 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg
/** Returns the maximum number of attempts to register the AM. */
def getMaxRegAttempts(sparkConf: SparkConf, yarnConf: YarnConfiguration): Int = {
- val sparkMaxAttempts = sparkConf.getOption("spark.yarn.maxAppAttempts").map(_.toInt)
+ val sparkMaxAttempts = sparkConf.get(MAX_APP_ATTEMPTS).map(_.toInt)
val yarnMaxAttempts = yarnConf.getInt(
YarnConfiguration.RM_AM_MAX_ATTEMPTS, YarnConfiguration.DEFAULT_RM_AM_MAX_ATTEMPTS)
val retval: Int = sparkMaxAttempts match {
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
index aef78fdfd4..ed56d4bd44 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
@@ -44,6 +44,8 @@ import org.apache.hadoop.yarn.util.ConverterUtils
import org.apache.spark.{SecurityManager, SparkConf, SparkException}
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.deploy.yarn.config._
+import org.apache.spark.internal.config._
import org.apache.spark.launcher.YarnCommandBuilderUtils
import org.apache.spark.util.Utils
@@ -97,10 +99,7 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil {
* Get the list of namenodes the user may access.
*/
def getNameNodesToAccess(sparkConf: SparkConf): Set[Path] = {
- sparkConf.get("spark.yarn.access.namenodes", "")
- .split(",")
- .map(_.trim())
- .filter(!_.isEmpty)
+ sparkConf.get(NAMENODES_TO_ACCESS)
.map(new Path(_))
.toSet
}
@@ -335,7 +334,7 @@ object YarnSparkHadoopUtil {
// the common cases. Memory overhead tends to grow with container size.
val MEMORY_OVERHEAD_FACTOR = 0.10
- val MEMORY_OVERHEAD_MIN = 384
+ val MEMORY_OVERHEAD_MIN = 384L
val ANY_HOST = "*"
@@ -509,10 +508,9 @@ object YarnSparkHadoopUtil {
conf: SparkConf,
numExecutors: Int = DEFAULT_NUMBER_EXECUTORS): Int = {
if (Utils.isDynamicAllocationEnabled(conf)) {
- val minNumExecutors = conf.getInt("spark.dynamicAllocation.minExecutors", 0)
- val initialNumExecutors =
- conf.getInt("spark.dynamicAllocation.initialExecutors", minNumExecutors)
- val maxNumExecutors = conf.getInt("spark.dynamicAllocation.maxExecutors", Int.MaxValue)
+ val minNumExecutors = conf.get(DYN_ALLOCATION_MIN_EXECUTORS)
+ val initialNumExecutors = conf.get(DYN_ALLOCATION_INITIAL_EXECUTORS)
+ val maxNumExecutors = conf.get(DYN_ALLOCATION_MAX_EXECUTORS)
require(initialNumExecutors >= minNumExecutors && initialNumExecutors <= maxNumExecutors,
s"initial executor number $initialNumExecutors must between min executor number" +
s"$minNumExecutors and max executor number $maxNumExecutors")
@@ -522,7 +520,7 @@ object YarnSparkHadoopUtil {
val targetNumExecutors =
sys.env.get("SPARK_EXECUTOR_INSTANCES").map(_.toInt).getOrElse(numExecutors)
// System property can override environment variable.
- conf.getInt("spark.executor.instances", targetNumExecutors)
+ conf.get(EXECUTOR_INSTANCES).getOrElse(targetNumExecutors)
}
}
}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala
new file mode 100644
index 0000000000..06c1be9bf0
--- /dev/null
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.util.concurrent.TimeUnit
+
+import org.apache.spark.internal.config.ConfigBuilder
+import org.apache.spark.network.util.ByteUnit
+
+package object config {
+
+ /* Common app configuration. */
+
+ private[spark] val APPLICATION_TAGS = ConfigBuilder("spark.yarn.tags")
+ .doc("Comma-separated list of strings to pass through as YARN application tags appearing " +
+ "in YARN Application Reports, which can be used for filtering when querying YARN.")
+ .stringConf
+ .toSequence
+ .optional
+
+ private[spark] val ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS =
+ ConfigBuilder("spark.yarn.am.attemptFailuresValidityInterval")
+ .doc("Interval after which AM failures will be considered independent and " +
+ "not accumulate towards the attempt count.")
+ .timeConf(TimeUnit.MILLISECONDS)
+ .optional
+
+ private[spark] val MAX_APP_ATTEMPTS = ConfigBuilder("spark.yarn.maxAppAttempts")
+ .doc("Maximum number of AM attempts before failing the app.")
+ .intConf
+ .optional
+
+ private[spark] val USER_CLASS_PATH_FIRST = ConfigBuilder("spark.yarn.user.classpath.first")
+ .doc("Whether to place user jars in front of Spark's classpath.")
+ .booleanConf
+ .withDefault(false)
+
+ private[spark] val GATEWAY_ROOT_PATH = ConfigBuilder("spark.yarn.config.gatewayPath")
+ .doc("Root of configuration paths that is present on gateway nodes, and will be replaced " +
+ "with the corresponding path in cluster machines.")
+ .stringConf
+ .withDefault(null)
+
+ private[spark] val REPLACEMENT_ROOT_PATH = ConfigBuilder("spark.yarn.config.replacementPath")
+ .doc(s"Path to use as a replacement for ${GATEWAY_ROOT_PATH.key} when launching processes " +
+ "in the YARN cluster.")
+ .stringConf
+ .withDefault(null)
+
+ private[spark] val QUEUE_NAME = ConfigBuilder("spark.yarn.queue")
+ .stringConf
+ .withDefault("default")
+
+ private[spark] val HISTORY_SERVER_ADDRESS = ConfigBuilder("spark.yarn.historyServer.address")
+ .stringConf
+ .optional
+
+ /* File distribution. */
+
+ private[spark] val SPARK_JAR = ConfigBuilder("spark.yarn.jar")
+ .doc("Location of the Spark jar to use.")
+ .stringConf
+ .optional
+
+ private[spark] val ARCHIVES_TO_DISTRIBUTE = ConfigBuilder("spark.yarn.dist.archives")
+ .stringConf
+ .optional
+
+ private[spark] val FILES_TO_DISTRIBUTE = ConfigBuilder("spark.yarn.dist.files")
+ .stringConf
+ .optional
+
+ private[spark] val PRESERVE_STAGING_FILES = ConfigBuilder("spark.yarn.preserve.staging.files")
+ .doc("Whether to preserve temporary files created by the job in HDFS.")
+ .booleanConf
+ .withDefault(false)
+
+ private[spark] val STAGING_FILE_REPLICATION = ConfigBuilder("spark.yarn.submit.file.replication")
+ .doc("Replication factor for files uploaded by Spark to HDFS.")
+ .intConf
+ .optional
+
+ /* Cluster-mode launcher configuration. */
+
+ private[spark] val WAIT_FOR_APP_COMPLETION = ConfigBuilder("spark.yarn.submit.waitAppCompletion")
+ .doc("In cluster mode, whether to wait for the application to finishe before exiting the " +
+ "launcher process.")
+ .booleanConf
+ .withDefault(true)
+
+ private[spark] val REPORT_INTERVAL = ConfigBuilder("spark.yarn.report.interval")
+ .doc("Interval between reports of the current app status in cluster mode.")
+ .timeConf(TimeUnit.MILLISECONDS)
+ .withDefaultString("1s")
+
+ /* Shared Client-mode AM / Driver configuration. */
+
+ private[spark] val AM_MAX_WAIT_TIME = ConfigBuilder("spark.yarn.am.waitTime")
+ .timeConf(TimeUnit.MILLISECONDS)
+ .withDefaultString("100s")
+
+ private[spark] val AM_NODE_LABEL_EXPRESSION = ConfigBuilder("spark.yarn.am.nodeLabelExpression")
+ .doc("Node label expression for the AM.")
+ .stringConf
+ .optional
+
+ private[spark] val CONTAINER_LAUNCH_MAX_THREADS =
+ ConfigBuilder("spark.yarn.containerLauncherMaxThreads")
+ .intConf
+ .withDefault(25)
+
+ private[spark] val MAX_EXECUTOR_FAILURES = ConfigBuilder("spark.yarn.max.executor.failures")
+ .intConf
+ .optional
+
+ private[spark] val MAX_REPORTER_THREAD_FAILURES =
+ ConfigBuilder("spark.yarn.scheduler.reporterThread.maxFailures")
+ .intConf
+ .withDefault(5)
+
+ private[spark] val RM_HEARTBEAT_INTERVAL =
+ ConfigBuilder("spark.yarn.scheduler.heartbeat.interval-ms")
+ .timeConf(TimeUnit.MILLISECONDS)
+ .withDefaultString("3s")
+
+ private[spark] val INITIAL_HEARTBEAT_INTERVAL =
+ ConfigBuilder("spark.yarn.scheduler.initial-allocation.interval")
+ .timeConf(TimeUnit.MILLISECONDS)
+ .withDefaultString("200ms")
+
+ private[spark] val SCHEDULER_SERVICES = ConfigBuilder("spark.yarn.services")
+ .doc("A comma-separated list of class names of services to add to the scheduler.")
+ .stringConf
+ .toSequence
+ .withDefault(Nil)
+
+ /* Client-mode AM configuration. */
+
+ private[spark] val AM_CORES = ConfigBuilder("spark.yarn.am.cores")
+ .intConf
+ .withDefault(1)
+
+ private[spark] val AM_JAVA_OPTIONS = ConfigBuilder("spark.yarn.am.extraJavaOptions")
+ .doc("Extra Java options for the client-mode AM.")
+ .stringConf
+ .optional
+
+ private[spark] val AM_LIBRARY_PATH = ConfigBuilder("spark.yarn.am.extraLibraryPath")
+ .doc("Extra native library path for the client-mode AM.")
+ .stringConf
+ .optional
+
+ private[spark] val AM_MEMORY_OVERHEAD = ConfigBuilder("spark.yarn.am.memoryOverhead")
+ .bytesConf(ByteUnit.MiB)
+ .optional
+
+ private[spark] val AM_MEMORY = ConfigBuilder("spark.yarn.am.memory")
+ .bytesConf(ByteUnit.MiB)
+ .withDefaultString("512m")
+
+ /* Driver configuration. */
+
+ private[spark] val DRIVER_CORES = ConfigBuilder("spark.driver.cores")
+ .intConf
+ .optional
+
+ private[spark] val DRIVER_MEMORY_OVERHEAD = ConfigBuilder("spark.yarn.driver.memoryOverhead")
+ .bytesConf(ByteUnit.MiB)
+ .optional
+
+ /* Executor configuration. */
+
+ private[spark] val EXECUTOR_MEMORY_OVERHEAD = ConfigBuilder("spark.yarn.executor.memoryOverhead")
+ .bytesConf(ByteUnit.MiB)
+ .optional
+
+ private[spark] val EXECUTOR_NODE_LABEL_EXPRESSION =
+ ConfigBuilder("spark.yarn.executor.nodeLabelExpression")
+ .doc("Node label expression for executors.")
+ .stringConf
+ .optional
+
+ /* Security configuration. */
+
+ private[spark] val CREDENTIAL_FILE_MAX_COUNT =
+ ConfigBuilder("spark.yarn.credentials.file.retention.count")
+ .intConf
+ .withDefault(5)
+
+ private[spark] val CREDENTIALS_FILE_MAX_RETENTION =
+ ConfigBuilder("spark.yarn.credentials.file.retention.days")
+ .intConf
+ .withDefault(5)
+
+ private[spark] val NAMENODES_TO_ACCESS = ConfigBuilder("spark.yarn.access.namenodes")
+ .doc("Extra NameNode URLs for which to request delegation tokens. The NameNode that hosts " +
+ "fs.defaultFS does not need to be listed here.")
+ .stringConf
+ .toSequence
+ .withDefault(Nil)
+
+ private[spark] val TOKEN_RENEWAL_INTERVAL = ConfigBuilder("spark.yarn.token.renewal.interval")
+ .internal
+ .timeConf(TimeUnit.MILLISECONDS)
+ .optional
+
+ /* Private configs. */
+
+ private[spark] val CREDENTIALS_FILE_PATH = ConfigBuilder("spark.yarn.credentials.file")
+ .internal
+ .stringConf
+ .withDefault(null)
+
+ // Internal config to propagate the location of the user's jar to the driver/executors
+ private[spark] val APP_JAR = ConfigBuilder("spark.yarn.user.jar")
+ .internal
+ .stringConf
+ .optional
+
+ // Internal config to propagate the locations of any extra jars to add to the classpath
+ // of the executors
+ private[spark] val SECONDARY_JARS = ConfigBuilder("spark.yarn.secondary.jars")
+ .internal
+ .stringConf
+ .toSequence
+ .optional
+
+}
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala
index c064521845..c4757e335b 100644
--- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala
@@ -22,6 +22,7 @@ import java.util.concurrent.atomic.AtomicBoolean
import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId}
import org.apache.spark.{Logging, SparkContext}
+import org.apache.spark.deploy.yarn.config._
import org.apache.spark.util.Utils
/**
@@ -103,20 +104,15 @@ private[spark] class SchedulerExtensionServices extends SchedulerExtensionServic
val attemptId = binding.attemptId
logInfo(s"Starting Yarn extension services with app $appId and attemptId $attemptId")
- serviceOption = sparkContext.getConf.getOption(SchedulerExtensionServices.SPARK_YARN_SERVICES)
- services = serviceOption
- .map { s =>
- s.split(",").map(_.trim()).filter(!_.isEmpty)
- .map { sClass =>
- val instance = Utils.classForName(sClass)
- .newInstance()
- .asInstanceOf[SchedulerExtensionService]
- // bind this service
- instance.start(binding)
- logInfo(s"Service $sClass started")
- instance
- }.toList
- }.getOrElse(Nil)
+ services = sparkContext.conf.get(SCHEDULER_SERVICES).map { sClass =>
+ val instance = Utils.classForName(sClass)
+ .newInstance()
+ .asInstanceOf[SchedulerExtensionService]
+ // bind this service
+ instance.start(binding)
+ logInfo(s"Service $sClass started")
+ instance
+ }.toList
}
/**
@@ -144,11 +140,3 @@ private[spark] class SchedulerExtensionServices extends SchedulerExtensionServic
| services=$services,
| started=$started)""".stripMargin
}
-
-private[spark] object SchedulerExtensionServices {
-
- /**
- * A list of comma separated services to instantiate in the scheduler
- */
- val SPARK_YARN_SERVICES = "spark.yarn.services"
-}
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
index 19065373c6..b57c179d89 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
@@ -41,6 +41,7 @@ import org.mockito.Mockito._
import org.scalatest.{BeforeAndAfterAll, Matchers}
import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.deploy.yarn.config._
import org.apache.spark.util.{ResetSystemProperties, Utils}
class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll
@@ -103,8 +104,9 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll
test("Local jar URIs") {
val conf = new Configuration()
- val sparkConf = new SparkConf().set(Client.CONF_SPARK_JAR, SPARK)
- .set("spark.yarn.user.classpath.first", "true")
+ val sparkConf = new SparkConf()
+ .set(SPARK_JAR, SPARK)
+ .set(USER_CLASS_PATH_FIRST, true)
val env = new MutableHashMap[String, String]()
val args = new ClientArguments(Array("--jar", USER, "--addJars", ADDED), sparkConf)
@@ -129,13 +131,13 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll
}
cp should contain(pwdVar)
cp should contain (s"$pwdVar${Path.SEPARATOR}${Client.LOCALIZED_CONF_DIR}")
- cp should not contain (Client.SPARK_JAR)
- cp should not contain (Client.APP_JAR)
+ cp should not contain (Client.SPARK_JAR_NAME)
+ cp should not contain (Client.APP_JAR_NAME)
}
test("Jar path propagation through SparkConf") {
val conf = new Configuration()
- val sparkConf = new SparkConf().set(Client.CONF_SPARK_JAR, SPARK)
+ val sparkConf = new SparkConf().set(SPARK_JAR, SPARK)
val args = new ClientArguments(Array("--jar", USER, "--addJars", ADDED), sparkConf)
val client = spy(new Client(args, conf, sparkConf))
@@ -145,7 +147,7 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll
val tempDir = Utils.createTempDir()
try {
client.prepareLocalResources(tempDir.getAbsolutePath(), Nil)
- sparkConf.getOption(Client.CONF_SPARK_USER_JAR) should be (Some(USER))
+ sparkConf.get(APP_JAR) should be (Some(USER))
// The non-local path should be propagated by name only, since it will end up in the app's
// staging dir.
@@ -160,7 +162,7 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll
})
.mkString(",")
- sparkConf.getOption(Client.CONF_SPARK_YARN_SECONDARY_JARS) should be (Some(expected))
+ sparkConf.get(SECONDARY_JARS) should be (Some(expected.split(",").toSeq))
} finally {
Utils.deleteRecursively(tempDir)
}
@@ -169,9 +171,9 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll
test("Cluster path translation") {
val conf = new Configuration()
val sparkConf = new SparkConf()
- .set(Client.CONF_SPARK_JAR, "local:/localPath/spark.jar")
- .set("spark.yarn.config.gatewayPath", "/localPath")
- .set("spark.yarn.config.replacementPath", "/remotePath")
+ .set(SPARK_JAR.key, "local:/localPath/spark.jar")
+ .set(GATEWAY_ROOT_PATH, "/localPath")
+ .set(REPLACEMENT_ROOT_PATH, "/remotePath")
Client.getClusterPath(sparkConf, "/localPath") should be ("/remotePath")
Client.getClusterPath(sparkConf, "/localPath/1:/localPath/2") should be (
@@ -191,8 +193,8 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll
// Spaces between non-comma strings should be preserved as single tags. Empty strings may or
// may not be removed depending on the version of Hadoop being used.
val sparkConf = new SparkConf()
- .set(Client.CONF_SPARK_YARN_APPLICATION_TAGS, ",tag1, dup,tag2 , ,multi word , dup")
- .set("spark.yarn.maxAppAttempts", "42")
+ .set(APPLICATION_TAGS.key, ",tag1, dup,tag2 , ,multi word , dup")
+ .set(MAX_APP_ATTEMPTS, 42)
val args = new ClientArguments(Array(
"--name", "foo-test-app",
"--queue", "staging-queue"), sparkConf)
diff --git a/yarn/src/test/scala/org/apache/spark/scheduler/cluster/ExtensionServiceIntegrationSuite.scala b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/ExtensionServiceIntegrationSuite.scala
index b4d1b0a3d2..338fbe2ef4 100644
--- a/yarn/src/test/scala/org/apache/spark/scheduler/cluster/ExtensionServiceIntegrationSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/ExtensionServiceIntegrationSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.scheduler.cluster
import org.scalatest.BeforeAndAfter
import org.apache.spark.{LocalSparkContext, Logging, SparkConf, SparkContext, SparkFunSuite}
+import org.apache.spark.deploy.yarn.config._
/**
* Test the integration with [[SchedulerExtensionServices]]
@@ -36,8 +37,7 @@ class ExtensionServiceIntegrationSuite extends SparkFunSuite
*/
before {
val sparkConf = new SparkConf()
- sparkConf.set(SchedulerExtensionServices.SPARK_YARN_SERVICES,
- classOf[SimpleExtensionService].getName())
+ sparkConf.set(SCHEDULER_SERVICES, Seq(classOf[SimpleExtensionService].getName()))
sparkConf.setMaster("local").setAppName("ExtensionServiceIntegrationSuite")
sc = new SparkContext(sparkConf)
}