aboutsummaryrefslogtreecommitdiff
path: root/resource-managers/yarn/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'resource-managers/yarn/src/test')
-rw-r--r--resource-managers/yarn/src/test/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider1
-rw-r--r--resource-managers/yarn/src/test/resources/log4j.properties31
-rw-r--r--resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala241
-rw-r--r--resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala204
-rw-r--r--resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala462
-rw-r--r--resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala153
-rw-r--r--resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala344
-rw-r--r--resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala493
-rw-r--r--resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala112
-rw-r--r--resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala213
-rw-r--r--resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManagerSuite.scala150
-rw-r--r--resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/HDFSCredentialProviderSuite.scala71
-rw-r--r--resource-managers/yarn/src/test/scala/org/apache/spark/launcher/TestClasspathBuilder.scala36
-rw-r--r--resource-managers/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala70
-rw-r--r--resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala372
-rw-r--r--resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnTestAccessor.scala37
-rw-r--r--resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/ExtensionServiceIntegrationSuite.scala72
-rw-r--r--resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/SimpleExtensionService.scala34
-rw-r--r--resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationAttemptId.scala48
-rw-r--r--resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationId.scala42
20 files changed, 3186 insertions, 0 deletions
diff --git a/resource-managers/yarn/src/test/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider b/resource-managers/yarn/src/test/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider
new file mode 100644
index 0000000000..d0ef5efa36
--- /dev/null
+++ b/resource-managers/yarn/src/test/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider
@@ -0,0 +1 @@
+org.apache.spark.deploy.yarn.security.TestCredentialProvider
diff --git a/resource-managers/yarn/src/test/resources/log4j.properties b/resource-managers/yarn/src/test/resources/log4j.properties
new file mode 100644
index 0000000000..d13454d5ae
--- /dev/null
+++ b/resource-managers/yarn/src/test/resources/log4j.properties
@@ -0,0 +1,31 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Set everything to be logged to the file target/unit-tests.log
+log4j.rootCategory=DEBUG, file
+log4j.appender.file=org.apache.log4j.FileAppender
+log4j.appender.file.append=true
+log4j.appender.file.file=target/unit-tests.log
+log4j.appender.file.layout=org.apache.log4j.PatternLayout
+log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
+
+# Ignore messages below warning level from a few verbose libraries.
+log4j.logger.com.sun.jersey=WARN
+log4j.logger.org.apache.hadoop=WARN
+log4j.logger.org.eclipse.jetty=WARN
+log4j.logger.org.mortbay=WARN
+log4j.logger.org.spark_project.jetty=WARN
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala
new file mode 100644
index 0000000000..9c3b18e4ec
--- /dev/null
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala
@@ -0,0 +1,241 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.io.{File, FileOutputStream, OutputStreamWriter}
+import java.nio.charset.StandardCharsets
+import java.util.Properties
+import java.util.concurrent.TimeUnit
+
+import scala.collection.JavaConverters._
+import scala.concurrent.duration._
+import scala.language.postfixOps
+
+import com.google.common.io.Files
+import org.apache.commons.lang3.SerializationUtils
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.server.MiniYARNCluster
+import org.scalatest.{BeforeAndAfterAll, Matchers}
+import org.scalatest.concurrent.Eventually._
+
+import org.apache.spark._
+import org.apache.spark.deploy.yarn.config._
+import org.apache.spark.internal.Logging
+import org.apache.spark.launcher._
+import org.apache.spark.util.Utils
+
+abstract class BaseYarnClusterSuite
+ extends SparkFunSuite with BeforeAndAfterAll with Matchers with Logging {
+
+ // log4j configuration for the YARN containers, so that their output is collected
+ // by YARN instead of trying to overwrite unit-tests.log.
+ protected val LOG4J_CONF = """
+ |log4j.rootCategory=DEBUG, console
+ |log4j.appender.console=org.apache.log4j.ConsoleAppender
+ |log4j.appender.console.target=System.err
+ |log4j.appender.console.layout=org.apache.log4j.PatternLayout
+ |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
+ |log4j.logger.org.apache.hadoop=WARN
+ |log4j.logger.org.eclipse.jetty=WARN
+ |log4j.logger.org.mortbay=WARN
+ |log4j.logger.org.spark_project.jetty=WARN
+ """.stripMargin
+
+ private var yarnCluster: MiniYARNCluster = _
+ protected var tempDir: File = _
+ private var fakeSparkJar: File = _
+ protected var hadoopConfDir: File = _
+ private var logConfDir: File = _
+
+ var oldSystemProperties: Properties = null
+
+ def newYarnConfig(): YarnConfiguration
+
+ override def beforeAll() {
+ super.beforeAll()
+ oldSystemProperties = SerializationUtils.clone(System.getProperties)
+
+ tempDir = Utils.createTempDir()
+ logConfDir = new File(tempDir, "log4j")
+ logConfDir.mkdir()
+ System.setProperty("SPARK_YARN_MODE", "true")
+
+ val logConfFile = new File(logConfDir, "log4j.properties")
+ Files.write(LOG4J_CONF, logConfFile, StandardCharsets.UTF_8)
+
+ // Disable the disk utilization check to avoid the test hanging when people's disks are
+ // getting full.
+ val yarnConf = newYarnConfig()
+ yarnConf.set("yarn.nodemanager.disk-health-checker.max-disk-utilization-per-disk-percentage",
+ "100.0")
+
+ yarnCluster = new MiniYARNCluster(getClass().getName(), 1, 1, 1)
+ yarnCluster.init(yarnConf)
+ yarnCluster.start()
+
+ // There's a race in MiniYARNCluster in which start() may return before the RM has updated
+ // its address in the configuration. You can see this in the logs by noticing that when
+ // MiniYARNCluster prints the address, it still has port "0" assigned, although later the
+ // test works sometimes:
+ //
+ // INFO MiniYARNCluster: MiniYARN ResourceManager address: blah:0
+ //
+ // That log message prints the contents of the RM_ADDRESS config variable. If you check it
+ // later on, it looks something like this:
+ //
+ // INFO YarnClusterSuite: RM address in configuration is blah:42631
+ //
+ // This hack loops for a bit waiting for the port to change, and fails the test if it hasn't
+ // done so in a timely manner (defined to be 10 seconds).
+ val config = yarnCluster.getConfig()
+ val deadline = System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(10)
+ while (config.get(YarnConfiguration.RM_ADDRESS).split(":")(1) == "0") {
+ if (System.currentTimeMillis() > deadline) {
+ throw new IllegalStateException("Timed out waiting for RM to come up.")
+ }
+ logDebug("RM address still not set in configuration, waiting...")
+ TimeUnit.MILLISECONDS.sleep(100)
+ }
+
+ logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}")
+
+ fakeSparkJar = File.createTempFile("sparkJar", null, tempDir)
+ hadoopConfDir = new File(tempDir, Client.LOCALIZED_CONF_DIR)
+ assert(hadoopConfDir.mkdir())
+ File.createTempFile("token", ".txt", hadoopConfDir)
+ }
+
+ override def afterAll() {
+ try {
+ yarnCluster.stop()
+ } finally {
+ System.setProperties(oldSystemProperties)
+ super.afterAll()
+ }
+ }
+
+ protected def runSpark(
+ clientMode: Boolean,
+ klass: String,
+ appArgs: Seq[String] = Nil,
+ sparkArgs: Seq[(String, String)] = Nil,
+ extraClassPath: Seq[String] = Nil,
+ extraJars: Seq[String] = Nil,
+ extraConf: Map[String, String] = Map(),
+ extraEnv: Map[String, String] = Map()): SparkAppHandle.State = {
+ val deployMode = if (clientMode) "client" else "cluster"
+ val propsFile = createConfFile(extraClassPath = extraClassPath, extraConf = extraConf)
+ val env = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath()) ++ extraEnv
+
+ val launcher = new SparkLauncher(env.asJava)
+ if (klass.endsWith(".py")) {
+ launcher.setAppResource(klass)
+ } else {
+ launcher.setMainClass(klass)
+ launcher.setAppResource(fakeSparkJar.getAbsolutePath())
+ }
+ launcher.setSparkHome(sys.props("spark.test.home"))
+ .setMaster("yarn")
+ .setDeployMode(deployMode)
+ .setConf("spark.executor.instances", "1")
+ .setPropertiesFile(propsFile)
+ .addAppArgs(appArgs.toArray: _*)
+
+ sparkArgs.foreach { case (name, value) =>
+ if (value != null) {
+ launcher.addSparkArg(name, value)
+ } else {
+ launcher.addSparkArg(name)
+ }
+ }
+ extraJars.foreach(launcher.addJar)
+
+ val handle = launcher.startApplication()
+ try {
+ eventually(timeout(2 minutes), interval(1 second)) {
+ assert(handle.getState().isFinal())
+ }
+ } finally {
+ handle.kill()
+ }
+
+ handle.getState()
+ }
+
+ /**
+ * This is a workaround for an issue with yarn-cluster mode: the Client class will not provide
+ * any sort of error when the job process finishes successfully, but the job itself fails. So
+ * the tests enforce that something is written to a file after everything is ok to indicate
+ * that the job succeeded.
+ */
+ protected def checkResult(finalState: SparkAppHandle.State, result: File): Unit = {
+ checkResult(finalState, result, "success")
+ }
+
+ protected def checkResult(
+ finalState: SparkAppHandle.State,
+ result: File,
+ expected: String): Unit = {
+ finalState should be (SparkAppHandle.State.FINISHED)
+ val resultString = Files.toString(result, StandardCharsets.UTF_8)
+ resultString should be (expected)
+ }
+
+ protected def mainClassName(klass: Class[_]): String = {
+ klass.getName().stripSuffix("$")
+ }
+
+ protected def createConfFile(
+ extraClassPath: Seq[String] = Nil,
+ extraConf: Map[String, String] = Map()): String = {
+ val props = new Properties()
+ props.put(SPARK_JARS.key, "local:" + fakeSparkJar.getAbsolutePath())
+
+ val testClasspath = new TestClasspathBuilder()
+ .buildClassPath(
+ logConfDir.getAbsolutePath() +
+ File.pathSeparator +
+ extraClassPath.mkString(File.pathSeparator))
+ .asScala
+ .mkString(File.pathSeparator)
+
+ props.put("spark.driver.extraClassPath", testClasspath)
+ props.put("spark.executor.extraClassPath", testClasspath)
+
+ // SPARK-4267: make sure java options are propagated correctly.
+ props.setProperty("spark.driver.extraJavaOptions", "-Dfoo=\"one two three\"")
+ props.setProperty("spark.executor.extraJavaOptions", "-Dfoo=\"one two three\"")
+
+ yarnCluster.getConfig().asScala.foreach { e =>
+ props.setProperty("spark.hadoop." + e.getKey(), e.getValue())
+ }
+ sys.props.foreach { case (k, v) =>
+ if (k.startsWith("spark.")) {
+ props.setProperty(k, v)
+ }
+ }
+ extraConf.foreach { case (k, v) => props.setProperty(k, v) }
+
+ val propsFile = File.createTempFile("spark", ".properties", tempDir)
+ val writer = new OutputStreamWriter(new FileOutputStream(propsFile), StandardCharsets.UTF_8)
+ props.store(writer, "Spark properties.")
+ writer.close()
+ propsFile.getAbsolutePath()
+ }
+
+}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala
new file mode 100644
index 0000000000..b696e080ce
--- /dev/null
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala
@@ -0,0 +1,204 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.net.URI
+
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.Map
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.FileStatus
+import org.apache.hadoop.fs.FileSystem
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.yarn.api.records.LocalResource
+import org.apache.hadoop.yarn.api.records.LocalResourceType
+import org.apache.hadoop.yarn.api.records.LocalResourceVisibility
+import org.apache.hadoop.yarn.util.ConverterUtils
+import org.mockito.Mockito.when
+import org.scalatest.mock.MockitoSugar
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.deploy.yarn.config._
+
+class ClientDistributedCacheManagerSuite extends SparkFunSuite with MockitoSugar {
+
+ class MockClientDistributedCacheManager extends ClientDistributedCacheManager {
+ override def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]):
+ LocalResourceVisibility = {
+ LocalResourceVisibility.PRIVATE
+ }
+ }
+
+ test("test getFileStatus empty") {
+ val distMgr = new ClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val uri = new URI("/tmp/testing")
+ when(fs.getFileStatus(new Path(uri))).thenReturn(new FileStatus())
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+ val stat = distMgr.getFileStatus(fs, uri, statCache)
+ assert(stat.getPath() === null)
+ }
+
+ test("test getFileStatus cached") {
+ val distMgr = new ClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val uri = new URI("/tmp/testing")
+ val realFileStatus = new FileStatus(10, false, 1, 1024, 10, 10, null, "testOwner",
+ null, new Path("/tmp/testing"))
+ when(fs.getFileStatus(new Path(uri))).thenReturn(new FileStatus())
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus](uri -> realFileStatus)
+ val stat = distMgr.getFileStatus(fs, uri, statCache)
+ assert(stat.getPath().toString() === "/tmp/testing")
+ }
+
+ test("test addResource") {
+ val distMgr = new MockClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val conf = new Configuration()
+ val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing")
+ val localResources = HashMap[String, LocalResource]()
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+ when(fs.getFileStatus(destPath)).thenReturn(new FileStatus())
+
+ distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, "link",
+ statCache, false)
+ val resource = localResources("link")
+ assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE)
+ assert(ConverterUtils.getPathFromYarnURL(resource.getResource()) === destPath)
+ assert(resource.getTimestamp() === 0)
+ assert(resource.getSize() === 0)
+ assert(resource.getType() === LocalResourceType.FILE)
+
+ val sparkConf = new SparkConf(false)
+ distMgr.updateConfiguration(sparkConf)
+ assert(sparkConf.get(CACHED_FILES) === Seq("file:/foo.invalid.com:8080/tmp/testing#link"))
+ assert(sparkConf.get(CACHED_FILES_TIMESTAMPS) === Seq(0L))
+ assert(sparkConf.get(CACHED_FILES_SIZES) === Seq(0L))
+ assert(sparkConf.get(CACHED_FILES_VISIBILITIES) === Seq(LocalResourceVisibility.PRIVATE.name()))
+ assert(sparkConf.get(CACHED_FILES_TYPES) === Seq(LocalResourceType.FILE.name()))
+
+ // add another one and verify both there and order correct
+ val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner",
+ null, new Path("/tmp/testing2"))
+ val destPath2 = new Path("file:///foo.invalid.com:8080/tmp/testing2")
+ when(fs.getFileStatus(destPath2)).thenReturn(realFileStatus)
+ distMgr.addResource(fs, conf, destPath2, localResources, LocalResourceType.FILE, "link2",
+ statCache, false)
+ val resource2 = localResources("link2")
+ assert(resource2.getVisibility() === LocalResourceVisibility.PRIVATE)
+ assert(ConverterUtils.getPathFromYarnURL(resource2.getResource()) === destPath2)
+ assert(resource2.getTimestamp() === 10)
+ assert(resource2.getSize() === 20)
+ assert(resource2.getType() === LocalResourceType.FILE)
+
+ val sparkConf2 = new SparkConf(false)
+ distMgr.updateConfiguration(sparkConf2)
+
+ val files = sparkConf2.get(CACHED_FILES)
+ val sizes = sparkConf2.get(CACHED_FILES_SIZES)
+ val timestamps = sparkConf2.get(CACHED_FILES_TIMESTAMPS)
+ val visibilities = sparkConf2.get(CACHED_FILES_VISIBILITIES)
+
+ assert(files(0) === "file:/foo.invalid.com:8080/tmp/testing#link")
+ assert(timestamps(0) === 0)
+ assert(sizes(0) === 0)
+ assert(visibilities(0) === LocalResourceVisibility.PRIVATE.name())
+
+ assert(files(1) === "file:/foo.invalid.com:8080/tmp/testing2#link2")
+ assert(timestamps(1) === 10)
+ assert(sizes(1) === 20)
+ assert(visibilities(1) === LocalResourceVisibility.PRIVATE.name())
+ }
+
+ test("test addResource link null") {
+ val distMgr = new MockClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val conf = new Configuration()
+ val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing")
+ val localResources = HashMap[String, LocalResource]()
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+ when(fs.getFileStatus(destPath)).thenReturn(new FileStatus())
+
+ intercept[Exception] {
+ distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, null,
+ statCache, false)
+ }
+ assert(localResources.get("link") === None)
+ assert(localResources.size === 0)
+ }
+
+ test("test addResource appmaster only") {
+ val distMgr = new MockClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val conf = new Configuration()
+ val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing")
+ val localResources = HashMap[String, LocalResource]()
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+ val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner",
+ null, new Path("/tmp/testing"))
+ when(fs.getFileStatus(destPath)).thenReturn(realFileStatus)
+
+ distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link",
+ statCache, true)
+ val resource = localResources("link")
+ assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE)
+ assert(ConverterUtils.getPathFromYarnURL(resource.getResource()) === destPath)
+ assert(resource.getTimestamp() === 10)
+ assert(resource.getSize() === 20)
+ assert(resource.getType() === LocalResourceType.ARCHIVE)
+
+ val sparkConf = new SparkConf(false)
+ distMgr.updateConfiguration(sparkConf)
+ assert(sparkConf.get(CACHED_FILES) === Nil)
+ assert(sparkConf.get(CACHED_FILES_TIMESTAMPS) === Nil)
+ assert(sparkConf.get(CACHED_FILES_SIZES) === Nil)
+ assert(sparkConf.get(CACHED_FILES_VISIBILITIES) === Nil)
+ assert(sparkConf.get(CACHED_FILES_TYPES) === Nil)
+ }
+
+ test("test addResource archive") {
+ val distMgr = new MockClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val conf = new Configuration()
+ val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing")
+ val localResources = HashMap[String, LocalResource]()
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+ val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner",
+ null, new Path("/tmp/testing"))
+ when(fs.getFileStatus(destPath)).thenReturn(realFileStatus)
+
+ distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link",
+ statCache, false)
+ val resource = localResources("link")
+ assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE)
+ assert(ConverterUtils.getPathFromYarnURL(resource.getResource()) === destPath)
+ assert(resource.getTimestamp() === 10)
+ assert(resource.getSize() === 20)
+ assert(resource.getType() === LocalResourceType.ARCHIVE)
+
+ val sparkConf = new SparkConf(false)
+ distMgr.updateConfiguration(sparkConf)
+ assert(sparkConf.get(CACHED_FILES) === Seq("file:/foo.invalid.com:8080/tmp/testing#link"))
+ assert(sparkConf.get(CACHED_FILES_SIZES) === Seq(20L))
+ assert(sparkConf.get(CACHED_FILES_TIMESTAMPS) === Seq(10L))
+ assert(sparkConf.get(CACHED_FILES_VISIBILITIES) === Seq(LocalResourceVisibility.PRIVATE.name()))
+ assert(sparkConf.get(CACHED_FILES_TYPES) === Seq(LocalResourceType.ARCHIVE.name()))
+ }
+
+}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
new file mode 100644
index 0000000000..7deaf0af94
--- /dev/null
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
@@ -0,0 +1,462 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.io.{File, FileInputStream, FileOutputStream}
+import java.net.URI
+import java.util.Properties
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.{HashMap => MutableHashMap}
+import scala.reflect.ClassTag
+import scala.util.Try
+
+import org.apache.commons.lang3.SerializationUtils
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.mapreduce.MRJobConfig
+import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
+import org.apache.hadoop.yarn.api.protocolrecords.GetNewApplicationResponse
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.client.api.YarnClientApplication
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.util.Records
+import org.mockito.Matchers.{eq => meq, _}
+import org.mockito.Mockito._
+import org.scalatest.{BeforeAndAfterAll, Matchers}
+
+import org.apache.spark.{SparkConf, SparkFunSuite, TestUtils}
+import org.apache.spark.deploy.yarn.config._
+import org.apache.spark.util.{ResetSystemProperties, SparkConfWithEnv, Utils}
+
+class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll
+ with ResetSystemProperties {
+
+ import Client._
+
+ var oldSystemProperties: Properties = null
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ oldSystemProperties = SerializationUtils.clone(System.getProperties)
+ System.setProperty("SPARK_YARN_MODE", "true")
+ }
+
+ override def afterAll(): Unit = {
+ try {
+ System.setProperties(oldSystemProperties)
+ oldSystemProperties = null
+ } finally {
+ super.afterAll()
+ }
+ }
+
+ test("default Yarn application classpath") {
+ getDefaultYarnApplicationClasspath should be(Some(Fixtures.knownDefYarnAppCP))
+ }
+
+ test("default MR application classpath") {
+ getDefaultMRApplicationClasspath should be(Some(Fixtures.knownDefMRAppCP))
+ }
+
+ test("resultant classpath for an application that defines a classpath for YARN") {
+ withAppConf(Fixtures.mapYARNAppConf) { conf =>
+ val env = newEnv
+ populateHadoopClasspath(conf, env)
+ classpath(env) should be(
+ flatten(Fixtures.knownYARNAppCP, getDefaultMRApplicationClasspath))
+ }
+ }
+
+ test("resultant classpath for an application that defines a classpath for MR") {
+ withAppConf(Fixtures.mapMRAppConf) { conf =>
+ val env = newEnv
+ populateHadoopClasspath(conf, env)
+ classpath(env) should be(
+ flatten(getDefaultYarnApplicationClasspath, Fixtures.knownMRAppCP))
+ }
+ }
+
+ test("resultant classpath for an application that defines both classpaths, YARN and MR") {
+ withAppConf(Fixtures.mapAppConf) { conf =>
+ val env = newEnv
+ populateHadoopClasspath(conf, env)
+ classpath(env) should be(flatten(Fixtures.knownYARNAppCP, Fixtures.knownMRAppCP))
+ }
+ }
+
+ private val SPARK = "local:/sparkJar"
+ private val USER = "local:/userJar"
+ private val ADDED = "local:/addJar1,local:/addJar2,/addJar3"
+
+ private val PWD =
+ if (classOf[Environment].getMethods().exists(_.getName == "$$")) {
+ "{{PWD}}"
+ } else if (Utils.isWindows) {
+ "%PWD%"
+ } else {
+ Environment.PWD.$()
+ }
+
+ test("Local jar URIs") {
+ val conf = new Configuration()
+ val sparkConf = new SparkConf()
+ .set(SPARK_JARS, Seq(SPARK))
+ .set(USER_CLASS_PATH_FIRST, true)
+ .set("spark.yarn.dist.jars", ADDED)
+ val env = new MutableHashMap[String, String]()
+ val args = new ClientArguments(Array("--jar", USER))
+
+ populateClasspath(args, conf, sparkConf, env)
+
+ val cp = env("CLASSPATH").split(":|;|<CPS>")
+ s"$SPARK,$USER,$ADDED".split(",").foreach({ entry =>
+ val uri = new URI(entry)
+ if (LOCAL_SCHEME.equals(uri.getScheme())) {
+ cp should contain (uri.getPath())
+ } else {
+ cp should not contain (uri.getPath())
+ }
+ })
+ cp should contain(PWD)
+ cp should contain (s"$PWD${Path.SEPARATOR}${LOCALIZED_CONF_DIR}")
+ cp should not contain (APP_JAR)
+ }
+
+ test("Jar path propagation through SparkConf") {
+ val conf = new Configuration()
+ val sparkConf = new SparkConf()
+ .set(SPARK_JARS, Seq(SPARK))
+ .set("spark.yarn.dist.jars", ADDED)
+ val client = createClient(sparkConf, args = Array("--jar", USER))
+ doReturn(new Path("/")).when(client).copyFileToRemote(any(classOf[Path]),
+ any(classOf[Path]), anyShort(), anyBoolean(), any())
+
+ val tempDir = Utils.createTempDir()
+ try {
+ // Because we mocked "copyFileToRemote" above to avoid having to create fake local files,
+ // we need to create a fake config archive in the temp dir to avoid having
+ // prepareLocalResources throw an exception.
+ new FileOutputStream(new File(tempDir, LOCALIZED_CONF_ARCHIVE)).close()
+
+ client.prepareLocalResources(new Path(tempDir.getAbsolutePath()), Nil)
+ sparkConf.get(APP_JAR) should be (Some(USER))
+
+ // The non-local path should be propagated by name only, since it will end up in the app's
+ // staging dir.
+ val expected = ADDED.split(",")
+ .map(p => {
+ val uri = new URI(p)
+ if (LOCAL_SCHEME == uri.getScheme()) {
+ p
+ } else {
+ Option(uri.getFragment()).getOrElse(new File(p).getName())
+ }
+ })
+ .mkString(",")
+
+ sparkConf.get(SECONDARY_JARS) should be (Some(expected.split(",").toSeq))
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+
+ test("Cluster path translation") {
+ val conf = new Configuration()
+ val sparkConf = new SparkConf()
+ .set(SPARK_JARS, Seq("local:/localPath/spark.jar"))
+ .set(GATEWAY_ROOT_PATH, "/localPath")
+ .set(REPLACEMENT_ROOT_PATH, "/remotePath")
+
+ getClusterPath(sparkConf, "/localPath") should be ("/remotePath")
+ getClusterPath(sparkConf, "/localPath/1:/localPath/2") should be (
+ "/remotePath/1:/remotePath/2")
+
+ val env = new MutableHashMap[String, String]()
+ populateClasspath(null, conf, sparkConf, env, extraClassPath = Some("/localPath/my1.jar"))
+ val cp = classpath(env)
+ cp should contain ("/remotePath/spark.jar")
+ cp should contain ("/remotePath/my1.jar")
+ }
+
+ test("configuration and args propagate through createApplicationSubmissionContext") {
+ val conf = new Configuration()
+ // When parsing tags, duplicates and leading/trailing whitespace should be removed.
+ // Spaces between non-comma strings should be preserved as single tags. Empty strings may or
+ // may not be removed depending on the version of Hadoop being used.
+ val sparkConf = new SparkConf()
+ .set(APPLICATION_TAGS.key, ",tag1, dup,tag2 , ,multi word , dup")
+ .set(MAX_APP_ATTEMPTS, 42)
+ .set("spark.app.name", "foo-test-app")
+ .set(QUEUE_NAME, "staging-queue")
+ val args = new ClientArguments(Array())
+
+ val appContext = Records.newRecord(classOf[ApplicationSubmissionContext])
+ val getNewApplicationResponse = Records.newRecord(classOf[GetNewApplicationResponse])
+ val containerLaunchContext = Records.newRecord(classOf[ContainerLaunchContext])
+
+ val client = new Client(args, conf, sparkConf)
+ client.createApplicationSubmissionContext(
+ new YarnClientApplication(getNewApplicationResponse, appContext),
+ containerLaunchContext)
+
+ appContext.getApplicationName should be ("foo-test-app")
+ appContext.getQueue should be ("staging-queue")
+ appContext.getAMContainerSpec should be (containerLaunchContext)
+ appContext.getApplicationType should be ("SPARK")
+ appContext.getClass.getMethods.filter(_.getName.equals("getApplicationTags")).foreach{ method =>
+ val tags = method.invoke(appContext).asInstanceOf[java.util.Set[String]]
+ tags should contain allOf ("tag1", "dup", "tag2", "multi word")
+ tags.asScala.count(_.nonEmpty) should be (4)
+ }
+ appContext.getMaxAppAttempts should be (42)
+ }
+
+ test("spark.yarn.jars with multiple paths and globs") {
+ val libs = Utils.createTempDir()
+ val single = Utils.createTempDir()
+ val jar1 = TestUtils.createJarWithFiles(Map(), libs)
+ val jar2 = TestUtils.createJarWithFiles(Map(), libs)
+ val jar3 = TestUtils.createJarWithFiles(Map(), single)
+ val jar4 = TestUtils.createJarWithFiles(Map(), single)
+
+ val jarsConf = Seq(
+ s"${libs.getAbsolutePath()}/*",
+ jar3.getPath(),
+ s"local:${jar4.getPath()}",
+ s"local:${single.getAbsolutePath()}/*")
+
+ val sparkConf = new SparkConf().set(SPARK_JARS, jarsConf)
+ val client = createClient(sparkConf)
+
+ val tempDir = Utils.createTempDir()
+ client.prepareLocalResources(new Path(tempDir.getAbsolutePath()), Nil)
+
+ assert(sparkConf.get(SPARK_JARS) ===
+ Some(Seq(s"local:${jar4.getPath()}", s"local:${single.getAbsolutePath()}/*")))
+
+ verify(client).copyFileToRemote(any(classOf[Path]), meq(new Path(jar1.toURI())), anyShort(),
+ anyBoolean(), any())
+ verify(client).copyFileToRemote(any(classOf[Path]), meq(new Path(jar2.toURI())), anyShort(),
+ anyBoolean(), any())
+ verify(client).copyFileToRemote(any(classOf[Path]), meq(new Path(jar3.toURI())), anyShort(),
+ anyBoolean(), any())
+
+ val cp = classpath(client)
+ cp should contain (buildPath(PWD, LOCALIZED_LIB_DIR, "*"))
+ cp should not contain (jar3.getPath())
+ cp should contain (jar4.getPath())
+ cp should contain (buildPath(single.getAbsolutePath(), "*"))
+ }
+
+ test("distribute jars archive") {
+ val temp = Utils.createTempDir()
+ val archive = TestUtils.createJarWithFiles(Map(), temp)
+
+ val sparkConf = new SparkConf().set(SPARK_ARCHIVE, archive.getPath())
+ val client = createClient(sparkConf)
+ client.prepareLocalResources(new Path(temp.getAbsolutePath()), Nil)
+
+ verify(client).copyFileToRemote(any(classOf[Path]), meq(new Path(archive.toURI())), anyShort(),
+ anyBoolean(), any())
+ classpath(client) should contain (buildPath(PWD, LOCALIZED_LIB_DIR, "*"))
+
+ sparkConf.set(SPARK_ARCHIVE, LOCAL_SCHEME + ":" + archive.getPath())
+ intercept[IllegalArgumentException] {
+ client.prepareLocalResources(new Path(temp.getAbsolutePath()), Nil)
+ }
+ }
+
+ test("distribute archive multiple times") {
+ val libs = Utils.createTempDir()
+ // Create jars dir and RELEASE file to avoid IllegalStateException.
+ val jarsDir = new File(libs, "jars")
+ assert(jarsDir.mkdir())
+ new FileOutputStream(new File(libs, "RELEASE")).close()
+
+ val userLib1 = Utils.createTempDir()
+ val testJar = TestUtils.createJarWithFiles(Map(), userLib1)
+
+ // Case 1: FILES_TO_DISTRIBUTE and ARCHIVES_TO_DISTRIBUTE can't have duplicate files
+ val sparkConf = new SparkConfWithEnv(Map("SPARK_HOME" -> libs.getAbsolutePath))
+ .set(FILES_TO_DISTRIBUTE, Seq(testJar.getPath))
+ .set(ARCHIVES_TO_DISTRIBUTE, Seq(testJar.getPath))
+
+ val client = createClient(sparkConf)
+ val tempDir = Utils.createTempDir()
+ intercept[IllegalArgumentException] {
+ client.prepareLocalResources(new Path(tempDir.getAbsolutePath()), Nil)
+ }
+
+ // Case 2: FILES_TO_DISTRIBUTE can't have duplicate files.
+ val sparkConfFiles = new SparkConfWithEnv(Map("SPARK_HOME" -> libs.getAbsolutePath))
+ .set(FILES_TO_DISTRIBUTE, Seq(testJar.getPath, testJar.getPath))
+
+ val clientFiles = createClient(sparkConfFiles)
+ val tempDirForFiles = Utils.createTempDir()
+ intercept[IllegalArgumentException] {
+ clientFiles.prepareLocalResources(new Path(tempDirForFiles.getAbsolutePath()), Nil)
+ }
+
+ // Case 3: ARCHIVES_TO_DISTRIBUTE can't have duplicate files.
+ val sparkConfArchives = new SparkConfWithEnv(Map("SPARK_HOME" -> libs.getAbsolutePath))
+ .set(ARCHIVES_TO_DISTRIBUTE, Seq(testJar.getPath, testJar.getPath))
+
+ val clientArchives = createClient(sparkConfArchives)
+ val tempDirForArchives = Utils.createTempDir()
+ intercept[IllegalArgumentException] {
+ clientArchives.prepareLocalResources(new Path(tempDirForArchives.getAbsolutePath()), Nil)
+ }
+
+ // Case 4: FILES_TO_DISTRIBUTE can have unique file.
+ val sparkConfFilesUniq = new SparkConfWithEnv(Map("SPARK_HOME" -> libs.getAbsolutePath))
+ .set(FILES_TO_DISTRIBUTE, Seq(testJar.getPath))
+
+ val clientFilesUniq = createClient(sparkConfFilesUniq)
+ val tempDirForFilesUniq = Utils.createTempDir()
+ clientFilesUniq.prepareLocalResources(new Path(tempDirForFilesUniq.getAbsolutePath()), Nil)
+
+ // Case 5: ARCHIVES_TO_DISTRIBUTE can have unique file.
+ val sparkConfArchivesUniq = new SparkConfWithEnv(Map("SPARK_HOME" -> libs.getAbsolutePath))
+ .set(ARCHIVES_TO_DISTRIBUTE, Seq(testJar.getPath))
+
+ val clientArchivesUniq = createClient(sparkConfArchivesUniq)
+ val tempDirArchivesUniq = Utils.createTempDir()
+ clientArchivesUniq.prepareLocalResources(new Path(tempDirArchivesUniq.getAbsolutePath()), Nil)
+
+ }
+
+ test("distribute local spark jars") {
+ val temp = Utils.createTempDir()
+ val jarsDir = new File(temp, "jars")
+ assert(jarsDir.mkdir())
+ val jar = TestUtils.createJarWithFiles(Map(), jarsDir)
+ new FileOutputStream(new File(temp, "RELEASE")).close()
+
+ val sparkConf = new SparkConfWithEnv(Map("SPARK_HOME" -> temp.getAbsolutePath()))
+ val client = createClient(sparkConf)
+ client.prepareLocalResources(new Path(temp.getAbsolutePath()), Nil)
+ classpath(client) should contain (buildPath(PWD, LOCALIZED_LIB_DIR, "*"))
+ }
+
+ test("ignore same name jars") {
+ val libs = Utils.createTempDir()
+ val jarsDir = new File(libs, "jars")
+ assert(jarsDir.mkdir())
+ new FileOutputStream(new File(libs, "RELEASE")).close()
+ val userLib1 = Utils.createTempDir()
+ val userLib2 = Utils.createTempDir()
+
+ val jar1 = TestUtils.createJarWithFiles(Map(), jarsDir)
+ val jar2 = TestUtils.createJarWithFiles(Map(), userLib1)
+ // Copy jar2 to jar3 with same name
+ val jar3 = {
+ val target = new File(userLib2, new File(jar2.toURI).getName)
+ val input = new FileInputStream(jar2.getPath)
+ val output = new FileOutputStream(target)
+ Utils.copyStream(input, output, closeStreams = true)
+ target.toURI.toURL
+ }
+
+ val sparkConf = new SparkConfWithEnv(Map("SPARK_HOME" -> libs.getAbsolutePath))
+ .set(JARS_TO_DISTRIBUTE, Seq(jar2.getPath, jar3.getPath))
+
+ val client = createClient(sparkConf)
+ val tempDir = Utils.createTempDir()
+ client.prepareLocalResources(new Path(tempDir.getAbsolutePath()), Nil)
+
+ // Only jar2 will be added to SECONDARY_JARS, jar3 which has the same name with jar2 will be
+ // ignored.
+ sparkConf.get(SECONDARY_JARS) should be (Some(Seq(new File(jar2.toURI).getName)))
+ }
+
+ object Fixtures {
+
+ val knownDefYarnAppCP: Seq[String] =
+ getFieldValue[Array[String], Seq[String]](classOf[YarnConfiguration],
+ "DEFAULT_YARN_APPLICATION_CLASSPATH",
+ Seq[String]())(a => a.toSeq)
+
+
+ val knownDefMRAppCP: Seq[String] =
+ getFieldValue2[String, Array[String], Seq[String]](
+ classOf[MRJobConfig],
+ "DEFAULT_MAPREDUCE_APPLICATION_CLASSPATH",
+ Seq[String]())(a => a.split(","))(a => a.toSeq)
+
+ val knownYARNAppCP = Some(Seq("/known/yarn/path"))
+
+ val knownMRAppCP = Some(Seq("/known/mr/path"))
+
+ val mapMRAppConf =
+ Map("mapreduce.application.classpath" -> knownMRAppCP.map(_.mkString(":")).get)
+
+ val mapYARNAppConf =
+ Map(YarnConfiguration.YARN_APPLICATION_CLASSPATH -> knownYARNAppCP.map(_.mkString(":")).get)
+
+ val mapAppConf = mapYARNAppConf ++ mapMRAppConf
+ }
+
+ def withAppConf(m: Map[String, String] = Map())(testCode: (Configuration) => Any) {
+ val conf = new Configuration
+ m.foreach { case (k, v) => conf.set(k, v, "ClientSpec") }
+ testCode(conf)
+ }
+
+ def newEnv: MutableHashMap[String, String] = MutableHashMap[String, String]()
+
+ def classpath(env: MutableHashMap[String, String]): Array[String] =
+ env(Environment.CLASSPATH.name).split(":|;|<CPS>")
+
+ def flatten(a: Option[Seq[String]], b: Option[Seq[String]]): Array[String] =
+ (a ++ b).flatten.toArray
+
+ def getFieldValue[A, B](clazz: Class[_], field: String, defaults: => B)(mapTo: A => B): B = {
+ Try(clazz.getField(field))
+ .map(_.get(null).asInstanceOf[A])
+ .toOption
+ .map(mapTo)
+ .getOrElse(defaults)
+ }
+
+ def getFieldValue2[A: ClassTag, A1: ClassTag, B](
+ clazz: Class[_],
+ field: String,
+ defaults: => B)(mapTo: A => B)(mapTo1: A1 => B): B = {
+ Try(clazz.getField(field)).map(_.get(null)).map {
+ case v: A => mapTo(v)
+ case v1: A1 => mapTo1(v1)
+ case _ => defaults
+ }.toOption.getOrElse(defaults)
+ }
+
+ private def createClient(
+ sparkConf: SparkConf,
+ conf: Configuration = new Configuration(),
+ args: Array[String] = Array()): Client = {
+ val clientArgs = new ClientArguments(args)
+ spy(new Client(clientArgs, conf, sparkConf))
+ }
+
+ private def classpath(client: Client): Array[String] = {
+ val env = new MutableHashMap[String, String]()
+ populateClasspath(null, client.hadoopConf, client.sparkConf, env)
+ classpath(env)
+ }
+
+}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala
new file mode 100644
index 0000000000..afb4b691b5
--- /dev/null
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
+import org.scalatest.{BeforeAndAfterEach, Matchers}
+
+import org.apache.spark.SparkFunSuite
+
+class ContainerPlacementStrategySuite extends SparkFunSuite with Matchers with BeforeAndAfterEach {
+
+ private val yarnAllocatorSuite = new YarnAllocatorSuite
+ import yarnAllocatorSuite._
+
+ def createContainerRequest(nodes: Array[String]): ContainerRequest =
+ new ContainerRequest(containerResource, nodes, null, YarnSparkHadoopUtil.RM_REQUEST_PRIORITY)
+
+ override def beforeEach() {
+ yarnAllocatorSuite.beforeEach()
+ }
+
+ override def afterEach() {
+ yarnAllocatorSuite.afterEach()
+ }
+
+ test("allocate locality preferred containers with enough resource and no matched existed " +
+ "containers") {
+ // 1. All the locations of current containers cannot satisfy the new requirements
+ // 2. Current requested container number can fully satisfy the pending tasks.
+
+ val handler = createAllocator(2)
+ handler.updateResourceRequests()
+ handler.handleAllocatedContainers(Array(createContainer("host1"), createContainer("host2")))
+
+ val localities = handler.containerPlacementStrategy.localityOfRequestedContainers(
+ 3, 15, Map("host3" -> 15, "host4" -> 15, "host5" -> 10),
+ handler.allocatedHostToContainersMap, Seq.empty)
+
+ assert(localities.map(_.nodes) === Array(
+ Array("host3", "host4", "host5"),
+ Array("host3", "host4", "host5"),
+ Array("host3", "host4")))
+ }
+
+ test("allocate locality preferred containers with enough resource and partially matched " +
+ "containers") {
+ // 1. Parts of current containers' locations can satisfy the new requirements
+ // 2. Current requested container number can fully satisfy the pending tasks.
+
+ val handler = createAllocator(3)
+ handler.updateResourceRequests()
+ handler.handleAllocatedContainers(Array(
+ createContainer("host1"),
+ createContainer("host1"),
+ createContainer("host2")
+ ))
+
+ val localities = handler.containerPlacementStrategy.localityOfRequestedContainers(
+ 3, 15, Map("host1" -> 15, "host2" -> 15, "host3" -> 10),
+ handler.allocatedHostToContainersMap, Seq.empty)
+
+ assert(localities.map(_.nodes) ===
+ Array(null, Array("host2", "host3"), Array("host2", "host3")))
+ }
+
+ test("allocate locality preferred containers with limited resource and partially matched " +
+ "containers") {
+ // 1. Parts of current containers' locations can satisfy the new requirements
+ // 2. Current requested container number cannot fully satisfy the pending tasks.
+
+ val handler = createAllocator(3)
+ handler.updateResourceRequests()
+ handler.handleAllocatedContainers(Array(
+ createContainer("host1"),
+ createContainer("host1"),
+ createContainer("host2")
+ ))
+
+ val localities = handler.containerPlacementStrategy.localityOfRequestedContainers(
+ 1, 15, Map("host1" -> 15, "host2" -> 15, "host3" -> 10),
+ handler.allocatedHostToContainersMap, Seq.empty)
+
+ assert(localities.map(_.nodes) === Array(Array("host2", "host3")))
+ }
+
+ test("allocate locality preferred containers with fully matched containers") {
+ // Current containers' locations can fully satisfy the new requirements
+
+ val handler = createAllocator(5)
+ handler.updateResourceRequests()
+ handler.handleAllocatedContainers(Array(
+ createContainer("host1"),
+ createContainer("host1"),
+ createContainer("host2"),
+ createContainer("host2"),
+ createContainer("host3")
+ ))
+
+ val localities = handler.containerPlacementStrategy.localityOfRequestedContainers(
+ 3, 15, Map("host1" -> 15, "host2" -> 15, "host3" -> 10),
+ handler.allocatedHostToContainersMap, Seq.empty)
+
+ assert(localities.map(_.nodes) === Array(null, null, null))
+ }
+
+ test("allocate containers with no locality preference") {
+ // Request new container without locality preference
+
+ val handler = createAllocator(2)
+ handler.updateResourceRequests()
+ handler.handleAllocatedContainers(Array(createContainer("host1"), createContainer("host2")))
+
+ val localities = handler.containerPlacementStrategy.localityOfRequestedContainers(
+ 1, 0, Map.empty, handler.allocatedHostToContainersMap, Seq.empty)
+
+ assert(localities.map(_.nodes) === Array(null))
+ }
+
+ test("allocate locality preferred containers by considering the localities of pending requests") {
+ val handler = createAllocator(3)
+ handler.updateResourceRequests()
+ handler.handleAllocatedContainers(Array(
+ createContainer("host1"),
+ createContainer("host1"),
+ createContainer("host2")
+ ))
+
+ val pendingAllocationRequests = Seq(
+ createContainerRequest(Array("host2", "host3")),
+ createContainerRequest(Array("host1", "host4")))
+
+ val localities = handler.containerPlacementStrategy.localityOfRequestedContainers(
+ 1, 15, Map("host1" -> 15, "host2" -> 15, "host3" -> 10),
+ handler.allocatedHostToContainersMap, pendingAllocationRequests)
+
+ assert(localities.map(_.nodes) === Array(Array("host3")))
+ }
+}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
new file mode 100644
index 0000000000..994dc75d34
--- /dev/null
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
@@ -0,0 +1,344 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.util.{Arrays, List => JList}
+
+import org.apache.hadoop.fs.CommonConfigurationKeysPublic
+import org.apache.hadoop.net.DNSToSwitchMapping
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.client.api.AMRMClient
+import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.mockito.Mockito._
+import org.scalatest.{BeforeAndAfterEach, Matchers}
+
+import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
+import org.apache.spark.deploy.yarn.config._
+import org.apache.spark.deploy.yarn.YarnAllocator._
+import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._
+import org.apache.spark.rpc.RpcEndpointRef
+import org.apache.spark.scheduler.SplitInfo
+import org.apache.spark.util.ManualClock
+
+class MockResolver extends DNSToSwitchMapping {
+
+ override def resolve(names: JList[String]): JList[String] = {
+ if (names.size > 0 && names.get(0) == "host3") Arrays.asList("/rack2")
+ else Arrays.asList("/rack1")
+ }
+
+ override def reloadCachedMappings() {}
+
+ def reloadCachedMappings(names: JList[String]) {}
+}
+
+class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach {
+ val conf = new YarnConfiguration()
+ conf.setClass(
+ CommonConfigurationKeysPublic.NET_TOPOLOGY_NODE_SWITCH_MAPPING_IMPL_KEY,
+ classOf[MockResolver], classOf[DNSToSwitchMapping])
+
+ val sparkConf = new SparkConf()
+ sparkConf.set("spark.driver.host", "localhost")
+ sparkConf.set("spark.driver.port", "4040")
+ sparkConf.set(SPARK_JARS, Seq("notarealjar.jar"))
+ sparkConf.set("spark.yarn.launchContainers", "false")
+
+ val appAttemptId = ApplicationAttemptId.newInstance(ApplicationId.newInstance(0, 0), 0)
+
+ // Resource returned by YARN. YARN can give larger containers than requested, so give 6 cores
+ // instead of the 5 requested and 3 GB instead of the 2 requested.
+ val containerResource = Resource.newInstance(3072, 6)
+
+ var rmClient: AMRMClient[ContainerRequest] = _
+
+ var containerNum = 0
+
+ override def beforeEach() {
+ super.beforeEach()
+ rmClient = AMRMClient.createAMRMClient()
+ rmClient.init(conf)
+ rmClient.start()
+ }
+
+ override def afterEach() {
+ try {
+ rmClient.stop()
+ } finally {
+ super.afterEach()
+ }
+ }
+
+ class MockSplitInfo(host: String) extends SplitInfo(null, host, null, 1, null) {
+ override def hashCode(): Int = 0
+ override def equals(other: Any): Boolean = false
+ }
+
+ def createAllocator(maxExecutors: Int = 5): YarnAllocator = {
+ val args = Array(
+ "--jar", "somejar.jar",
+ "--class", "SomeClass")
+ val sparkConfClone = sparkConf.clone()
+ sparkConfClone
+ .set("spark.executor.instances", maxExecutors.toString)
+ .set("spark.executor.cores", "5")
+ .set("spark.executor.memory", "2048")
+ new YarnAllocator(
+ "not used",
+ mock(classOf[RpcEndpointRef]),
+ conf,
+ sparkConfClone,
+ rmClient,
+ appAttemptId,
+ new SecurityManager(sparkConf),
+ Map())
+ }
+
+ def createContainer(host: String): Container = {
+ // When YARN 2.6+ is required, avoid deprecation by using version with long second arg
+ val containerId = ContainerId.newInstance(appAttemptId, containerNum)
+ containerNum += 1
+ val nodeId = NodeId.newInstance(host, 1000)
+ Container.newInstance(containerId, nodeId, "", containerResource, RM_REQUEST_PRIORITY, null)
+ }
+
+ test("single container allocated") {
+ // request a single container and receive it
+ val handler = createAllocator(1)
+ handler.updateResourceRequests()
+ handler.getNumExecutorsRunning should be (0)
+ handler.getPendingAllocate.size should be (1)
+
+ val container = createContainer("host1")
+ handler.handleAllocatedContainers(Array(container))
+
+ handler.getNumExecutorsRunning should be (1)
+ handler.allocatedContainerToHostMap.get(container.getId).get should be ("host1")
+ handler.allocatedHostToContainersMap.get("host1").get should contain (container.getId)
+
+ val size = rmClient.getMatchingRequests(container.getPriority, "host1", containerResource).size
+ size should be (0)
+ }
+
+ test("container should not be created if requested number if met") {
+ // request a single container and receive it
+ val handler = createAllocator(1)
+ handler.updateResourceRequests()
+ handler.getNumExecutorsRunning should be (0)
+ handler.getPendingAllocate.size should be (1)
+
+ val container = createContainer("host1")
+ handler.handleAllocatedContainers(Array(container))
+
+ handler.getNumExecutorsRunning should be (1)
+ handler.allocatedContainerToHostMap.get(container.getId).get should be ("host1")
+ handler.allocatedHostToContainersMap.get("host1").get should contain (container.getId)
+
+ val container2 = createContainer("host2")
+ handler.handleAllocatedContainers(Array(container2))
+ handler.getNumExecutorsRunning should be (1)
+ }
+
+ test("some containers allocated") {
+ // request a few containers and receive some of them
+ val handler = createAllocator(4)
+ handler.updateResourceRequests()
+ handler.getNumExecutorsRunning should be (0)
+ handler.getPendingAllocate.size should be (4)
+
+ val container1 = createContainer("host1")
+ val container2 = createContainer("host1")
+ val container3 = createContainer("host2")
+ handler.handleAllocatedContainers(Array(container1, container2, container3))
+
+ handler.getNumExecutorsRunning should be (3)
+ handler.allocatedContainerToHostMap.get(container1.getId).get should be ("host1")
+ handler.allocatedContainerToHostMap.get(container2.getId).get should be ("host1")
+ handler.allocatedContainerToHostMap.get(container3.getId).get should be ("host2")
+ handler.allocatedHostToContainersMap.get("host1").get should contain (container1.getId)
+ handler.allocatedHostToContainersMap.get("host1").get should contain (container2.getId)
+ handler.allocatedHostToContainersMap.get("host2").get should contain (container3.getId)
+ }
+
+ test("receive more containers than requested") {
+ val handler = createAllocator(2)
+ handler.updateResourceRequests()
+ handler.getNumExecutorsRunning should be (0)
+ handler.getPendingAllocate.size should be (2)
+
+ val container1 = createContainer("host1")
+ val container2 = createContainer("host2")
+ val container3 = createContainer("host4")
+ handler.handleAllocatedContainers(Array(container1, container2, container3))
+
+ handler.getNumExecutorsRunning should be (2)
+ handler.allocatedContainerToHostMap.get(container1.getId).get should be ("host1")
+ handler.allocatedContainerToHostMap.get(container2.getId).get should be ("host2")
+ handler.allocatedContainerToHostMap.contains(container3.getId) should be (false)
+ handler.allocatedHostToContainersMap.get("host1").get should contain (container1.getId)
+ handler.allocatedHostToContainersMap.get("host2").get should contain (container2.getId)
+ handler.allocatedHostToContainersMap.contains("host4") should be (false)
+ }
+
+ test("decrease total requested executors") {
+ val handler = createAllocator(4)
+ handler.updateResourceRequests()
+ handler.getNumExecutorsRunning should be (0)
+ handler.getPendingAllocate.size should be (4)
+
+ handler.requestTotalExecutorsWithPreferredLocalities(3, 0, Map.empty)
+ handler.updateResourceRequests()
+ handler.getPendingAllocate.size should be (3)
+
+ val container = createContainer("host1")
+ handler.handleAllocatedContainers(Array(container))
+
+ handler.getNumExecutorsRunning should be (1)
+ handler.allocatedContainerToHostMap.get(container.getId).get should be ("host1")
+ handler.allocatedHostToContainersMap.get("host1").get should contain (container.getId)
+
+ handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map.empty)
+ handler.updateResourceRequests()
+ handler.getPendingAllocate.size should be (1)
+ }
+
+ test("decrease total requested executors to less than currently running") {
+ val handler = createAllocator(4)
+ handler.updateResourceRequests()
+ handler.getNumExecutorsRunning should be (0)
+ handler.getPendingAllocate.size should be (4)
+
+ handler.requestTotalExecutorsWithPreferredLocalities(3, 0, Map.empty)
+ handler.updateResourceRequests()
+ handler.getPendingAllocate.size should be (3)
+
+ val container1 = createContainer("host1")
+ val container2 = createContainer("host2")
+ handler.handleAllocatedContainers(Array(container1, container2))
+
+ handler.getNumExecutorsRunning should be (2)
+
+ handler.requestTotalExecutorsWithPreferredLocalities(1, 0, Map.empty)
+ handler.updateResourceRequests()
+ handler.getPendingAllocate.size should be (0)
+ handler.getNumExecutorsRunning should be (2)
+ }
+
+ test("kill executors") {
+ val handler = createAllocator(4)
+ handler.updateResourceRequests()
+ handler.getNumExecutorsRunning should be (0)
+ handler.getPendingAllocate.size should be (4)
+
+ val container1 = createContainer("host1")
+ val container2 = createContainer("host2")
+ handler.handleAllocatedContainers(Array(container1, container2))
+
+ handler.requestTotalExecutorsWithPreferredLocalities(1, 0, Map.empty)
+ handler.executorIdToContainer.keys.foreach { id => handler.killExecutor(id ) }
+
+ val statuses = Seq(container1, container2).map { c =>
+ ContainerStatus.newInstance(c.getId(), ContainerState.COMPLETE, "Finished", 0)
+ }
+ handler.updateResourceRequests()
+ handler.processCompletedContainers(statuses.toSeq)
+ handler.getNumExecutorsRunning should be (0)
+ handler.getPendingAllocate.size should be (1)
+ }
+
+ test("lost executor removed from backend") {
+ val handler = createAllocator(4)
+ handler.updateResourceRequests()
+ handler.getNumExecutorsRunning should be (0)
+ handler.getPendingAllocate.size should be (4)
+
+ val container1 = createContainer("host1")
+ val container2 = createContainer("host2")
+ handler.handleAllocatedContainers(Array(container1, container2))
+
+ handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map())
+
+ val statuses = Seq(container1, container2).map { c =>
+ ContainerStatus.newInstance(c.getId(), ContainerState.COMPLETE, "Failed", -1)
+ }
+ handler.updateResourceRequests()
+ handler.processCompletedContainers(statuses.toSeq)
+ handler.updateResourceRequests()
+ handler.getNumExecutorsRunning should be (0)
+ handler.getPendingAllocate.size should be (2)
+ handler.getNumExecutorsFailed should be (2)
+ handler.getNumUnexpectedContainerRelease should be (2)
+ }
+
+ test("memory exceeded diagnostic regexes") {
+ val diagnostics =
+ "Container [pid=12465,containerID=container_1412887393566_0003_01_000002] is running " +
+ "beyond physical memory limits. Current usage: 2.1 MB of 2 GB physical memory used; " +
+ "5.8 GB of 4.2 GB virtual memory used. Killing container."
+ val vmemMsg = memLimitExceededLogMessage(diagnostics, VMEM_EXCEEDED_PATTERN)
+ val pmemMsg = memLimitExceededLogMessage(diagnostics, PMEM_EXCEEDED_PATTERN)
+ assert(vmemMsg.contains("5.8 GB of 4.2 GB virtual memory used."))
+ assert(pmemMsg.contains("2.1 MB of 2 GB physical memory used."))
+ }
+
+ test("window based failure executor counting") {
+ sparkConf.set("spark.yarn.executor.failuresValidityInterval", "100s")
+ val handler = createAllocator(4)
+ val clock = new ManualClock(0L)
+ handler.setClock(clock)
+
+ handler.updateResourceRequests()
+ handler.getNumExecutorsRunning should be (0)
+ handler.getPendingAllocate.size should be (4)
+
+ val containers = Seq(
+ createContainer("host1"),
+ createContainer("host2"),
+ createContainer("host3"),
+ createContainer("host4")
+ )
+ handler.handleAllocatedContainers(containers)
+
+ val failedStatuses = containers.map { c =>
+ ContainerStatus.newInstance(c.getId, ContainerState.COMPLETE, "Failed", -1)
+ }
+
+ handler.getNumExecutorsFailed should be (0)
+
+ clock.advance(100 * 1000L)
+ handler.processCompletedContainers(failedStatuses.slice(0, 1))
+ handler.getNumExecutorsFailed should be (1)
+
+ clock.advance(101 * 1000L)
+ handler.getNumExecutorsFailed should be (0)
+
+ handler.processCompletedContainers(failedStatuses.slice(1, 3))
+ handler.getNumExecutorsFailed should be (2)
+
+ clock.advance(50 * 1000L)
+ handler.processCompletedContainers(failedStatuses.slice(3, 4))
+ handler.getNumExecutorsFailed should be (3)
+
+ clock.advance(51 * 1000L)
+ handler.getNumExecutorsFailed should be (1)
+
+ clock.advance(50 * 1000L)
+ handler.getNumExecutorsFailed should be (0)
+ }
+}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
new file mode 100644
index 0000000000..99fb58a289
--- /dev/null
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
@@ -0,0 +1,493 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.io.File
+import java.net.URL
+import java.nio.charset.StandardCharsets
+import java.util.{HashMap => JHashMap}
+
+import scala.collection.mutable
+import scala.concurrent.duration._
+import scala.language.postfixOps
+
+import com.google.common.io.{ByteStreams, Files}
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.scalatest.Matchers
+import org.scalatest.concurrent.Eventually._
+
+import org.apache.spark._
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.deploy.yarn.config._
+import org.apache.spark.internal.Logging
+import org.apache.spark.launcher._
+import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart,
+ SparkListenerExecutorAdded}
+import org.apache.spark.scheduler.cluster.ExecutorInfo
+import org.apache.spark.tags.ExtendedYarnTest
+import org.apache.spark.util.Utils
+
+/**
+ * Integration tests for YARN; these tests use a mini Yarn cluster to run Spark-on-YARN
+ * applications, and require the Spark assembly to be built before they can be successfully
+ * run.
+ */
+@ExtendedYarnTest
+class YarnClusterSuite extends BaseYarnClusterSuite {
+
+ override def newYarnConfig(): YarnConfiguration = new YarnConfiguration()
+
+ private val TEST_PYFILE = """
+ |import mod1, mod2
+ |import sys
+ |from operator import add
+ |
+ |from pyspark import SparkConf , SparkContext
+ |if __name__ == "__main__":
+ | if len(sys.argv) != 2:
+ | print >> sys.stderr, "Usage: test.py [result file]"
+ | exit(-1)
+ | sc = SparkContext(conf=SparkConf())
+ | status = open(sys.argv[1],'w')
+ | result = "failure"
+ | rdd = sc.parallelize(range(10)).map(lambda x: x * mod1.func() * mod2.func())
+ | cnt = rdd.count()
+ | if cnt == 10:
+ | result = "success"
+ | status.write(result)
+ | status.close()
+ | sc.stop()
+ """.stripMargin
+
+ private val TEST_PYMODULE = """
+ |def func():
+ | return 42
+ """.stripMargin
+
+ test("run Spark in yarn-client mode") {
+ testBasicYarnApp(true)
+ }
+
+ test("run Spark in yarn-cluster mode") {
+ testBasicYarnApp(false)
+ }
+
+ test("run Spark in yarn-client mode with different configurations") {
+ testBasicYarnApp(true,
+ Map(
+ "spark.driver.memory" -> "512m",
+ "spark.executor.cores" -> "1",
+ "spark.executor.memory" -> "512m",
+ "spark.executor.instances" -> "2"
+ ))
+ }
+
+ test("run Spark in yarn-cluster mode with different configurations") {
+ testBasicYarnApp(false,
+ Map(
+ "spark.driver.memory" -> "512m",
+ "spark.driver.cores" -> "1",
+ "spark.executor.cores" -> "1",
+ "spark.executor.memory" -> "512m",
+ "spark.executor.instances" -> "2"
+ ))
+ }
+
+ test("run Spark in yarn-cluster mode with using SparkHadoopUtil.conf") {
+ testYarnAppUseSparkHadoopUtilConf()
+ }
+
+ test("run Spark in yarn-client mode with additional jar") {
+ testWithAddJar(true)
+ }
+
+ test("run Spark in yarn-cluster mode with additional jar") {
+ testWithAddJar(false)
+ }
+
+ test("run Spark in yarn-cluster mode unsuccessfully") {
+ // Don't provide arguments so the driver will fail.
+ val finalState = runSpark(false, mainClassName(YarnClusterDriver.getClass))
+ finalState should be (SparkAppHandle.State.FAILED)
+ }
+
+ test("run Spark in yarn-cluster mode failure after sc initialized") {
+ val finalState = runSpark(false, mainClassName(YarnClusterDriverWithFailure.getClass))
+ finalState should be (SparkAppHandle.State.FAILED)
+ }
+
+ test("run Python application in yarn-client mode") {
+ testPySpark(true)
+ }
+
+ test("run Python application in yarn-cluster mode") {
+ testPySpark(false)
+ }
+
+ test("run Python application in yarn-cluster mode using " +
+ " spark.yarn.appMasterEnv to override local envvar") {
+ testPySpark(
+ clientMode = false,
+ extraConf = Map(
+ "spark.yarn.appMasterEnv.PYSPARK_DRIVER_PYTHON"
+ -> sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python"),
+ "spark.yarn.appMasterEnv.PYSPARK_PYTHON"
+ -> sys.env.getOrElse("PYSPARK_PYTHON", "python")),
+ extraEnv = Map(
+ "PYSPARK_DRIVER_PYTHON" -> "not python",
+ "PYSPARK_PYTHON" -> "not python"))
+ }
+
+ test("user class path first in client mode") {
+ testUseClassPathFirst(true)
+ }
+
+ test("user class path first in cluster mode") {
+ testUseClassPathFirst(false)
+ }
+
+ test("monitor app using launcher library") {
+ val env = new JHashMap[String, String]()
+ env.put("YARN_CONF_DIR", hadoopConfDir.getAbsolutePath())
+
+ val propsFile = createConfFile()
+ val handle = new SparkLauncher(env)
+ .setSparkHome(sys.props("spark.test.home"))
+ .setConf("spark.ui.enabled", "false")
+ .setPropertiesFile(propsFile)
+ .setMaster("yarn")
+ .setDeployMode("client")
+ .setAppResource(SparkLauncher.NO_RESOURCE)
+ .setMainClass(mainClassName(YarnLauncherTestApp.getClass))
+ .startApplication()
+
+ try {
+ eventually(timeout(30 seconds), interval(100 millis)) {
+ handle.getState() should be (SparkAppHandle.State.RUNNING)
+ }
+
+ handle.getAppId() should not be (null)
+ handle.getAppId() should startWith ("application_")
+ handle.stop()
+
+ eventually(timeout(30 seconds), interval(100 millis)) {
+ handle.getState() should be (SparkAppHandle.State.KILLED)
+ }
+ } finally {
+ handle.kill()
+ }
+ }
+
+ test("timeout to get SparkContext in cluster mode triggers failure") {
+ val timeout = 2000
+ val finalState = runSpark(false, mainClassName(SparkContextTimeoutApp.getClass),
+ appArgs = Seq((timeout * 4).toString),
+ extraConf = Map(AM_MAX_WAIT_TIME.key -> timeout.toString))
+ finalState should be (SparkAppHandle.State.FAILED)
+ }
+
+ private def testBasicYarnApp(clientMode: Boolean, conf: Map[String, String] = Map()): Unit = {
+ val result = File.createTempFile("result", null, tempDir)
+ val finalState = runSpark(clientMode, mainClassName(YarnClusterDriver.getClass),
+ appArgs = Seq(result.getAbsolutePath()),
+ extraConf = conf)
+ checkResult(finalState, result)
+ }
+
+ private def testYarnAppUseSparkHadoopUtilConf(): Unit = {
+ val result = File.createTempFile("result", null, tempDir)
+ val finalState = runSpark(false,
+ mainClassName(YarnClusterDriverUseSparkHadoopUtilConf.getClass),
+ appArgs = Seq("key=value", result.getAbsolutePath()),
+ extraConf = Map("spark.hadoop.key" -> "value"))
+ checkResult(finalState, result)
+ }
+
+ private def testWithAddJar(clientMode: Boolean): Unit = {
+ val originalJar = TestUtils.createJarWithFiles(Map("test.resource" -> "ORIGINAL"), tempDir)
+ val driverResult = File.createTempFile("driver", null, tempDir)
+ val executorResult = File.createTempFile("executor", null, tempDir)
+ val finalState = runSpark(clientMode, mainClassName(YarnClasspathTest.getClass),
+ appArgs = Seq(driverResult.getAbsolutePath(), executorResult.getAbsolutePath()),
+ extraClassPath = Seq(originalJar.getPath()),
+ extraJars = Seq("local:" + originalJar.getPath()))
+ checkResult(finalState, driverResult, "ORIGINAL")
+ checkResult(finalState, executorResult, "ORIGINAL")
+ }
+
+ private def testPySpark(
+ clientMode: Boolean,
+ extraConf: Map[String, String] = Map(),
+ extraEnv: Map[String, String] = Map()): Unit = {
+ val primaryPyFile = new File(tempDir, "test.py")
+ Files.write(TEST_PYFILE, primaryPyFile, StandardCharsets.UTF_8)
+
+ // When running tests, let's not assume the user has built the assembly module, which also
+ // creates the pyspark archive. Instead, let's use PYSPARK_ARCHIVES_PATH to point at the
+ // needed locations.
+ val sparkHome = sys.props("spark.test.home")
+ val pythonPath = Seq(
+ s"$sparkHome/python/lib/py4j-0.10.4-src.zip",
+ s"$sparkHome/python")
+ val extraEnvVars = Map(
+ "PYSPARK_ARCHIVES_PATH" -> pythonPath.map("local:" + _).mkString(File.pathSeparator),
+ "PYTHONPATH" -> pythonPath.mkString(File.pathSeparator)) ++ extraEnv
+
+ val moduleDir =
+ if (clientMode) {
+ // In client-mode, .py files added with --py-files are not visible in the driver.
+ // This is something that the launcher library would have to handle.
+ tempDir
+ } else {
+ val subdir = new File(tempDir, "pyModules")
+ subdir.mkdir()
+ subdir
+ }
+ val pyModule = new File(moduleDir, "mod1.py")
+ Files.write(TEST_PYMODULE, pyModule, StandardCharsets.UTF_8)
+
+ val mod2Archive = TestUtils.createJarWithFiles(Map("mod2.py" -> TEST_PYMODULE), moduleDir)
+ val pyFiles = Seq(pyModule.getAbsolutePath(), mod2Archive.getPath()).mkString(",")
+ val result = File.createTempFile("result", null, tempDir)
+
+ val finalState = runSpark(clientMode, primaryPyFile.getAbsolutePath(),
+ sparkArgs = Seq("--py-files" -> pyFiles),
+ appArgs = Seq(result.getAbsolutePath()),
+ extraEnv = extraEnvVars,
+ extraConf = extraConf)
+ checkResult(finalState, result)
+ }
+
+ private def testUseClassPathFirst(clientMode: Boolean): Unit = {
+ // Create a jar file that contains a different version of "test.resource".
+ val originalJar = TestUtils.createJarWithFiles(Map("test.resource" -> "ORIGINAL"), tempDir)
+ val userJar = TestUtils.createJarWithFiles(Map("test.resource" -> "OVERRIDDEN"), tempDir)
+ val driverResult = File.createTempFile("driver", null, tempDir)
+ val executorResult = File.createTempFile("executor", null, tempDir)
+ val finalState = runSpark(clientMode, mainClassName(YarnClasspathTest.getClass),
+ appArgs = Seq(driverResult.getAbsolutePath(), executorResult.getAbsolutePath()),
+ extraClassPath = Seq(originalJar.getPath()),
+ extraJars = Seq("local:" + userJar.getPath()),
+ extraConf = Map(
+ "spark.driver.userClassPathFirst" -> "true",
+ "spark.executor.userClassPathFirst" -> "true"))
+ checkResult(finalState, driverResult, "OVERRIDDEN")
+ checkResult(finalState, executorResult, "OVERRIDDEN")
+ }
+
+}
+
+private[spark] class SaveExecutorInfo extends SparkListener {
+ val addedExecutorInfos = mutable.Map[String, ExecutorInfo]()
+ var driverLogs: Option[collection.Map[String, String]] = None
+
+ override def onExecutorAdded(executor: SparkListenerExecutorAdded) {
+ addedExecutorInfos(executor.executorId) = executor.executorInfo
+ }
+
+ override def onApplicationStart(appStart: SparkListenerApplicationStart): Unit = {
+ driverLogs = appStart.driverLogs
+ }
+}
+
+private object YarnClusterDriverWithFailure extends Logging with Matchers {
+ def main(args: Array[String]): Unit = {
+ val sc = new SparkContext(new SparkConf()
+ .set("spark.extraListeners", classOf[SaveExecutorInfo].getName)
+ .setAppName("yarn test with failure"))
+
+ throw new Exception("exception after sc initialized")
+ }
+}
+
+private object YarnClusterDriverUseSparkHadoopUtilConf extends Logging with Matchers {
+ def main(args: Array[String]): Unit = {
+ if (args.length != 2) {
+ // scalastyle:off println
+ System.err.println(
+ s"""
+ |Invalid command line: ${args.mkString(" ")}
+ |
+ |Usage: YarnClusterDriverUseSparkHadoopUtilConf [hadoopConfKey=value] [result file]
+ """.stripMargin)
+ // scalastyle:on println
+ System.exit(1)
+ }
+
+ val sc = new SparkContext(new SparkConf()
+ .set("spark.extraListeners", classOf[SaveExecutorInfo].getName)
+ .setAppName("yarn test using SparkHadoopUtil's conf"))
+
+ val kv = args(0).split("=")
+ val status = new File(args(1))
+ var result = "failure"
+ try {
+ SparkHadoopUtil.get.conf.get(kv(0)) should be (kv(1))
+ result = "success"
+ } finally {
+ Files.write(result, status, StandardCharsets.UTF_8)
+ sc.stop()
+ }
+ }
+}
+
+private object YarnClusterDriver extends Logging with Matchers {
+
+ val WAIT_TIMEOUT_MILLIS = 10000
+
+ def main(args: Array[String]): Unit = {
+ if (args.length != 1) {
+ // scalastyle:off println
+ System.err.println(
+ s"""
+ |Invalid command line: ${args.mkString(" ")}
+ |
+ |Usage: YarnClusterDriver [result file]
+ """.stripMargin)
+ // scalastyle:on println
+ System.exit(1)
+ }
+
+ val sc = new SparkContext(new SparkConf()
+ .set("spark.extraListeners", classOf[SaveExecutorInfo].getName)
+ .setAppName("yarn \"test app\" 'with quotes' and \\back\\slashes and $dollarSigns"))
+ val conf = sc.getConf
+ val status = new File(args(0))
+ var result = "failure"
+ try {
+ val data = sc.parallelize(1 to 4, 4).collect().toSet
+ sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ data should be (Set(1, 2, 3, 4))
+ result = "success"
+
+ // Verify that the config archive is correctly placed in the classpath of all containers.
+ val confFile = "/" + Client.SPARK_CONF_FILE
+ assert(getClass().getResource(confFile) != null)
+ val configFromExecutors = sc.parallelize(1 to 4, 4)
+ .map { _ => Option(getClass().getResource(confFile)).map(_.toString).orNull }
+ .collect()
+ assert(configFromExecutors.find(_ == null) === None)
+ } finally {
+ Files.write(result, status, StandardCharsets.UTF_8)
+ sc.stop()
+ }
+
+ // verify log urls are present
+ val listeners = sc.listenerBus.findListenersByClass[SaveExecutorInfo]
+ assert(listeners.size === 1)
+ val listener = listeners(0)
+ val executorInfos = listener.addedExecutorInfos.values
+ assert(executorInfos.nonEmpty)
+ executorInfos.foreach { info =>
+ assert(info.logUrlMap.nonEmpty)
+ }
+
+ // If we are running in yarn-cluster mode, verify that driver logs links and present and are
+ // in the expected format.
+ if (conf.get("spark.submit.deployMode") == "cluster") {
+ assert(listener.driverLogs.nonEmpty)
+ val driverLogs = listener.driverLogs.get
+ assert(driverLogs.size === 2)
+ assert(driverLogs.contains("stderr"))
+ assert(driverLogs.contains("stdout"))
+ val urlStr = driverLogs("stderr")
+ // Ensure that this is a valid URL, else this will throw an exception
+ new URL(urlStr)
+ val containerId = YarnSparkHadoopUtil.get.getContainerId
+ val user = Utils.getCurrentUserName()
+ assert(urlStr.endsWith(s"/node/containerlogs/$containerId/$user/stderr?start=-4096"))
+ }
+ }
+
+}
+
+private object YarnClasspathTest extends Logging {
+ def error(m: String, ex: Throwable = null): Unit = {
+ logError(m, ex)
+ // scalastyle:off println
+ System.out.println(m)
+ if (ex != null) {
+ ex.printStackTrace(System.out)
+ }
+ // scalastyle:on println
+ }
+
+ def main(args: Array[String]): Unit = {
+ if (args.length != 2) {
+ error(
+ s"""
+ |Invalid command line: ${args.mkString(" ")}
+ |
+ |Usage: YarnClasspathTest [driver result file] [executor result file]
+ """.stripMargin)
+ // scalastyle:on println
+ }
+
+ readResource(args(0))
+ val sc = new SparkContext(new SparkConf())
+ try {
+ sc.parallelize(Seq(1)).foreach { x => readResource(args(1)) }
+ } finally {
+ sc.stop()
+ }
+ }
+
+ private def readResource(resultPath: String): Unit = {
+ var result = "failure"
+ try {
+ val ccl = Thread.currentThread().getContextClassLoader()
+ val resource = ccl.getResourceAsStream("test.resource")
+ val bytes = ByteStreams.toByteArray(resource)
+ result = new String(bytes, 0, bytes.length, StandardCharsets.UTF_8)
+ } catch {
+ case t: Throwable =>
+ error(s"loading test.resource to $resultPath", t)
+ } finally {
+ Files.write(result, new File(resultPath), StandardCharsets.UTF_8)
+ }
+ }
+
+}
+
+private object YarnLauncherTestApp {
+
+ def main(args: Array[String]): Unit = {
+ // Do not stop the application; the test will stop it using the launcher lib. Just run a task
+ // that will prevent the process from exiting.
+ val sc = new SparkContext(new SparkConf())
+ sc.parallelize(Seq(1)).foreach { i =>
+ this.synchronized {
+ wait()
+ }
+ }
+ }
+
+}
+
+/**
+ * Used to test code in the AM that detects the SparkContext instance. Expects a single argument
+ * with the duration to sleep for, in ms.
+ */
+private object SparkContextTimeoutApp {
+
+ def main(args: Array[String]): Unit = {
+ val Array(sleepTime) = args
+ Thread.sleep(java.lang.Long.parseLong(sleepTime))
+ }
+
+}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala
new file mode 100644
index 0000000000..950ebd9a2d
--- /dev/null
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala
@@ -0,0 +1,112 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.deploy.yarn
+
+import java.io.File
+import java.nio.charset.StandardCharsets
+
+import com.google.common.io.Files
+import org.apache.commons.io.FileUtils
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.scalatest.Matchers
+
+import org.apache.spark._
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.shuffle.ShuffleTestAccessor
+import org.apache.spark.network.yarn.{YarnShuffleService, YarnTestAccessor}
+import org.apache.spark.tags.ExtendedYarnTest
+
+/**
+ * Integration test for the external shuffle service with a yarn mini-cluster
+ */
+@ExtendedYarnTest
+class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite {
+
+ override def newYarnConfig(): YarnConfiguration = {
+ val yarnConfig = new YarnConfiguration()
+ yarnConfig.set(YarnConfiguration.NM_AUX_SERVICES, "spark_shuffle")
+ yarnConfig.set(YarnConfiguration.NM_AUX_SERVICE_FMT.format("spark_shuffle"),
+ classOf[YarnShuffleService].getCanonicalName)
+ yarnConfig.set("spark.shuffle.service.port", "0")
+ yarnConfig
+ }
+
+ test("external shuffle service") {
+ val shuffleServicePort = YarnTestAccessor.getShuffleServicePort
+ val shuffleService = YarnTestAccessor.getShuffleServiceInstance
+
+ val registeredExecFile = YarnTestAccessor.getRegisteredExecutorFile(shuffleService)
+
+ logInfo("Shuffle service port = " + shuffleServicePort)
+ val result = File.createTempFile("result", null, tempDir)
+ val finalState = runSpark(
+ false,
+ mainClassName(YarnExternalShuffleDriver.getClass),
+ appArgs = Seq(result.getAbsolutePath(), registeredExecFile.getAbsolutePath),
+ extraConf = Map(
+ "spark.shuffle.service.enabled" -> "true",
+ "spark.shuffle.service.port" -> shuffleServicePort.toString
+ )
+ )
+ checkResult(finalState, result)
+ assert(YarnTestAccessor.getRegisteredExecutorFile(shuffleService).exists())
+ }
+}
+
+private object YarnExternalShuffleDriver extends Logging with Matchers {
+
+ val WAIT_TIMEOUT_MILLIS = 10000
+
+ def main(args: Array[String]): Unit = {
+ if (args.length != 2) {
+ // scalastyle:off println
+ System.err.println(
+ s"""
+ |Invalid command line: ${args.mkString(" ")}
+ |
+ |Usage: ExternalShuffleDriver [result file] [registered exec file]
+ """.stripMargin)
+ // scalastyle:on println
+ System.exit(1)
+ }
+
+ val sc = new SparkContext(new SparkConf()
+ .setAppName("External Shuffle Test"))
+ val conf = sc.getConf
+ val status = new File(args(0))
+ val registeredExecFile = new File(args(1))
+ logInfo("shuffle service executor file = " + registeredExecFile)
+ var result = "failure"
+ val execStateCopy = new File(registeredExecFile.getAbsolutePath + "_dup")
+ try {
+ val data = sc.parallelize(0 until 100, 10).map { x => (x % 10) -> x }.reduceByKey{ _ + _ }.
+ collect().toSet
+ sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ data should be ((0 until 10).map{x => x -> (x * 10 + 450)}.toSet)
+ result = "success"
+ // only one process can open a leveldb file at a time, so we copy the files
+ FileUtils.copyDirectory(registeredExecFile, execStateCopy)
+ assert(!ShuffleTestAccessor.reloadRegisteredExecutors(execStateCopy).isEmpty)
+ } finally {
+ sc.stop()
+ FileUtils.deleteDirectory(execStateCopy)
+ Files.write(result, status, StandardCharsets.UTF_8)
+ }
+ }
+
+}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala
new file mode 100644
index 0000000000..7fbbe12609
--- /dev/null
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.io.{File, IOException}
+import java.nio.charset.StandardCharsets
+
+import com.google.common.io.{ByteStreams, Files}
+import org.apache.hadoop.io.Text
+import org.apache.hadoop.yarn.api.ApplicationConstants
+import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
+import org.apache.hadoop.yarn.api.records.ApplicationAccessType
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.scalatest.Matchers
+
+import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.internal.Logging
+import org.apache.spark.util.{ResetSystemProperties, Utils}
+
+class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging
+ with ResetSystemProperties {
+
+ val hasBash =
+ try {
+ val exitCode = Runtime.getRuntime().exec(Array("bash", "--version")).waitFor()
+ exitCode == 0
+ } catch {
+ case e: IOException =>
+ false
+ }
+
+ if (!hasBash) {
+ logWarning("Cannot execute bash, skipping bash tests.")
+ }
+
+ def bashTest(name: String)(fn: => Unit): Unit =
+ if (hasBash) test(name)(fn) else ignore(name)(fn)
+
+ bashTest("shell script escaping") {
+ val scriptFile = File.createTempFile("script.", ".sh", Utils.createTempDir())
+ val args = Array("arg1", "${arg.2}", "\"arg3\"", "'arg4'", "$arg5", "\\arg6")
+ try {
+ val argLine = args.map(a => YarnSparkHadoopUtil.escapeForShell(a)).mkString(" ")
+ Files.write(("bash -c \"echo " + argLine + "\"").getBytes(StandardCharsets.UTF_8), scriptFile)
+ scriptFile.setExecutable(true)
+
+ val proc = Runtime.getRuntime().exec(Array(scriptFile.getAbsolutePath()))
+ val out = new String(ByteStreams.toByteArray(proc.getInputStream())).trim()
+ val err = new String(ByteStreams.toByteArray(proc.getErrorStream()))
+ val exitCode = proc.waitFor()
+ exitCode should be (0)
+ out should be (args.mkString(" "))
+ } finally {
+ scriptFile.delete()
+ }
+ }
+
+ test("Yarn configuration override") {
+ val key = "yarn.nodemanager.hostname"
+ val default = new YarnConfiguration()
+
+ val sparkConf = new SparkConf()
+ .set("spark.hadoop." + key, "someHostName")
+ val yarnConf = new YarnSparkHadoopUtil().newConfiguration(sparkConf)
+
+ yarnConf.getClass() should be (classOf[YarnConfiguration])
+ yarnConf.get(key) should not be default.get(key)
+ }
+
+
+ test("test getApplicationAclsForYarn acls on") {
+
+ // spark acls on, just pick up default user
+ val sparkConf = new SparkConf()
+ sparkConf.set("spark.acls.enable", "true")
+
+ val securityMgr = new SecurityManager(sparkConf)
+ val acls = YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr)
+
+ val viewAcls = acls.get(ApplicationAccessType.VIEW_APP)
+ val modifyAcls = acls.get(ApplicationAccessType.MODIFY_APP)
+
+ viewAcls match {
+ case Some(vacls) =>
+ val aclSet = vacls.split(',').map(_.trim).toSet
+ assert(aclSet.contains(System.getProperty("user.name", "invalid")))
+ case None =>
+ fail()
+ }
+ modifyAcls match {
+ case Some(macls) =>
+ val aclSet = macls.split(',').map(_.trim).toSet
+ assert(aclSet.contains(System.getProperty("user.name", "invalid")))
+ case None =>
+ fail()
+ }
+ }
+
+ test("test getApplicationAclsForYarn acls on and specify users") {
+
+ // default spark acls are on and specify acls
+ val sparkConf = new SparkConf()
+ sparkConf.set("spark.acls.enable", "true")
+ sparkConf.set("spark.ui.view.acls", "user1,user2")
+ sparkConf.set("spark.modify.acls", "user3,user4")
+
+ val securityMgr = new SecurityManager(sparkConf)
+ val acls = YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr)
+
+ val viewAcls = acls.get(ApplicationAccessType.VIEW_APP)
+ val modifyAcls = acls.get(ApplicationAccessType.MODIFY_APP)
+
+ viewAcls match {
+ case Some(vacls) =>
+ val aclSet = vacls.split(',').map(_.trim).toSet
+ assert(aclSet.contains("user1"))
+ assert(aclSet.contains("user2"))
+ assert(aclSet.contains(System.getProperty("user.name", "invalid")))
+ case None =>
+ fail()
+ }
+ modifyAcls match {
+ case Some(macls) =>
+ val aclSet = macls.split(',').map(_.trim).toSet
+ assert(aclSet.contains("user3"))
+ assert(aclSet.contains("user4"))
+ assert(aclSet.contains(System.getProperty("user.name", "invalid")))
+ case None =>
+ fail()
+ }
+
+ }
+
+ test("test expandEnvironment result") {
+ val target = Environment.PWD
+ if (classOf[Environment].getMethods().exists(_.getName == "$$")) {
+ YarnSparkHadoopUtil.expandEnvironment(target) should be ("{{" + target + "}}")
+ } else if (Utils.isWindows) {
+ YarnSparkHadoopUtil.expandEnvironment(target) should be ("%" + target + "%")
+ } else {
+ YarnSparkHadoopUtil.expandEnvironment(target) should be ("$" + target)
+ }
+
+ }
+
+ test("test getClassPathSeparator result") {
+ if (classOf[ApplicationConstants].getFields().exists(_.getName == "CLASS_PATH_SEPARATOR")) {
+ YarnSparkHadoopUtil.getClassPathSeparator() should be ("<CPS>")
+ } else if (Utils.isWindows) {
+ YarnSparkHadoopUtil.getClassPathSeparator() should be (";")
+ } else {
+ YarnSparkHadoopUtil.getClassPathSeparator() should be (":")
+ }
+ }
+
+ test("check different hadoop utils based on env variable") {
+ try {
+ System.setProperty("SPARK_YARN_MODE", "true")
+ assert(SparkHadoopUtil.get.getClass === classOf[YarnSparkHadoopUtil])
+ System.setProperty("SPARK_YARN_MODE", "false")
+ assert(SparkHadoopUtil.get.getClass === classOf[SparkHadoopUtil])
+ } finally {
+ System.clearProperty("SPARK_YARN_MODE")
+ }
+ }
+
+
+
+ // This test needs to live here because it depends on isYarnMode returning true, which can only
+ // happen in the YARN module.
+ test("security manager token generation") {
+ try {
+ System.setProperty("SPARK_YARN_MODE", "true")
+ val initial = SparkHadoopUtil.get
+ .getSecretKeyFromUserCredentials(SecurityManager.SECRET_LOOKUP_KEY)
+ assert(initial === null || initial.length === 0)
+
+ val conf = new SparkConf()
+ .set(SecurityManager.SPARK_AUTH_CONF, "true")
+ .set(SecurityManager.SPARK_AUTH_SECRET_CONF, "unused")
+ val sm = new SecurityManager(conf)
+
+ val generated = SparkHadoopUtil.get
+ .getSecretKeyFromUserCredentials(SecurityManager.SECRET_LOOKUP_KEY)
+ assert(generated != null)
+ val genString = new Text(generated).toString()
+ assert(genString != "unused")
+ assert(sm.getSecretKey() === genString)
+ } finally {
+ // removeSecretKey() was only added in Hadoop 2.6, so instead we just set the secret
+ // to an empty string.
+ SparkHadoopUtil.get.addSecretKeyToUserCredentials(SecurityManager.SECRET_LOOKUP_KEY, "")
+ System.clearProperty("SPARK_YARN_MODE")
+ }
+ }
+
+}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManagerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManagerSuite.scala
new file mode 100644
index 0000000000..db4619e80c
--- /dev/null
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManagerSuite.scala
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn.security
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.io.Text
+import org.apache.hadoop.security.Credentials
+import org.apache.hadoop.security.token.Token
+import org.scalatest.{BeforeAndAfter, Matchers}
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.deploy.yarn.config._
+
+class ConfigurableCredentialManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfter {
+ private var credentialManager: ConfigurableCredentialManager = null
+ private var sparkConf: SparkConf = null
+ private var hadoopConf: Configuration = null
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+
+ sparkConf = new SparkConf()
+ hadoopConf = new Configuration()
+ System.setProperty("SPARK_YARN_MODE", "true")
+ }
+
+ override def afterAll(): Unit = {
+ System.clearProperty("SPARK_YARN_MODE")
+
+ super.afterAll()
+ }
+
+ test("Correctly load default credential providers") {
+ credentialManager = new ConfigurableCredentialManager(sparkConf, hadoopConf)
+
+ credentialManager.getServiceCredentialProvider("hdfs") should not be (None)
+ credentialManager.getServiceCredentialProvider("hbase") should not be (None)
+ credentialManager.getServiceCredentialProvider("hive") should not be (None)
+ }
+
+ test("disable hive credential provider") {
+ sparkConf.set("spark.yarn.security.credentials.hive.enabled", "false")
+ credentialManager = new ConfigurableCredentialManager(sparkConf, hadoopConf)
+
+ credentialManager.getServiceCredentialProvider("hdfs") should not be (None)
+ credentialManager.getServiceCredentialProvider("hbase") should not be (None)
+ credentialManager.getServiceCredentialProvider("hive") should be (None)
+ }
+
+ test("using deprecated configurations") {
+ sparkConf.set("spark.yarn.security.tokens.hdfs.enabled", "false")
+ sparkConf.set("spark.yarn.security.tokens.hive.enabled", "false")
+ credentialManager = new ConfigurableCredentialManager(sparkConf, hadoopConf)
+
+ credentialManager.getServiceCredentialProvider("hdfs") should be (None)
+ credentialManager.getServiceCredentialProvider("hive") should be (None)
+ credentialManager.getServiceCredentialProvider("test") should not be (None)
+ credentialManager.getServiceCredentialProvider("hbase") should not be (None)
+ }
+
+ test("verify obtaining credentials from provider") {
+ credentialManager = new ConfigurableCredentialManager(sparkConf, hadoopConf)
+ val creds = new Credentials()
+
+ // Tokens can only be obtained from TestTokenProvider, for hdfs, hbase and hive tokens cannot
+ // be obtained.
+ credentialManager.obtainCredentials(hadoopConf, creds)
+ val tokens = creds.getAllTokens
+ tokens.size() should be (1)
+ tokens.iterator().next().getService should be (new Text("test"))
+ }
+
+ test("verify getting credential renewal info") {
+ credentialManager = new ConfigurableCredentialManager(sparkConf, hadoopConf)
+ val creds = new Credentials()
+
+ val testCredentialProvider = credentialManager.getServiceCredentialProvider("test").get
+ .asInstanceOf[TestCredentialProvider]
+ // Only TestTokenProvider can get the time of next token renewal
+ val nextRenewal = credentialManager.obtainCredentials(hadoopConf, creds)
+ nextRenewal should be (testCredentialProvider.timeOfNextTokenRenewal)
+ }
+
+ test("obtain tokens For HiveMetastore") {
+ val hadoopConf = new Configuration()
+ hadoopConf.set("hive.metastore.kerberos.principal", "bob")
+ // thrift picks up on port 0 and bails out, without trying to talk to endpoint
+ hadoopConf.set("hive.metastore.uris", "http://localhost:0")
+
+ val hiveCredentialProvider = new HiveCredentialProvider()
+ val credentials = new Credentials()
+ hiveCredentialProvider.obtainCredentials(hadoopConf, sparkConf, credentials)
+
+ credentials.getAllTokens.size() should be (0)
+ }
+
+ test("Obtain tokens For HBase") {
+ val hadoopConf = new Configuration()
+ hadoopConf.set("hbase.security.authentication", "kerberos")
+
+ val hbaseTokenProvider = new HBaseCredentialProvider()
+ val creds = new Credentials()
+ hbaseTokenProvider.obtainCredentials(hadoopConf, sparkConf, creds)
+
+ creds.getAllTokens.size should be (0)
+ }
+}
+
+class TestCredentialProvider extends ServiceCredentialProvider {
+ val tokenRenewalInterval = 86400 * 1000L
+ var timeOfNextTokenRenewal = 0L
+
+ override def serviceName: String = "test"
+
+ override def credentialsRequired(conf: Configuration): Boolean = true
+
+ override def obtainCredentials(
+ hadoopConf: Configuration,
+ sparkConf: SparkConf,
+ creds: Credentials): Option[Long] = {
+ if (creds == null) {
+ // Guard out other unit test failures.
+ return None
+ }
+
+ val emptyToken = new Token()
+ emptyToken.setService(new Text("test"))
+ creds.addToken(emptyToken.getService, emptyToken)
+
+ val currTime = System.currentTimeMillis()
+ timeOfNextTokenRenewal = (currTime - currTime % tokenRenewalInterval) + tokenRenewalInterval
+
+ Some(timeOfNextTokenRenewal)
+ }
+}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/HDFSCredentialProviderSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/HDFSCredentialProviderSuite.scala
new file mode 100644
index 0000000000..7b2da3f26e
--- /dev/null
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/HDFSCredentialProviderSuite.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn.security
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.scalatest.{Matchers, PrivateMethodTester}
+
+import org.apache.spark.{SparkConf, SparkException, SparkFunSuite}
+
+class HDFSCredentialProviderSuite
+ extends SparkFunSuite
+ with PrivateMethodTester
+ with Matchers {
+ private val _getTokenRenewer = PrivateMethod[String]('getTokenRenewer)
+
+ private def getTokenRenewer(
+ hdfsCredentialProvider: HDFSCredentialProvider, conf: Configuration): String = {
+ hdfsCredentialProvider invokePrivate _getTokenRenewer(conf)
+ }
+
+ private var hdfsCredentialProvider: HDFSCredentialProvider = null
+
+ override def beforeAll() {
+ super.beforeAll()
+
+ if (hdfsCredentialProvider == null) {
+ hdfsCredentialProvider = new HDFSCredentialProvider()
+ }
+ }
+
+ override def afterAll() {
+ if (hdfsCredentialProvider != null) {
+ hdfsCredentialProvider = null
+ }
+
+ super.afterAll()
+ }
+
+ test("check token renewer") {
+ val hadoopConf = new Configuration()
+ hadoopConf.set("yarn.resourcemanager.address", "myrm:8033")
+ hadoopConf.set("yarn.resourcemanager.principal", "yarn/myrm:8032@SPARKTEST.COM")
+ val renewer = getTokenRenewer(hdfsCredentialProvider, hadoopConf)
+ renewer should be ("yarn/myrm:8032@SPARKTEST.COM")
+ }
+
+ test("check token renewer default") {
+ val hadoopConf = new Configuration()
+ val caught =
+ intercept[SparkException] {
+ getTokenRenewer(hdfsCredentialProvider, hadoopConf)
+ }
+ assert(caught.getMessage === "Can't get Master Kerberos principal for use as renewer")
+ }
+}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/launcher/TestClasspathBuilder.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/launcher/TestClasspathBuilder.scala
new file mode 100644
index 0000000000..da9e8e21a2
--- /dev/null
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/launcher/TestClasspathBuilder.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.launcher
+
+import java.util.{List => JList, Map => JMap}
+
+/**
+ * Exposes AbstractCommandBuilder to the YARN tests, so that they can build classpaths the same
+ * way other cluster managers do.
+ */
+private[spark] class TestClasspathBuilder extends AbstractCommandBuilder {
+
+ childEnv.put(CommandBuilderUtils.ENV_SPARK_HOME, sys.props("spark.test.home"))
+
+ override def buildClassPath(extraCp: String): JList[String] = super.buildClassPath(extraCp)
+
+ /** Not used by the YARN tests. */
+ override def buildCommand(env: JMap[String, String]): JList[String] =
+ throw new UnsupportedOperationException()
+
+}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala
new file mode 100644
index 0000000000..1fed2562fc
--- /dev/null
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.network.shuffle
+
+import java.io.File
+import java.util.concurrent.ConcurrentMap
+
+import org.apache.hadoop.yarn.api.records.ApplicationId
+import org.fusesource.leveldbjni.JniDBFactory
+import org.iq80.leveldb.{DB, Options}
+
+import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
+
+/**
+ * just a cheat to get package-visible members in tests
+ */
+object ShuffleTestAccessor {
+
+ def getBlockResolver(handler: ExternalShuffleBlockHandler): ExternalShuffleBlockResolver = {
+ handler.blockManager
+ }
+
+ def getExecutorInfo(
+ appId: ApplicationId,
+ execId: String,
+ resolver: ExternalShuffleBlockResolver
+ ): Option[ExecutorShuffleInfo] = {
+ val id = new AppExecId(appId.toString, execId)
+ Option(resolver.executors.get(id))
+ }
+
+ def registeredExecutorFile(resolver: ExternalShuffleBlockResolver): File = {
+ resolver.registeredExecutorFile
+ }
+
+ def shuffleServiceLevelDB(resolver: ExternalShuffleBlockResolver): DB = {
+ resolver.db
+ }
+
+ def reloadRegisteredExecutors(
+ file: File): ConcurrentMap[ExternalShuffleBlockResolver.AppExecId, ExecutorShuffleInfo] = {
+ val options: Options = new Options
+ options.createIfMissing(true)
+ val factory = new JniDBFactory
+ val db = factory.open(file, options)
+ val result = ExternalShuffleBlockResolver.reloadRegisteredExecutors(db)
+ db.close()
+ result
+ }
+
+ def reloadRegisteredExecutors(
+ db: DB): ConcurrentMap[ExternalShuffleBlockResolver.AppExecId, ExecutorShuffleInfo] = {
+ ExternalShuffleBlockResolver.reloadRegisteredExecutors(db)
+ }
+}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala
new file mode 100644
index 0000000000..a58784f596
--- /dev/null
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala
@@ -0,0 +1,372 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.network.yarn
+
+import java.io.{DataOutputStream, File, FileOutputStream, IOException}
+import java.nio.ByteBuffer
+import java.nio.file.Files
+import java.nio.file.attribute.PosixFilePermission._
+import java.util.EnumSet
+
+import scala.annotation.tailrec
+import scala.concurrent.duration._
+import scala.language.postfixOps
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.service.ServiceStateException
+import org.apache.hadoop.yarn.api.records.ApplicationId
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.server.api.{ApplicationInitializationContext, ApplicationTerminationContext}
+import org.scalatest.{BeforeAndAfterEach, Matchers}
+import org.scalatest.concurrent.Eventually._
+
+import org.apache.spark.SecurityManager
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.network.shuffle.ShuffleTestAccessor
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
+import org.apache.spark.util.Utils
+
+class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach {
+ private[yarn] var yarnConfig: YarnConfiguration = null
+ private[yarn] val SORT_MANAGER = "org.apache.spark.shuffle.sort.SortShuffleManager"
+
+ override def beforeEach(): Unit = {
+ super.beforeEach()
+ yarnConfig = new YarnConfiguration()
+ yarnConfig.set(YarnConfiguration.NM_AUX_SERVICES, "spark_shuffle")
+ yarnConfig.set(YarnConfiguration.NM_AUX_SERVICE_FMT.format("spark_shuffle"),
+ classOf[YarnShuffleService].getCanonicalName)
+ yarnConfig.setInt("spark.shuffle.service.port", 0)
+ yarnConfig.setBoolean(YarnShuffleService.STOP_ON_FAILURE_KEY, true)
+ val localDir = Utils.createTempDir()
+ yarnConfig.set(YarnConfiguration.NM_LOCAL_DIRS, localDir.getAbsolutePath)
+ }
+
+ var s1: YarnShuffleService = null
+ var s2: YarnShuffleService = null
+ var s3: YarnShuffleService = null
+
+ override def afterEach(): Unit = {
+ try {
+ if (s1 != null) {
+ s1.stop()
+ s1 = null
+ }
+ if (s2 != null) {
+ s2.stop()
+ s2 = null
+ }
+ if (s3 != null) {
+ s3.stop()
+ s3 = null
+ }
+ } finally {
+ super.afterEach()
+ }
+ }
+
+ test("executor state kept across NM restart") {
+ s1 = new YarnShuffleService
+ // set auth to true to test the secrets recovery
+ yarnConfig.setBoolean(SecurityManager.SPARK_AUTH_CONF, true)
+ s1.init(yarnConfig)
+ val app1Id = ApplicationId.newInstance(0, 1)
+ val app1Data = makeAppInfo("user", app1Id)
+ s1.initializeApplication(app1Data)
+ val app2Id = ApplicationId.newInstance(0, 2)
+ val app2Data = makeAppInfo("user", app2Id)
+ s1.initializeApplication(app2Data)
+
+ val execStateFile = s1.registeredExecutorFile
+ execStateFile should not be (null)
+ val secretsFile = s1.secretsFile
+ secretsFile should not be (null)
+ val shuffleInfo1 = new ExecutorShuffleInfo(Array("/foo", "/bar"), 3, SORT_MANAGER)
+ val shuffleInfo2 = new ExecutorShuffleInfo(Array("/bippy"), 5, SORT_MANAGER)
+
+ val blockHandler = s1.blockHandler
+ val blockResolver = ShuffleTestAccessor.getBlockResolver(blockHandler)
+ ShuffleTestAccessor.registeredExecutorFile(blockResolver) should be (execStateFile)
+
+ blockResolver.registerExecutor(app1Id.toString, "exec-1", shuffleInfo1)
+ blockResolver.registerExecutor(app2Id.toString, "exec-2", shuffleInfo2)
+ ShuffleTestAccessor.getExecutorInfo(app1Id, "exec-1", blockResolver) should
+ be (Some(shuffleInfo1))
+ ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", blockResolver) should
+ be (Some(shuffleInfo2))
+
+ if (!execStateFile.exists()) {
+ @tailrec def findExistingParent(file: File): File = {
+ if (file == null) file
+ else if (file.exists()) file
+ else findExistingParent(file.getParentFile())
+ }
+ val existingParent = findExistingParent(execStateFile)
+ assert(false, s"$execStateFile does not exist -- closest existing parent is $existingParent")
+ }
+ assert(execStateFile.exists(), s"$execStateFile did not exist")
+
+ // now we pretend the shuffle service goes down, and comes back up
+ s1.stop()
+ s2 = new YarnShuffleService
+ s2.init(yarnConfig)
+ s2.secretsFile should be (secretsFile)
+ s2.registeredExecutorFile should be (execStateFile)
+
+ val handler2 = s2.blockHandler
+ val resolver2 = ShuffleTestAccessor.getBlockResolver(handler2)
+
+ // now we reinitialize only one of the apps, and expect yarn to tell us that app2 was stopped
+ // during the restart
+ s2.initializeApplication(app1Data)
+ s2.stopApplication(new ApplicationTerminationContext(app2Id))
+ ShuffleTestAccessor.getExecutorInfo(app1Id, "exec-1", resolver2) should be (Some(shuffleInfo1))
+ ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver2) should be (None)
+
+ // Act like the NM restarts one more time
+ s2.stop()
+ s3 = new YarnShuffleService
+ s3.init(yarnConfig)
+ s3.registeredExecutorFile should be (execStateFile)
+ s3.secretsFile should be (secretsFile)
+
+ val handler3 = s3.blockHandler
+ val resolver3 = ShuffleTestAccessor.getBlockResolver(handler3)
+
+ // app1 is still running
+ s3.initializeApplication(app1Data)
+ ShuffleTestAccessor.getExecutorInfo(app1Id, "exec-1", resolver3) should be (Some(shuffleInfo1))
+ ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver3) should be (None)
+ s3.stop()
+ }
+
+ test("removed applications should not be in registered executor file") {
+ s1 = new YarnShuffleService
+ yarnConfig.setBoolean(SecurityManager.SPARK_AUTH_CONF, false)
+ s1.init(yarnConfig)
+ val secretsFile = s1.secretsFile
+ secretsFile should be (null)
+ val app1Id = ApplicationId.newInstance(0, 1)
+ val app1Data = makeAppInfo("user", app1Id)
+ s1.initializeApplication(app1Data)
+ val app2Id = ApplicationId.newInstance(0, 2)
+ val app2Data = makeAppInfo("user", app2Id)
+ s1.initializeApplication(app2Data)
+
+ val execStateFile = s1.registeredExecutorFile
+ execStateFile should not be (null)
+ val shuffleInfo1 = new ExecutorShuffleInfo(Array("/foo", "/bar"), 3, SORT_MANAGER)
+ val shuffleInfo2 = new ExecutorShuffleInfo(Array("/bippy"), 5, SORT_MANAGER)
+
+ val blockHandler = s1.blockHandler
+ val blockResolver = ShuffleTestAccessor.getBlockResolver(blockHandler)
+ ShuffleTestAccessor.registeredExecutorFile(blockResolver) should be (execStateFile)
+
+ blockResolver.registerExecutor(app1Id.toString, "exec-1", shuffleInfo1)
+ blockResolver.registerExecutor(app2Id.toString, "exec-2", shuffleInfo2)
+
+ val db = ShuffleTestAccessor.shuffleServiceLevelDB(blockResolver)
+ ShuffleTestAccessor.reloadRegisteredExecutors(db) should not be empty
+
+ s1.stopApplication(new ApplicationTerminationContext(app1Id))
+ ShuffleTestAccessor.reloadRegisteredExecutors(db) should not be empty
+ s1.stopApplication(new ApplicationTerminationContext(app2Id))
+ ShuffleTestAccessor.reloadRegisteredExecutors(db) shouldBe empty
+ }
+
+ test("shuffle service should be robust to corrupt registered executor file") {
+ s1 = new YarnShuffleService
+ s1.init(yarnConfig)
+ val app1Id = ApplicationId.newInstance(0, 1)
+ val app1Data = makeAppInfo("user", app1Id)
+ s1.initializeApplication(app1Data)
+
+ val execStateFile = s1.registeredExecutorFile
+ val shuffleInfo1 = new ExecutorShuffleInfo(Array("/foo", "/bar"), 3, SORT_MANAGER)
+
+ val blockHandler = s1.blockHandler
+ val blockResolver = ShuffleTestAccessor.getBlockResolver(blockHandler)
+ ShuffleTestAccessor.registeredExecutorFile(blockResolver) should be (execStateFile)
+
+ blockResolver.registerExecutor(app1Id.toString, "exec-1", shuffleInfo1)
+
+ // now we pretend the shuffle service goes down, and comes back up. But we'll also
+ // make a corrupt registeredExecutor File
+ s1.stop()
+
+ execStateFile.listFiles().foreach{_.delete()}
+
+ val out = new DataOutputStream(new FileOutputStream(execStateFile + "/CURRENT"))
+ out.writeInt(42)
+ out.close()
+
+ s2 = new YarnShuffleService
+ s2.init(yarnConfig)
+ s2.registeredExecutorFile should be (execStateFile)
+
+ val handler2 = s2.blockHandler
+ val resolver2 = ShuffleTestAccessor.getBlockResolver(handler2)
+
+ // we re-initialize app1, but since the file was corrupt there is nothing we can do about it ...
+ s2.initializeApplication(app1Data)
+ // however, when we initialize a totally new app2, everything is still happy
+ val app2Id = ApplicationId.newInstance(0, 2)
+ val app2Data = makeAppInfo("user", app2Id)
+ s2.initializeApplication(app2Data)
+ val shuffleInfo2 = new ExecutorShuffleInfo(Array("/bippy"), 5, SORT_MANAGER)
+ resolver2.registerExecutor(app2Id.toString, "exec-2", shuffleInfo2)
+ ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver2) should be (Some(shuffleInfo2))
+ s2.stop()
+
+ // another stop & restart should be fine though (eg., we recover from previous corruption)
+ s3 = new YarnShuffleService
+ s3.init(yarnConfig)
+ s3.registeredExecutorFile should be (execStateFile)
+ val handler3 = s3.blockHandler
+ val resolver3 = ShuffleTestAccessor.getBlockResolver(handler3)
+
+ s3.initializeApplication(app2Data)
+ ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver3) should be (Some(shuffleInfo2))
+ s3.stop()
+ }
+
+ test("get correct recovery path") {
+ // Test recovery path is set outside the shuffle service, this is to simulate NM recovery
+ // enabled scenario, where recovery path will be set by yarn.
+ s1 = new YarnShuffleService
+ val recoveryPath = new Path(Utils.createTempDir().toURI)
+ s1.setRecoveryPath(recoveryPath)
+
+ s1.init(yarnConfig)
+ s1._recoveryPath should be (recoveryPath)
+ s1.stop()
+
+ // Test recovery path is set inside the shuffle service, this will be happened when NM
+ // recovery is not enabled or there's no NM recovery (Hadoop 2.5-).
+ s2 = new YarnShuffleService
+ s2.init(yarnConfig)
+ s2._recoveryPath should be
+ (new Path(yarnConfig.getTrimmedStrings("yarn.nodemanager.local-dirs")(0)))
+ s2.stop()
+ }
+
+ test("moving recovery file from NM local dir to recovery path") {
+ // This is to test when Hadoop is upgrade to 2.5+ and NM recovery is enabled, we should move
+ // old recovery file to the new path to keep compatibility
+
+ // Simulate s1 is running on old version of Hadoop in which recovery file is in the NM local
+ // dir.
+ s1 = new YarnShuffleService
+ // set auth to true to test the secrets recovery
+ yarnConfig.setBoolean(SecurityManager.SPARK_AUTH_CONF, true)
+ s1.init(yarnConfig)
+ val app1Id = ApplicationId.newInstance(0, 1)
+ val app1Data = makeAppInfo("user", app1Id)
+ s1.initializeApplication(app1Data)
+ val app2Id = ApplicationId.newInstance(0, 2)
+ val app2Data = makeAppInfo("user", app2Id)
+ s1.initializeApplication(app2Data)
+
+ assert(s1.secretManager.getSecretKey(app1Id.toString()) != null)
+ assert(s1.secretManager.getSecretKey(app2Id.toString()) != null)
+
+ val execStateFile = s1.registeredExecutorFile
+ execStateFile should not be (null)
+ val secretsFile = s1.secretsFile
+ secretsFile should not be (null)
+ val shuffleInfo1 = new ExecutorShuffleInfo(Array("/foo", "/bar"), 3, SORT_MANAGER)
+ val shuffleInfo2 = new ExecutorShuffleInfo(Array("/bippy"), 5, SORT_MANAGER)
+
+ val blockHandler = s1.blockHandler
+ val blockResolver = ShuffleTestAccessor.getBlockResolver(blockHandler)
+ ShuffleTestAccessor.registeredExecutorFile(blockResolver) should be (execStateFile)
+
+ blockResolver.registerExecutor(app1Id.toString, "exec-1", shuffleInfo1)
+ blockResolver.registerExecutor(app2Id.toString, "exec-2", shuffleInfo2)
+ ShuffleTestAccessor.getExecutorInfo(app1Id, "exec-1", blockResolver) should
+ be (Some(shuffleInfo1))
+ ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", blockResolver) should
+ be (Some(shuffleInfo2))
+
+ assert(execStateFile.exists(), s"$execStateFile did not exist")
+
+ s1.stop()
+
+ // Simulate s2 is running on Hadoop 2.5+ with NM recovery is enabled.
+ assert(execStateFile.exists())
+ val recoveryPath = new Path(Utils.createTempDir().toURI)
+ s2 = new YarnShuffleService
+ s2.setRecoveryPath(recoveryPath)
+ s2.init(yarnConfig)
+
+ // Ensure that s2 has loaded known apps from the secrets db.
+ assert(s2.secretManager.getSecretKey(app1Id.toString()) != null)
+ assert(s2.secretManager.getSecretKey(app2Id.toString()) != null)
+
+ val execStateFile2 = s2.registeredExecutorFile
+ val secretsFile2 = s2.secretsFile
+
+ recoveryPath.toString should be (new Path(execStateFile2.getParentFile.toURI).toString)
+ recoveryPath.toString should be (new Path(secretsFile2.getParentFile.toURI).toString)
+ eventually(timeout(10 seconds), interval(5 millis)) {
+ assert(!execStateFile.exists())
+ }
+ eventually(timeout(10 seconds), interval(5 millis)) {
+ assert(!secretsFile.exists())
+ }
+
+ val handler2 = s2.blockHandler
+ val resolver2 = ShuffleTestAccessor.getBlockResolver(handler2)
+
+ // now we reinitialize only one of the apps, and expect yarn to tell us that app2 was stopped
+ // during the restart
+ // Since recovery file is got from old path, so the previous state should be stored.
+ s2.initializeApplication(app1Data)
+ s2.stopApplication(new ApplicationTerminationContext(app2Id))
+ ShuffleTestAccessor.getExecutorInfo(app1Id, "exec-1", resolver2) should be (Some(shuffleInfo1))
+ ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver2) should be (None)
+
+ s2.stop()
+ }
+
+ test("service throws error if cannot start") {
+ // Set up a read-only local dir.
+ val roDir = Utils.createTempDir()
+ Files.setPosixFilePermissions(roDir.toPath(), EnumSet.of(OWNER_READ, OWNER_EXECUTE))
+ yarnConfig.set(YarnConfiguration.NM_LOCAL_DIRS, roDir.getAbsolutePath())
+
+ // Try to start the shuffle service, it should fail.
+ val service = new YarnShuffleService()
+
+ try {
+ val error = intercept[ServiceStateException] {
+ service.init(yarnConfig)
+ }
+ assert(error.getCause().isInstanceOf[IOException])
+ } finally {
+ service.stop()
+ Files.setPosixFilePermissions(roDir.toPath(),
+ EnumSet.of(OWNER_READ, OWNER_WRITE, OWNER_EXECUTE))
+ }
+ }
+
+ private def makeAppInfo(user: String, appId: ApplicationId): ApplicationInitializationContext = {
+ val secret = ByteBuffer.wrap(new Array[Byte](0))
+ new ApplicationInitializationContext(user, appId, secret)
+ }
+
+}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnTestAccessor.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnTestAccessor.scala
new file mode 100644
index 0000000000..db322cd18e
--- /dev/null
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnTestAccessor.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.network.yarn
+
+import java.io.File
+
+/**
+ * just a cheat to get package-visible members in tests
+ */
+object YarnTestAccessor {
+ def getShuffleServicePort: Int = {
+ YarnShuffleService.boundPort
+ }
+
+ def getShuffleServiceInstance: YarnShuffleService = {
+ YarnShuffleService.instance
+ }
+
+ def getRegisteredExecutorFile(service: YarnShuffleService): File = {
+ service.registeredExecutorFile
+ }
+
+}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/ExtensionServiceIntegrationSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/ExtensionServiceIntegrationSuite.scala
new file mode 100644
index 0000000000..6ea7984c64
--- /dev/null
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/ExtensionServiceIntegrationSuite.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite}
+import org.apache.spark.deploy.yarn.config._
+import org.apache.spark.internal.Logging
+
+/**
+ * Test the integration with [[SchedulerExtensionServices]]
+ */
+class ExtensionServiceIntegrationSuite extends SparkFunSuite
+ with LocalSparkContext with BeforeAndAfter
+ with Logging {
+
+ val applicationId = new StubApplicationId(0, 1111L)
+ val attemptId = new StubApplicationAttemptId(applicationId, 1)
+
+ /*
+ * Setup phase creates the spark context
+ */
+ before {
+ val sparkConf = new SparkConf()
+ sparkConf.set(SCHEDULER_SERVICES, Seq(classOf[SimpleExtensionService].getName()))
+ sparkConf.setMaster("local").setAppName("ExtensionServiceIntegrationSuite")
+ sc = new SparkContext(sparkConf)
+ }
+
+ test("Instantiate") {
+ val services = new SchedulerExtensionServices()
+ assertResult(Nil, "non-nil service list") {
+ services.getServices
+ }
+ services.start(SchedulerExtensionServiceBinding(sc, applicationId))
+ services.stop()
+ }
+
+ test("Contains SimpleExtensionService Service") {
+ val services = new SchedulerExtensionServices()
+ try {
+ services.start(SchedulerExtensionServiceBinding(sc, applicationId))
+ val serviceList = services.getServices
+ assert(serviceList.nonEmpty, "empty service list")
+ val (service :: Nil) = serviceList
+ val simpleService = service.asInstanceOf[SimpleExtensionService]
+ assert(simpleService.started.get, "service not started")
+ services.stop()
+ assert(!simpleService.started.get, "service not stopped")
+ } finally {
+ services.stop()
+ }
+ }
+}
+
+
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/SimpleExtensionService.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/SimpleExtensionService.scala
new file mode 100644
index 0000000000..9b8c98cda8
--- /dev/null
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/SimpleExtensionService.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import java.util.concurrent.atomic.AtomicBoolean
+
+private[spark] class SimpleExtensionService extends SchedulerExtensionService {
+
+ /** started flag; set in the `start()` call, stopped in `stop()`. */
+ val started = new AtomicBoolean(false)
+
+ override def start(binding: SchedulerExtensionServiceBinding): Unit = {
+ started.set(true)
+ }
+
+ override def stop(): Unit = {
+ started.set(false)
+ }
+}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationAttemptId.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationAttemptId.scala
new file mode 100644
index 0000000000..4b57b9509a
--- /dev/null
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationAttemptId.scala
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId}
+
+/**
+ * A stub application ID; can be set in constructor and/or updated later.
+ * @param applicationId application ID
+ * @param attempt an attempt counter
+ */
+class StubApplicationAttemptId(var applicationId: ApplicationId, var attempt: Int)
+ extends ApplicationAttemptId {
+
+ override def setApplicationId(appID: ApplicationId): Unit = {
+ applicationId = appID
+ }
+
+ override def getAttemptId: Int = {
+ attempt
+ }
+
+ override def setAttemptId(attemptId: Int): Unit = {
+ attempt = attemptId
+ }
+
+ override def getApplicationId: ApplicationId = {
+ applicationId
+ }
+
+ override def build(): Unit = {
+ }
+}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationId.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationId.scala
new file mode 100644
index 0000000000..bffa0e09be
--- /dev/null
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationId.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import org.apache.hadoop.yarn.api.records.ApplicationId
+
+/**
+ * Simple Testing Application Id; ID and cluster timestamp are set in constructor
+ * and cannot be updated.
+ * @param id app id
+ * @param clusterTimestamp timestamp
+ */
+private[spark] class StubApplicationId(id: Int, clusterTimestamp: Long) extends ApplicationId {
+ override def getId: Int = {
+ id
+ }
+
+ override def getClusterTimestamp: Long = {
+ clusterTimestamp
+ }
+
+ override def setId(id: Int): Unit = {}
+
+ override def setClusterTimestamp(clusterTimestamp: Long): Unit = {}
+
+ override def build(): Unit = {}
+}