aboutsummaryrefslogtreecommitdiff
path: root/yarn
diff options
context:
space:
mode:
Diffstat (limited to 'yarn')
-rw-r--r--yarn/pom.xml6
-rw-r--r--yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala193
-rw-r--r--yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala173
-rw-r--r--yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala109
-rw-r--r--yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala71
-rw-r--r--yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala233
-rw-r--r--yarn/src/test/scala/org/apache/spark/network/yarn/YarnTestAccessor.scala37
7 files changed, 658 insertions, 164 deletions
diff --git a/yarn/pom.xml b/yarn/pom.xml
index 15db54e4e7..f673769530 100644
--- a/yarn/pom.xml
+++ b/yarn/pom.xml
@@ -40,6 +40,12 @@
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
+ <artifactId>spark-network-yarn_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<type>test-jar</type>
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala
new file mode 100644
index 0000000000..128e996b71
--- /dev/null
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.io.{File, FileOutputStream, OutputStreamWriter}
+import java.util.Properties
+import java.util.concurrent.TimeUnit
+
+import scala.collection.JavaConversions._
+
+import com.google.common.base.Charsets.UTF_8
+import com.google.common.io.Files
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.server.MiniYARNCluster
+import org.scalatest.{BeforeAndAfterAll, Matchers}
+
+import org.apache.spark._
+import org.apache.spark.util.Utils
+
+abstract class BaseYarnClusterSuite
+ extends SparkFunSuite with BeforeAndAfterAll with Matchers with Logging {
+
+ // log4j configuration for the YARN containers, so that their output is collected
+ // by YARN instead of trying to overwrite unit-tests.log.
+ protected val LOG4J_CONF = """
+ |log4j.rootCategory=DEBUG, console
+ |log4j.appender.console=org.apache.log4j.ConsoleAppender
+ |log4j.appender.console.target=System.err
+ |log4j.appender.console.layout=org.apache.log4j.PatternLayout
+ |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
+ """.stripMargin
+
+ private var yarnCluster: MiniYARNCluster = _
+ protected var tempDir: File = _
+ private var fakeSparkJar: File = _
+ private var hadoopConfDir: File = _
+ private var logConfDir: File = _
+
+
+ def yarnConfig: YarnConfiguration
+
+ override def beforeAll() {
+ super.beforeAll()
+
+ tempDir = Utils.createTempDir()
+ logConfDir = new File(tempDir, "log4j")
+ logConfDir.mkdir()
+ System.setProperty("SPARK_YARN_MODE", "true")
+
+ val logConfFile = new File(logConfDir, "log4j.properties")
+ Files.write(LOG4J_CONF, logConfFile, UTF_8)
+
+ yarnCluster = new MiniYARNCluster(getClass().getName(), 1, 1, 1)
+ yarnCluster.init(yarnConfig)
+ yarnCluster.start()
+
+ // There's a race in MiniYARNCluster in which start() may return before the RM has updated
+ // its address in the configuration. You can see this in the logs by noticing that when
+ // MiniYARNCluster prints the address, it still has port "0" assigned, although later the
+ // test works sometimes:
+ //
+ // INFO MiniYARNCluster: MiniYARN ResourceManager address: blah:0
+ //
+ // That log message prints the contents of the RM_ADDRESS config variable. If you check it
+ // later on, it looks something like this:
+ //
+ // INFO YarnClusterSuite: RM address in configuration is blah:42631
+ //
+ // This hack loops for a bit waiting for the port to change, and fails the test if it hasn't
+ // done so in a timely manner (defined to be 10 seconds).
+ val config = yarnCluster.getConfig()
+ val deadline = System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(10)
+ while (config.get(YarnConfiguration.RM_ADDRESS).split(":")(1) == "0") {
+ if (System.currentTimeMillis() > deadline) {
+ throw new IllegalStateException("Timed out waiting for RM to come up.")
+ }
+ logDebug("RM address still not set in configuration, waiting...")
+ TimeUnit.MILLISECONDS.sleep(100)
+ }
+
+ logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}")
+
+ fakeSparkJar = File.createTempFile("sparkJar", null, tempDir)
+ hadoopConfDir = new File(tempDir, Client.LOCALIZED_CONF_DIR)
+ assert(hadoopConfDir.mkdir())
+ File.createTempFile("token", ".txt", hadoopConfDir)
+ }
+
+ override def afterAll() {
+ yarnCluster.stop()
+ System.clearProperty("SPARK_YARN_MODE")
+ super.afterAll()
+ }
+
+ protected def runSpark(
+ clientMode: Boolean,
+ klass: String,
+ appArgs: Seq[String] = Nil,
+ sparkArgs: Seq[String] = Nil,
+ extraClassPath: Seq[String] = Nil,
+ extraJars: Seq[String] = Nil,
+ extraConf: Map[String, String] = Map()): Unit = {
+ val master = if (clientMode) "yarn-client" else "yarn-cluster"
+ val props = new Properties()
+
+ props.setProperty("spark.yarn.jar", "local:" + fakeSparkJar.getAbsolutePath())
+
+ val childClasspath = logConfDir.getAbsolutePath() +
+ File.pathSeparator +
+ sys.props("java.class.path") +
+ File.pathSeparator +
+ extraClassPath.mkString(File.pathSeparator)
+ props.setProperty("spark.driver.extraClassPath", childClasspath)
+ props.setProperty("spark.executor.extraClassPath", childClasspath)
+
+ // SPARK-4267: make sure java options are propagated correctly.
+ props.setProperty("spark.driver.extraJavaOptions", "-Dfoo=\"one two three\"")
+ props.setProperty("spark.executor.extraJavaOptions", "-Dfoo=\"one two three\"")
+
+ yarnCluster.getConfig().foreach { e =>
+ props.setProperty("spark.hadoop." + e.getKey(), e.getValue())
+ }
+
+ sys.props.foreach { case (k, v) =>
+ if (k.startsWith("spark.")) {
+ props.setProperty(k, v)
+ }
+ }
+
+ extraConf.foreach { case (k, v) => props.setProperty(k, v) }
+
+ val propsFile = File.createTempFile("spark", ".properties", tempDir)
+ val writer = new OutputStreamWriter(new FileOutputStream(propsFile), UTF_8)
+ props.store(writer, "Spark properties.")
+ writer.close()
+
+ val extraJarArgs = if (!extraJars.isEmpty()) Seq("--jars", extraJars.mkString(",")) else Nil
+ val mainArgs =
+ if (klass.endsWith(".py")) {
+ Seq(klass)
+ } else {
+ Seq("--class", klass, fakeSparkJar.getAbsolutePath())
+ }
+ val argv =
+ Seq(
+ new File(sys.props("spark.test.home"), "bin/spark-submit").getAbsolutePath(),
+ "--master", master,
+ "--num-executors", "1",
+ "--properties-file", propsFile.getAbsolutePath()) ++
+ extraJarArgs ++
+ sparkArgs ++
+ mainArgs ++
+ appArgs
+
+ Utils.executeAndGetOutput(argv,
+ extraEnvironment = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath()))
+ }
+
+ /**
+ * This is a workaround for an issue with yarn-cluster mode: the Client class will not provide
+ * any sort of error when the job process finishes successfully, but the job itself fails. So
+ * the tests enforce that something is written to a file after everything is ok to indicate
+ * that the job succeeded.
+ */
+ protected def checkResult(result: File): Unit = {
+ checkResult(result, "success")
+ }
+
+ protected def checkResult(result: File, expected: String): Unit = {
+ val resultString = Files.toString(result, UTF_8)
+ resultString should be (expected)
+ }
+
+ protected def mainClassName(klass: Class[_]): String = {
+ klass.getName().stripSuffix("$")
+ }
+
+}
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
index eb6e1fd370..128350b648 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
@@ -17,25 +17,20 @@
package org.apache.spark.deploy.yarn
-import java.io.{File, FileOutputStream, OutputStreamWriter}
+import java.io.File
import java.net.URL
-import java.util.Properties
-import java.util.concurrent.TimeUnit
-import scala.collection.JavaConversions._
import scala.collection.mutable
+import scala.collection.JavaConversions._
import com.google.common.base.Charsets.UTF_8
-import com.google.common.io.ByteStreams
-import com.google.common.io.Files
+import com.google.common.io.{ByteStreams, Files}
import org.apache.hadoop.yarn.conf.YarnConfiguration
-import org.apache.hadoop.yarn.server.MiniYARNCluster
-import org.scalatest.{BeforeAndAfterAll, Matchers}
+import org.scalatest.Matchers
import org.apache.spark._
+import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart, SparkListenerExecutorAdded}
import org.apache.spark.scheduler.cluster.ExecutorInfo
-import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart,
- SparkListenerExecutorAdded}
import org.apache.spark.util.Utils
/**
@@ -43,17 +38,9 @@ import org.apache.spark.util.Utils
* applications, and require the Spark assembly to be built before they can be successfully
* run.
*/
-class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matchers with Logging {
-
- // log4j configuration for the YARN containers, so that their output is collected
- // by YARN instead of trying to overwrite unit-tests.log.
- private val LOG4J_CONF = """
- |log4j.rootCategory=DEBUG, console
- |log4j.appender.console=org.apache.log4j.ConsoleAppender
- |log4j.appender.console.target=System.err
- |log4j.appender.console.layout=org.apache.log4j.PatternLayout
- |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
- """.stripMargin
+class YarnClusterSuite extends BaseYarnClusterSuite {
+
+ override def yarnConfig: YarnConfiguration = new YarnConfiguration()
private val TEST_PYFILE = """
|import mod1, mod2
@@ -82,65 +69,6 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher
| return 42
""".stripMargin
- private var yarnCluster: MiniYARNCluster = _
- private var tempDir: File = _
- private var fakeSparkJar: File = _
- private var hadoopConfDir: File = _
- private var logConfDir: File = _
-
- override def beforeAll() {
- super.beforeAll()
-
- tempDir = Utils.createTempDir()
- logConfDir = new File(tempDir, "log4j")
- logConfDir.mkdir()
- System.setProperty("SPARK_YARN_MODE", "true")
-
- val logConfFile = new File(logConfDir, "log4j.properties")
- Files.write(LOG4J_CONF, logConfFile, UTF_8)
-
- yarnCluster = new MiniYARNCluster(getClass().getName(), 1, 1, 1)
- yarnCluster.init(new YarnConfiguration())
- yarnCluster.start()
-
- // There's a race in MiniYARNCluster in which start() may return before the RM has updated
- // its address in the configuration. You can see this in the logs by noticing that when
- // MiniYARNCluster prints the address, it still has port "0" assigned, although later the
- // test works sometimes:
- //
- // INFO MiniYARNCluster: MiniYARN ResourceManager address: blah:0
- //
- // That log message prints the contents of the RM_ADDRESS config variable. If you check it
- // later on, it looks something like this:
- //
- // INFO YarnClusterSuite: RM address in configuration is blah:42631
- //
- // This hack loops for a bit waiting for the port to change, and fails the test if it hasn't
- // done so in a timely manner (defined to be 10 seconds).
- val config = yarnCluster.getConfig()
- val deadline = System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(10)
- while (config.get(YarnConfiguration.RM_ADDRESS).split(":")(1) == "0") {
- if (System.currentTimeMillis() > deadline) {
- throw new IllegalStateException("Timed out waiting for RM to come up.")
- }
- logDebug("RM address still not set in configuration, waiting...")
- TimeUnit.MILLISECONDS.sleep(100)
- }
-
- logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}")
-
- fakeSparkJar = File.createTempFile("sparkJar", null, tempDir)
- hadoopConfDir = new File(tempDir, Client.LOCALIZED_CONF_DIR)
- assert(hadoopConfDir.mkdir())
- File.createTempFile("token", ".txt", hadoopConfDir)
- }
-
- override def afterAll() {
- yarnCluster.stop()
- System.clearProperty("SPARK_YARN_MODE")
- super.afterAll()
- }
-
test("run Spark in yarn-client mode") {
testBasicYarnApp(true)
}
@@ -174,7 +102,7 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher
}
private def testBasicYarnApp(clientMode: Boolean): Unit = {
- var result = File.createTempFile("result", null, tempDir)
+ val result = File.createTempFile("result", null, tempDir)
runSpark(clientMode, mainClassName(YarnClusterDriver.getClass),
appArgs = Seq(result.getAbsolutePath()))
checkResult(result)
@@ -224,89 +152,6 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher
checkResult(executorResult, "OVERRIDDEN")
}
- private def runSpark(
- clientMode: Boolean,
- klass: String,
- appArgs: Seq[String] = Nil,
- sparkArgs: Seq[String] = Nil,
- extraClassPath: Seq[String] = Nil,
- extraJars: Seq[String] = Nil,
- extraConf: Map[String, String] = Map()): Unit = {
- val master = if (clientMode) "yarn-client" else "yarn-cluster"
- val props = new Properties()
-
- props.setProperty("spark.yarn.jar", "local:" + fakeSparkJar.getAbsolutePath())
-
- val childClasspath = logConfDir.getAbsolutePath() +
- File.pathSeparator +
- sys.props("java.class.path") +
- File.pathSeparator +
- extraClassPath.mkString(File.pathSeparator)
- props.setProperty("spark.driver.extraClassPath", childClasspath)
- props.setProperty("spark.executor.extraClassPath", childClasspath)
-
- // SPARK-4267: make sure java options are propagated correctly.
- props.setProperty("spark.driver.extraJavaOptions", "-Dfoo=\"one two three\"")
- props.setProperty("spark.executor.extraJavaOptions", "-Dfoo=\"one two three\"")
-
- yarnCluster.getConfig().foreach { e =>
- props.setProperty("spark.hadoop." + e.getKey(), e.getValue())
- }
-
- sys.props.foreach { case (k, v) =>
- if (k.startsWith("spark.")) {
- props.setProperty(k, v)
- }
- }
-
- extraConf.foreach { case (k, v) => props.setProperty(k, v) }
-
- val propsFile = File.createTempFile("spark", ".properties", tempDir)
- val writer = new OutputStreamWriter(new FileOutputStream(propsFile), UTF_8)
- props.store(writer, "Spark properties.")
- writer.close()
-
- val extraJarArgs = if (!extraJars.isEmpty()) Seq("--jars", extraJars.mkString(",")) else Nil
- val mainArgs =
- if (klass.endsWith(".py")) {
- Seq(klass)
- } else {
- Seq("--class", klass, fakeSparkJar.getAbsolutePath())
- }
- val argv =
- Seq(
- new File(sys.props("spark.test.home"), "bin/spark-submit").getAbsolutePath(),
- "--master", master,
- "--num-executors", "1",
- "--properties-file", propsFile.getAbsolutePath()) ++
- extraJarArgs ++
- sparkArgs ++
- mainArgs ++
- appArgs
-
- Utils.executeAndGetOutput(argv,
- extraEnvironment = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath()))
- }
-
- /**
- * This is a workaround for an issue with yarn-cluster mode: the Client class will not provide
- * any sort of error when the job process finishes successfully, but the job itself fails. So
- * the tests enforce that something is written to a file after everything is ok to indicate
- * that the job succeeded.
- */
- private def checkResult(result: File): Unit = {
- checkResult(result, "success")
- }
-
- private def checkResult(result: File, expected: String): Unit = {
- var resultString = Files.toString(result, UTF_8)
- resultString should be (expected)
- }
-
- private def mainClassName(klass: Class[_]): String = {
- klass.getName().stripSuffix("$")
- }
-
}
private[spark] class SaveExecutorInfo extends SparkListener {
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala
new file mode 100644
index 0000000000..5e8238822b
--- /dev/null
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala
@@ -0,0 +1,109 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.deploy.yarn
+
+import java.io.File
+
+import com.google.common.base.Charsets.UTF_8
+import com.google.common.io.Files
+import org.apache.commons.io.FileUtils
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.scalatest.Matchers
+
+import org.apache.spark._
+import org.apache.spark.network.shuffle.ShuffleTestAccessor
+import org.apache.spark.network.yarn.{YarnShuffleService, YarnTestAccessor}
+
+/**
+ * Integration test for the external shuffle service with a yarn mini-cluster
+ */
+class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite {
+
+ override def yarnConfig: YarnConfiguration = {
+ val yarnConfig = new YarnConfiguration()
+ yarnConfig.set(YarnConfiguration.NM_AUX_SERVICES, "spark_shuffle")
+ yarnConfig.set(YarnConfiguration.NM_AUX_SERVICE_FMT.format("spark_shuffle"),
+ classOf[YarnShuffleService].getCanonicalName)
+ yarnConfig.set("spark.shuffle.service.port", "0")
+ yarnConfig
+ }
+
+ test("external shuffle service") {
+ val shuffleServicePort = YarnTestAccessor.getShuffleServicePort
+ val shuffleService = YarnTestAccessor.getShuffleServiceInstance
+
+ val registeredExecFile = YarnTestAccessor.getRegisteredExecutorFile(shuffleService)
+
+ logInfo("Shuffle service port = " + shuffleServicePort)
+ val result = File.createTempFile("result", null, tempDir)
+ runSpark(
+ false,
+ mainClassName(YarnExternalShuffleDriver.getClass),
+ appArgs = Seq(result.getAbsolutePath(), registeredExecFile.getAbsolutePath),
+ extraConf = Map(
+ "spark.shuffle.service.enabled" -> "true",
+ "spark.shuffle.service.port" -> shuffleServicePort.toString
+ )
+ )
+ checkResult(result)
+ assert(YarnTestAccessor.getRegisteredExecutorFile(shuffleService).exists())
+ }
+}
+
+private object YarnExternalShuffleDriver extends Logging with Matchers {
+
+ val WAIT_TIMEOUT_MILLIS = 10000
+
+ def main(args: Array[String]): Unit = {
+ if (args.length != 2) {
+ // scalastyle:off println
+ System.err.println(
+ s"""
+ |Invalid command line: ${args.mkString(" ")}
+ |
+ |Usage: ExternalShuffleDriver [result file] [registed exec file]
+ """.stripMargin)
+ // scalastyle:on println
+ System.exit(1)
+ }
+
+ val sc = new SparkContext(new SparkConf()
+ .setAppName("External Shuffle Test"))
+ val conf = sc.getConf
+ val status = new File(args(0))
+ val registeredExecFile = new File(args(1))
+ logInfo("shuffle service executor file = " + registeredExecFile)
+ var result = "failure"
+ val execStateCopy = new File(registeredExecFile.getAbsolutePath + "_dup")
+ try {
+ val data = sc.parallelize(0 until 100, 10).map { x => (x % 10) -> x }.reduceByKey{ _ + _ }.
+ collect().toSet
+ sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ data should be ((0 until 10).map{x => x -> (x * 10 + 450)}.toSet)
+ result = "success"
+ // only one process can open a leveldb file at a time, so we copy the files
+ FileUtils.copyDirectory(registeredExecFile, execStateCopy)
+ assert(!ShuffleTestAccessor.reloadRegisteredExecutors(execStateCopy).isEmpty)
+ } finally {
+ sc.stop()
+ FileUtils.deleteDirectory(execStateCopy)
+ Files.write(result, status, UTF_8)
+ }
+ }
+
+}
diff --git a/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala b/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala
new file mode 100644
index 0000000000..aa46ec5100
--- /dev/null
+++ b/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.network.shuffle
+
+import java.io.{IOException, File}
+import java.util.concurrent.ConcurrentMap
+
+import com.google.common.annotations.VisibleForTesting
+import org.apache.hadoop.yarn.api.records.ApplicationId
+import org.fusesource.leveldbjni.JniDBFactory
+import org.iq80.leveldb.{DB, Options}
+
+import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
+
+/**
+ * just a cheat to get package-visible members in tests
+ */
+object ShuffleTestAccessor {
+
+ def getBlockResolver(handler: ExternalShuffleBlockHandler): ExternalShuffleBlockResolver = {
+ handler.blockManager
+ }
+
+ def getExecutorInfo(
+ appId: ApplicationId,
+ execId: String,
+ resolver: ExternalShuffleBlockResolver
+ ): Option[ExecutorShuffleInfo] = {
+ val id = new AppExecId(appId.toString, execId)
+ Option(resolver.executors.get(id))
+ }
+
+ def registeredExecutorFile(resolver: ExternalShuffleBlockResolver): File = {
+ resolver.registeredExecutorFile
+ }
+
+ def shuffleServiceLevelDB(resolver: ExternalShuffleBlockResolver): DB = {
+ resolver.db
+ }
+
+ def reloadRegisteredExecutors(
+ file: File): ConcurrentMap[ExternalShuffleBlockResolver.AppExecId, ExecutorShuffleInfo] = {
+ val options: Options = new Options
+ options.createIfMissing(true)
+ val factory = new JniDBFactory
+ val db = factory.open(file, options)
+ val result = ExternalShuffleBlockResolver.reloadRegisteredExecutors(db)
+ db.close()
+ result
+ }
+
+ def reloadRegisteredExecutors(
+ db: DB): ConcurrentMap[ExternalShuffleBlockResolver.AppExecId, ExecutorShuffleInfo] = {
+ ExternalShuffleBlockResolver.reloadRegisteredExecutors(db)
+ }
+}
diff --git a/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala
new file mode 100644
index 0000000000..2f22cbdbea
--- /dev/null
+++ b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.network.yarn
+
+import java.io.{DataOutputStream, File, FileOutputStream}
+
+import scala.annotation.tailrec
+
+import org.apache.commons.io.FileUtils
+import org.apache.hadoop.yarn.api.records.ApplicationId
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.server.api.{ApplicationInitializationContext, ApplicationTerminationContext}
+import org.scalatest.{BeforeAndAfterEach, Matchers}
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.network.shuffle.ShuffleTestAccessor
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
+
+class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach {
+ private[yarn] var yarnConfig: YarnConfiguration = new YarnConfiguration
+
+ override def beforeEach(): Unit = {
+ yarnConfig.set(YarnConfiguration.NM_AUX_SERVICES, "spark_shuffle")
+ yarnConfig.set(YarnConfiguration.NM_AUX_SERVICE_FMT.format("spark_shuffle"),
+ classOf[YarnShuffleService].getCanonicalName)
+
+ yarnConfig.get("yarn.nodemanager.local-dirs").split(",").foreach { dir =>
+ val d = new File(dir)
+ if (d.exists()) {
+ FileUtils.deleteDirectory(d)
+ }
+ FileUtils.forceMkdir(d)
+ logInfo(s"creating yarn.nodemanager.local-dirs: $d")
+ }
+ }
+
+ var s1: YarnShuffleService = null
+ var s2: YarnShuffleService = null
+ var s3: YarnShuffleService = null
+
+ override def afterEach(): Unit = {
+ if (s1 != null) {
+ s1.stop()
+ s1 = null
+ }
+ if (s2 != null) {
+ s2.stop()
+ s2 = null
+ }
+ if (s3 != null) {
+ s3.stop()
+ s3 = null
+ }
+ }
+
+ test("executor state kept across NM restart") {
+ s1 = new YarnShuffleService
+ s1.init(yarnConfig)
+ val app1Id = ApplicationId.newInstance(0, 1)
+ val app1Data: ApplicationInitializationContext =
+ new ApplicationInitializationContext("user", app1Id, null)
+ s1.initializeApplication(app1Data)
+ val app2Id = ApplicationId.newInstance(0, 2)
+ val app2Data: ApplicationInitializationContext =
+ new ApplicationInitializationContext("user", app2Id, null)
+ s1.initializeApplication(app2Data)
+
+ val execStateFile = s1.registeredExecutorFile
+ execStateFile should not be (null)
+ val shuffleInfo1 = new ExecutorShuffleInfo(Array("/foo", "/bar"), 3, "sort")
+ val shuffleInfo2 = new ExecutorShuffleInfo(Array("/bippy"), 5, "hash")
+
+ val blockHandler = s1.blockHandler
+ val blockResolver = ShuffleTestAccessor.getBlockResolver(blockHandler)
+ ShuffleTestAccessor.registeredExecutorFile(blockResolver) should be (execStateFile)
+
+ blockResolver.registerExecutor(app1Id.toString, "exec-1", shuffleInfo1)
+ blockResolver.registerExecutor(app2Id.toString, "exec-2", shuffleInfo2)
+ ShuffleTestAccessor.getExecutorInfo(app1Id, "exec-1", blockResolver) should
+ be (Some(shuffleInfo1))
+ ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", blockResolver) should
+ be (Some(shuffleInfo2))
+
+ if (!execStateFile.exists()) {
+ @tailrec def findExistingParent(file: File): File = {
+ if (file == null) file
+ else if (file.exists()) file
+ else findExistingParent(file.getParentFile())
+ }
+ val existingParent = findExistingParent(execStateFile)
+ assert(false, s"$execStateFile does not exist -- closest existing parent is $existingParent")
+ }
+ assert(execStateFile.exists(), s"$execStateFile did not exist")
+
+ // now we pretend the shuffle service goes down, and comes back up
+ s1.stop()
+ s2 = new YarnShuffleService
+ s2.init(yarnConfig)
+ s2.registeredExecutorFile should be (execStateFile)
+
+ val handler2 = s2.blockHandler
+ val resolver2 = ShuffleTestAccessor.getBlockResolver(handler2)
+
+ // now we reinitialize only one of the apps, and expect yarn to tell us that app2 was stopped
+ // during the restart
+ s2.initializeApplication(app1Data)
+ s2.stopApplication(new ApplicationTerminationContext(app2Id))
+ ShuffleTestAccessor.getExecutorInfo(app1Id, "exec-1", resolver2) should be (Some(shuffleInfo1))
+ ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver2) should be (None)
+
+ // Act like the NM restarts one more time
+ s2.stop()
+ s3 = new YarnShuffleService
+ s3.init(yarnConfig)
+ s3.registeredExecutorFile should be (execStateFile)
+
+ val handler3 = s3.blockHandler
+ val resolver3 = ShuffleTestAccessor.getBlockResolver(handler3)
+
+ // app1 is still running
+ s3.initializeApplication(app1Data)
+ ShuffleTestAccessor.getExecutorInfo(app1Id, "exec-1", resolver3) should be (Some(shuffleInfo1))
+ ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver3) should be (None)
+ s3.stop()
+ }
+
+ test("removed applications should not be in registered executor file") {
+ s1 = new YarnShuffleService
+ s1.init(yarnConfig)
+ val app1Id = ApplicationId.newInstance(0, 1)
+ val app1Data: ApplicationInitializationContext =
+ new ApplicationInitializationContext("user", app1Id, null)
+ s1.initializeApplication(app1Data)
+ val app2Id = ApplicationId.newInstance(0, 2)
+ val app2Data: ApplicationInitializationContext =
+ new ApplicationInitializationContext("user", app2Id, null)
+ s1.initializeApplication(app2Data)
+
+ val execStateFile = s1.registeredExecutorFile
+ execStateFile should not be (null)
+ val shuffleInfo1 = new ExecutorShuffleInfo(Array("/foo", "/bar"), 3, "sort")
+ val shuffleInfo2 = new ExecutorShuffleInfo(Array("/bippy"), 5, "hash")
+
+ val blockHandler = s1.blockHandler
+ val blockResolver = ShuffleTestAccessor.getBlockResolver(blockHandler)
+ ShuffleTestAccessor.registeredExecutorFile(blockResolver) should be (execStateFile)
+
+ blockResolver.registerExecutor(app1Id.toString, "exec-1", shuffleInfo1)
+ blockResolver.registerExecutor(app2Id.toString, "exec-2", shuffleInfo2)
+
+ val db = ShuffleTestAccessor.shuffleServiceLevelDB(blockResolver)
+ ShuffleTestAccessor.reloadRegisteredExecutors(db) should not be empty
+
+ s1.stopApplication(new ApplicationTerminationContext(app1Id))
+ ShuffleTestAccessor.reloadRegisteredExecutors(db) should not be empty
+ s1.stopApplication(new ApplicationTerminationContext(app2Id))
+ ShuffleTestAccessor.reloadRegisteredExecutors(db) shouldBe empty
+ }
+
+ test("shuffle service should be robust to corrupt registered executor file") {
+ s1 = new YarnShuffleService
+ s1.init(yarnConfig)
+ val app1Id = ApplicationId.newInstance(0, 1)
+ val app1Data: ApplicationInitializationContext =
+ new ApplicationInitializationContext("user", app1Id, null)
+ s1.initializeApplication(app1Data)
+
+ val execStateFile = s1.registeredExecutorFile
+ val shuffleInfo1 = new ExecutorShuffleInfo(Array("/foo", "/bar"), 3, "sort")
+
+ val blockHandler = s1.blockHandler
+ val blockResolver = ShuffleTestAccessor.getBlockResolver(blockHandler)
+ ShuffleTestAccessor.registeredExecutorFile(blockResolver) should be (execStateFile)
+
+ blockResolver.registerExecutor(app1Id.toString, "exec-1", shuffleInfo1)
+
+ // now we pretend the shuffle service goes down, and comes back up. But we'll also
+ // make a corrupt registeredExecutor File
+ s1.stop()
+
+ execStateFile.listFiles().foreach{_.delete()}
+
+ val out = new DataOutputStream(new FileOutputStream(execStateFile + "/CURRENT"))
+ out.writeInt(42)
+ out.close()
+
+ s2 = new YarnShuffleService
+ s2.init(yarnConfig)
+ s2.registeredExecutorFile should be (execStateFile)
+
+ val handler2 = s2.blockHandler
+ val resolver2 = ShuffleTestAccessor.getBlockResolver(handler2)
+
+ // we re-initialize app1, but since the file was corrupt there is nothing we can do about it ...
+ s2.initializeApplication(app1Data)
+ // however, when we initialize a totally new app2, everything is still happy
+ val app2Id = ApplicationId.newInstance(0, 2)
+ val app2Data: ApplicationInitializationContext =
+ new ApplicationInitializationContext("user", app2Id, null)
+ s2.initializeApplication(app2Data)
+ val shuffleInfo2 = new ExecutorShuffleInfo(Array("/bippy"), 5, "hash")
+ resolver2.registerExecutor(app2Id.toString, "exec-2", shuffleInfo2)
+ ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver2) should be (Some(shuffleInfo2))
+ s2.stop()
+
+ // another stop & restart should be fine though (eg., we recover from previous corruption)
+ s3 = new YarnShuffleService
+ s3.init(yarnConfig)
+ s3.registeredExecutorFile should be (execStateFile)
+ val handler3 = s3.blockHandler
+ val resolver3 = ShuffleTestAccessor.getBlockResolver(handler3)
+
+ s3.initializeApplication(app2Data)
+ ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver3) should be (Some(shuffleInfo2))
+ s3.stop()
+
+ }
+
+}
diff --git a/yarn/src/test/scala/org/apache/spark/network/yarn/YarnTestAccessor.scala b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnTestAccessor.scala
new file mode 100644
index 0000000000..db322cd18e
--- /dev/null
+++ b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnTestAccessor.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.network.yarn
+
+import java.io.File
+
+/**
+ * just a cheat to get package-visible members in tests
+ */
+object YarnTestAccessor {
+ def getShuffleServicePort: Int = {
+ YarnShuffleService.boundPort
+ }
+
+ def getShuffleServiceInstance: YarnShuffleService = {
+ YarnShuffleService.instance
+ }
+
+ def getRegisteredExecutorFile(service: YarnShuffleService): File = {
+ service.registeredExecutorFile
+ }
+
+}