diff options
Diffstat (limited to 'yarn')
7 files changed, 658 insertions, 164 deletions
diff --git a/yarn/pom.xml b/yarn/pom.xml index 15db54e4e7..f673769530 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -40,6 +40,12 @@ </dependency> <dependency> <groupId>org.apache.spark</groupId> + <artifactId>spark-network-yarn_${scala.binary.version}</artifactId> + <version>${project.version}</version> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.apache.spark</groupId> <artifactId>spark-core_${scala.binary.version}</artifactId> <version>${project.version}</version> <type>test-jar</type> diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala new file mode 100644 index 0000000000..128e996b71 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.io.{File, FileOutputStream, OutputStreamWriter} +import java.util.Properties +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConversions._ + +import com.google.common.base.Charsets.UTF_8 +import com.google.common.io.Files +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.server.MiniYARNCluster +import org.scalatest.{BeforeAndAfterAll, Matchers} + +import org.apache.spark._ +import org.apache.spark.util.Utils + +abstract class BaseYarnClusterSuite + extends SparkFunSuite with BeforeAndAfterAll with Matchers with Logging { + + // log4j configuration for the YARN containers, so that their output is collected + // by YARN instead of trying to overwrite unit-tests.log. + protected val LOG4J_CONF = """ + |log4j.rootCategory=DEBUG, console + |log4j.appender.console=org.apache.log4j.ConsoleAppender + |log4j.appender.console.target=System.err + |log4j.appender.console.layout=org.apache.log4j.PatternLayout + |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + """.stripMargin + + private var yarnCluster: MiniYARNCluster = _ + protected var tempDir: File = _ + private var fakeSparkJar: File = _ + private var hadoopConfDir: File = _ + private var logConfDir: File = _ + + + def yarnConfig: YarnConfiguration + + override def beforeAll() { + super.beforeAll() + + tempDir = Utils.createTempDir() + logConfDir = new File(tempDir, "log4j") + logConfDir.mkdir() + System.setProperty("SPARK_YARN_MODE", "true") + + val logConfFile = new File(logConfDir, "log4j.properties") + Files.write(LOG4J_CONF, logConfFile, UTF_8) + + yarnCluster = new MiniYARNCluster(getClass().getName(), 1, 1, 1) + yarnCluster.init(yarnConfig) + yarnCluster.start() + + // There's a race in MiniYARNCluster in which start() may return before the RM has updated + // its address in the configuration. You can see this in the logs by noticing that when + // MiniYARNCluster prints the address, it still has port "0" assigned, although later the + // test works sometimes: + // + // INFO MiniYARNCluster: MiniYARN ResourceManager address: blah:0 + // + // That log message prints the contents of the RM_ADDRESS config variable. If you check it + // later on, it looks something like this: + // + // INFO YarnClusterSuite: RM address in configuration is blah:42631 + // + // This hack loops for a bit waiting for the port to change, and fails the test if it hasn't + // done so in a timely manner (defined to be 10 seconds). + val config = yarnCluster.getConfig() + val deadline = System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(10) + while (config.get(YarnConfiguration.RM_ADDRESS).split(":")(1) == "0") { + if (System.currentTimeMillis() > deadline) { + throw new IllegalStateException("Timed out waiting for RM to come up.") + } + logDebug("RM address still not set in configuration, waiting...") + TimeUnit.MILLISECONDS.sleep(100) + } + + logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}") + + fakeSparkJar = File.createTempFile("sparkJar", null, tempDir) + hadoopConfDir = new File(tempDir, Client.LOCALIZED_CONF_DIR) + assert(hadoopConfDir.mkdir()) + File.createTempFile("token", ".txt", hadoopConfDir) + } + + override def afterAll() { + yarnCluster.stop() + System.clearProperty("SPARK_YARN_MODE") + super.afterAll() + } + + protected def runSpark( + clientMode: Boolean, + klass: String, + appArgs: Seq[String] = Nil, + sparkArgs: Seq[String] = Nil, + extraClassPath: Seq[String] = Nil, + extraJars: Seq[String] = Nil, + extraConf: Map[String, String] = Map()): Unit = { + val master = if (clientMode) "yarn-client" else "yarn-cluster" + val props = new Properties() + + props.setProperty("spark.yarn.jar", "local:" + fakeSparkJar.getAbsolutePath()) + + val childClasspath = logConfDir.getAbsolutePath() + + File.pathSeparator + + sys.props("java.class.path") + + File.pathSeparator + + extraClassPath.mkString(File.pathSeparator) + props.setProperty("spark.driver.extraClassPath", childClasspath) + props.setProperty("spark.executor.extraClassPath", childClasspath) + + // SPARK-4267: make sure java options are propagated correctly. + props.setProperty("spark.driver.extraJavaOptions", "-Dfoo=\"one two three\"") + props.setProperty("spark.executor.extraJavaOptions", "-Dfoo=\"one two three\"") + + yarnCluster.getConfig().foreach { e => + props.setProperty("spark.hadoop." + e.getKey(), e.getValue()) + } + + sys.props.foreach { case (k, v) => + if (k.startsWith("spark.")) { + props.setProperty(k, v) + } + } + + extraConf.foreach { case (k, v) => props.setProperty(k, v) } + + val propsFile = File.createTempFile("spark", ".properties", tempDir) + val writer = new OutputStreamWriter(new FileOutputStream(propsFile), UTF_8) + props.store(writer, "Spark properties.") + writer.close() + + val extraJarArgs = if (!extraJars.isEmpty()) Seq("--jars", extraJars.mkString(",")) else Nil + val mainArgs = + if (klass.endsWith(".py")) { + Seq(klass) + } else { + Seq("--class", klass, fakeSparkJar.getAbsolutePath()) + } + val argv = + Seq( + new File(sys.props("spark.test.home"), "bin/spark-submit").getAbsolutePath(), + "--master", master, + "--num-executors", "1", + "--properties-file", propsFile.getAbsolutePath()) ++ + extraJarArgs ++ + sparkArgs ++ + mainArgs ++ + appArgs + + Utils.executeAndGetOutput(argv, + extraEnvironment = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath())) + } + + /** + * This is a workaround for an issue with yarn-cluster mode: the Client class will not provide + * any sort of error when the job process finishes successfully, but the job itself fails. So + * the tests enforce that something is written to a file after everything is ok to indicate + * that the job succeeded. + */ + protected def checkResult(result: File): Unit = { + checkResult(result, "success") + } + + protected def checkResult(result: File, expected: String): Unit = { + val resultString = Files.toString(result, UTF_8) + resultString should be (expected) + } + + protected def mainClassName(klass: Class[_]): String = { + klass.getName().stripSuffix("$") + } + +} diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index eb6e1fd370..128350b648 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -17,25 +17,20 @@ package org.apache.spark.deploy.yarn -import java.io.{File, FileOutputStream, OutputStreamWriter} +import java.io.File import java.net.URL -import java.util.Properties -import java.util.concurrent.TimeUnit -import scala.collection.JavaConversions._ import scala.collection.mutable +import scala.collection.JavaConversions._ import com.google.common.base.Charsets.UTF_8 -import com.google.common.io.ByteStreams -import com.google.common.io.Files +import com.google.common.io.{ByteStreams, Files} import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.hadoop.yarn.server.MiniYARNCluster -import org.scalatest.{BeforeAndAfterAll, Matchers} +import org.scalatest.Matchers import org.apache.spark._ +import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart, SparkListenerExecutorAdded} import org.apache.spark.scheduler.cluster.ExecutorInfo -import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart, - SparkListenerExecutorAdded} import org.apache.spark.util.Utils /** @@ -43,17 +38,9 @@ import org.apache.spark.util.Utils * applications, and require the Spark assembly to be built before they can be successfully * run. */ -class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matchers with Logging { - - // log4j configuration for the YARN containers, so that their output is collected - // by YARN instead of trying to overwrite unit-tests.log. - private val LOG4J_CONF = """ - |log4j.rootCategory=DEBUG, console - |log4j.appender.console=org.apache.log4j.ConsoleAppender - |log4j.appender.console.target=System.err - |log4j.appender.console.layout=org.apache.log4j.PatternLayout - |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n - """.stripMargin +class YarnClusterSuite extends BaseYarnClusterSuite { + + override def yarnConfig: YarnConfiguration = new YarnConfiguration() private val TEST_PYFILE = """ |import mod1, mod2 @@ -82,65 +69,6 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher | return 42 """.stripMargin - private var yarnCluster: MiniYARNCluster = _ - private var tempDir: File = _ - private var fakeSparkJar: File = _ - private var hadoopConfDir: File = _ - private var logConfDir: File = _ - - override def beforeAll() { - super.beforeAll() - - tempDir = Utils.createTempDir() - logConfDir = new File(tempDir, "log4j") - logConfDir.mkdir() - System.setProperty("SPARK_YARN_MODE", "true") - - val logConfFile = new File(logConfDir, "log4j.properties") - Files.write(LOG4J_CONF, logConfFile, UTF_8) - - yarnCluster = new MiniYARNCluster(getClass().getName(), 1, 1, 1) - yarnCluster.init(new YarnConfiguration()) - yarnCluster.start() - - // There's a race in MiniYARNCluster in which start() may return before the RM has updated - // its address in the configuration. You can see this in the logs by noticing that when - // MiniYARNCluster prints the address, it still has port "0" assigned, although later the - // test works sometimes: - // - // INFO MiniYARNCluster: MiniYARN ResourceManager address: blah:0 - // - // That log message prints the contents of the RM_ADDRESS config variable. If you check it - // later on, it looks something like this: - // - // INFO YarnClusterSuite: RM address in configuration is blah:42631 - // - // This hack loops for a bit waiting for the port to change, and fails the test if it hasn't - // done so in a timely manner (defined to be 10 seconds). - val config = yarnCluster.getConfig() - val deadline = System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(10) - while (config.get(YarnConfiguration.RM_ADDRESS).split(":")(1) == "0") { - if (System.currentTimeMillis() > deadline) { - throw new IllegalStateException("Timed out waiting for RM to come up.") - } - logDebug("RM address still not set in configuration, waiting...") - TimeUnit.MILLISECONDS.sleep(100) - } - - logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}") - - fakeSparkJar = File.createTempFile("sparkJar", null, tempDir) - hadoopConfDir = new File(tempDir, Client.LOCALIZED_CONF_DIR) - assert(hadoopConfDir.mkdir()) - File.createTempFile("token", ".txt", hadoopConfDir) - } - - override def afterAll() { - yarnCluster.stop() - System.clearProperty("SPARK_YARN_MODE") - super.afterAll() - } - test("run Spark in yarn-client mode") { testBasicYarnApp(true) } @@ -174,7 +102,7 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher } private def testBasicYarnApp(clientMode: Boolean): Unit = { - var result = File.createTempFile("result", null, tempDir) + val result = File.createTempFile("result", null, tempDir) runSpark(clientMode, mainClassName(YarnClusterDriver.getClass), appArgs = Seq(result.getAbsolutePath())) checkResult(result) @@ -224,89 +152,6 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher checkResult(executorResult, "OVERRIDDEN") } - private def runSpark( - clientMode: Boolean, - klass: String, - appArgs: Seq[String] = Nil, - sparkArgs: Seq[String] = Nil, - extraClassPath: Seq[String] = Nil, - extraJars: Seq[String] = Nil, - extraConf: Map[String, String] = Map()): Unit = { - val master = if (clientMode) "yarn-client" else "yarn-cluster" - val props = new Properties() - - props.setProperty("spark.yarn.jar", "local:" + fakeSparkJar.getAbsolutePath()) - - val childClasspath = logConfDir.getAbsolutePath() + - File.pathSeparator + - sys.props("java.class.path") + - File.pathSeparator + - extraClassPath.mkString(File.pathSeparator) - props.setProperty("spark.driver.extraClassPath", childClasspath) - props.setProperty("spark.executor.extraClassPath", childClasspath) - - // SPARK-4267: make sure java options are propagated correctly. - props.setProperty("spark.driver.extraJavaOptions", "-Dfoo=\"one two three\"") - props.setProperty("spark.executor.extraJavaOptions", "-Dfoo=\"one two three\"") - - yarnCluster.getConfig().foreach { e => - props.setProperty("spark.hadoop." + e.getKey(), e.getValue()) - } - - sys.props.foreach { case (k, v) => - if (k.startsWith("spark.")) { - props.setProperty(k, v) - } - } - - extraConf.foreach { case (k, v) => props.setProperty(k, v) } - - val propsFile = File.createTempFile("spark", ".properties", tempDir) - val writer = new OutputStreamWriter(new FileOutputStream(propsFile), UTF_8) - props.store(writer, "Spark properties.") - writer.close() - - val extraJarArgs = if (!extraJars.isEmpty()) Seq("--jars", extraJars.mkString(",")) else Nil - val mainArgs = - if (klass.endsWith(".py")) { - Seq(klass) - } else { - Seq("--class", klass, fakeSparkJar.getAbsolutePath()) - } - val argv = - Seq( - new File(sys.props("spark.test.home"), "bin/spark-submit").getAbsolutePath(), - "--master", master, - "--num-executors", "1", - "--properties-file", propsFile.getAbsolutePath()) ++ - extraJarArgs ++ - sparkArgs ++ - mainArgs ++ - appArgs - - Utils.executeAndGetOutput(argv, - extraEnvironment = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath())) - } - - /** - * This is a workaround for an issue with yarn-cluster mode: the Client class will not provide - * any sort of error when the job process finishes successfully, but the job itself fails. So - * the tests enforce that something is written to a file after everything is ok to indicate - * that the job succeeded. - */ - private def checkResult(result: File): Unit = { - checkResult(result, "success") - } - - private def checkResult(result: File, expected: String): Unit = { - var resultString = Files.toString(result, UTF_8) - resultString should be (expected) - } - - private def mainClassName(klass: Class[_]): String = { - klass.getName().stripSuffix("$") - } - } private[spark] class SaveExecutorInfo extends SparkListener { diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala new file mode 100644 index 0000000000..5e8238822b --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala @@ -0,0 +1,109 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.deploy.yarn + +import java.io.File + +import com.google.common.base.Charsets.UTF_8 +import com.google.common.io.Files +import org.apache.commons.io.FileUtils +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.scalatest.Matchers + +import org.apache.spark._ +import org.apache.spark.network.shuffle.ShuffleTestAccessor +import org.apache.spark.network.yarn.{YarnShuffleService, YarnTestAccessor} + +/** + * Integration test for the external shuffle service with a yarn mini-cluster + */ +class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite { + + override def yarnConfig: YarnConfiguration = { + val yarnConfig = new YarnConfiguration() + yarnConfig.set(YarnConfiguration.NM_AUX_SERVICES, "spark_shuffle") + yarnConfig.set(YarnConfiguration.NM_AUX_SERVICE_FMT.format("spark_shuffle"), + classOf[YarnShuffleService].getCanonicalName) + yarnConfig.set("spark.shuffle.service.port", "0") + yarnConfig + } + + test("external shuffle service") { + val shuffleServicePort = YarnTestAccessor.getShuffleServicePort + val shuffleService = YarnTestAccessor.getShuffleServiceInstance + + val registeredExecFile = YarnTestAccessor.getRegisteredExecutorFile(shuffleService) + + logInfo("Shuffle service port = " + shuffleServicePort) + val result = File.createTempFile("result", null, tempDir) + runSpark( + false, + mainClassName(YarnExternalShuffleDriver.getClass), + appArgs = Seq(result.getAbsolutePath(), registeredExecFile.getAbsolutePath), + extraConf = Map( + "spark.shuffle.service.enabled" -> "true", + "spark.shuffle.service.port" -> shuffleServicePort.toString + ) + ) + checkResult(result) + assert(YarnTestAccessor.getRegisteredExecutorFile(shuffleService).exists()) + } +} + +private object YarnExternalShuffleDriver extends Logging with Matchers { + + val WAIT_TIMEOUT_MILLIS = 10000 + + def main(args: Array[String]): Unit = { + if (args.length != 2) { + // scalastyle:off println + System.err.println( + s""" + |Invalid command line: ${args.mkString(" ")} + | + |Usage: ExternalShuffleDriver [result file] [registed exec file] + """.stripMargin) + // scalastyle:on println + System.exit(1) + } + + val sc = new SparkContext(new SparkConf() + .setAppName("External Shuffle Test")) + val conf = sc.getConf + val status = new File(args(0)) + val registeredExecFile = new File(args(1)) + logInfo("shuffle service executor file = " + registeredExecFile) + var result = "failure" + val execStateCopy = new File(registeredExecFile.getAbsolutePath + "_dup") + try { + val data = sc.parallelize(0 until 100, 10).map { x => (x % 10) -> x }.reduceByKey{ _ + _ }. + collect().toSet + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + data should be ((0 until 10).map{x => x -> (x * 10 + 450)}.toSet) + result = "success" + // only one process can open a leveldb file at a time, so we copy the files + FileUtils.copyDirectory(registeredExecFile, execStateCopy) + assert(!ShuffleTestAccessor.reloadRegisteredExecutors(execStateCopy).isEmpty) + } finally { + sc.stop() + FileUtils.deleteDirectory(execStateCopy) + Files.write(result, status, UTF_8) + } + } + +} diff --git a/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala b/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala new file mode 100644 index 0000000000..aa46ec5100 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.shuffle + +import java.io.{IOException, File} +import java.util.concurrent.ConcurrentMap + +import com.google.common.annotations.VisibleForTesting +import org.apache.hadoop.yarn.api.records.ApplicationId +import org.fusesource.leveldbjni.JniDBFactory +import org.iq80.leveldb.{DB, Options} + +import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo + +/** + * just a cheat to get package-visible members in tests + */ +object ShuffleTestAccessor { + + def getBlockResolver(handler: ExternalShuffleBlockHandler): ExternalShuffleBlockResolver = { + handler.blockManager + } + + def getExecutorInfo( + appId: ApplicationId, + execId: String, + resolver: ExternalShuffleBlockResolver + ): Option[ExecutorShuffleInfo] = { + val id = new AppExecId(appId.toString, execId) + Option(resolver.executors.get(id)) + } + + def registeredExecutorFile(resolver: ExternalShuffleBlockResolver): File = { + resolver.registeredExecutorFile + } + + def shuffleServiceLevelDB(resolver: ExternalShuffleBlockResolver): DB = { + resolver.db + } + + def reloadRegisteredExecutors( + file: File): ConcurrentMap[ExternalShuffleBlockResolver.AppExecId, ExecutorShuffleInfo] = { + val options: Options = new Options + options.createIfMissing(true) + val factory = new JniDBFactory + val db = factory.open(file, options) + val result = ExternalShuffleBlockResolver.reloadRegisteredExecutors(db) + db.close() + result + } + + def reloadRegisteredExecutors( + db: DB): ConcurrentMap[ExternalShuffleBlockResolver.AppExecId, ExecutorShuffleInfo] = { + ExternalShuffleBlockResolver.reloadRegisteredExecutors(db) + } +} diff --git a/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala new file mode 100644 index 0000000000..2f22cbdbea --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.yarn + +import java.io.{DataOutputStream, File, FileOutputStream} + +import scala.annotation.tailrec + +import org.apache.commons.io.FileUtils +import org.apache.hadoop.yarn.api.records.ApplicationId +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.server.api.{ApplicationInitializationContext, ApplicationTerminationContext} +import org.scalatest.{BeforeAndAfterEach, Matchers} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.network.shuffle.ShuffleTestAccessor +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo + +class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach { + private[yarn] var yarnConfig: YarnConfiguration = new YarnConfiguration + + override def beforeEach(): Unit = { + yarnConfig.set(YarnConfiguration.NM_AUX_SERVICES, "spark_shuffle") + yarnConfig.set(YarnConfiguration.NM_AUX_SERVICE_FMT.format("spark_shuffle"), + classOf[YarnShuffleService].getCanonicalName) + + yarnConfig.get("yarn.nodemanager.local-dirs").split(",").foreach { dir => + val d = new File(dir) + if (d.exists()) { + FileUtils.deleteDirectory(d) + } + FileUtils.forceMkdir(d) + logInfo(s"creating yarn.nodemanager.local-dirs: $d") + } + } + + var s1: YarnShuffleService = null + var s2: YarnShuffleService = null + var s3: YarnShuffleService = null + + override def afterEach(): Unit = { + if (s1 != null) { + s1.stop() + s1 = null + } + if (s2 != null) { + s2.stop() + s2 = null + } + if (s3 != null) { + s3.stop() + s3 = null + } + } + + test("executor state kept across NM restart") { + s1 = new YarnShuffleService + s1.init(yarnConfig) + val app1Id = ApplicationId.newInstance(0, 1) + val app1Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app1Id, null) + s1.initializeApplication(app1Data) + val app2Id = ApplicationId.newInstance(0, 2) + val app2Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app2Id, null) + s1.initializeApplication(app2Data) + + val execStateFile = s1.registeredExecutorFile + execStateFile should not be (null) + val shuffleInfo1 = new ExecutorShuffleInfo(Array("/foo", "/bar"), 3, "sort") + val shuffleInfo2 = new ExecutorShuffleInfo(Array("/bippy"), 5, "hash") + + val blockHandler = s1.blockHandler + val blockResolver = ShuffleTestAccessor.getBlockResolver(blockHandler) + ShuffleTestAccessor.registeredExecutorFile(blockResolver) should be (execStateFile) + + blockResolver.registerExecutor(app1Id.toString, "exec-1", shuffleInfo1) + blockResolver.registerExecutor(app2Id.toString, "exec-2", shuffleInfo2) + ShuffleTestAccessor.getExecutorInfo(app1Id, "exec-1", blockResolver) should + be (Some(shuffleInfo1)) + ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", blockResolver) should + be (Some(shuffleInfo2)) + + if (!execStateFile.exists()) { + @tailrec def findExistingParent(file: File): File = { + if (file == null) file + else if (file.exists()) file + else findExistingParent(file.getParentFile()) + } + val existingParent = findExistingParent(execStateFile) + assert(false, s"$execStateFile does not exist -- closest existing parent is $existingParent") + } + assert(execStateFile.exists(), s"$execStateFile did not exist") + + // now we pretend the shuffle service goes down, and comes back up + s1.stop() + s2 = new YarnShuffleService + s2.init(yarnConfig) + s2.registeredExecutorFile should be (execStateFile) + + val handler2 = s2.blockHandler + val resolver2 = ShuffleTestAccessor.getBlockResolver(handler2) + + // now we reinitialize only one of the apps, and expect yarn to tell us that app2 was stopped + // during the restart + s2.initializeApplication(app1Data) + s2.stopApplication(new ApplicationTerminationContext(app2Id)) + ShuffleTestAccessor.getExecutorInfo(app1Id, "exec-1", resolver2) should be (Some(shuffleInfo1)) + ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver2) should be (None) + + // Act like the NM restarts one more time + s2.stop() + s3 = new YarnShuffleService + s3.init(yarnConfig) + s3.registeredExecutorFile should be (execStateFile) + + val handler3 = s3.blockHandler + val resolver3 = ShuffleTestAccessor.getBlockResolver(handler3) + + // app1 is still running + s3.initializeApplication(app1Data) + ShuffleTestAccessor.getExecutorInfo(app1Id, "exec-1", resolver3) should be (Some(shuffleInfo1)) + ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver3) should be (None) + s3.stop() + } + + test("removed applications should not be in registered executor file") { + s1 = new YarnShuffleService + s1.init(yarnConfig) + val app1Id = ApplicationId.newInstance(0, 1) + val app1Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app1Id, null) + s1.initializeApplication(app1Data) + val app2Id = ApplicationId.newInstance(0, 2) + val app2Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app2Id, null) + s1.initializeApplication(app2Data) + + val execStateFile = s1.registeredExecutorFile + execStateFile should not be (null) + val shuffleInfo1 = new ExecutorShuffleInfo(Array("/foo", "/bar"), 3, "sort") + val shuffleInfo2 = new ExecutorShuffleInfo(Array("/bippy"), 5, "hash") + + val blockHandler = s1.blockHandler + val blockResolver = ShuffleTestAccessor.getBlockResolver(blockHandler) + ShuffleTestAccessor.registeredExecutorFile(blockResolver) should be (execStateFile) + + blockResolver.registerExecutor(app1Id.toString, "exec-1", shuffleInfo1) + blockResolver.registerExecutor(app2Id.toString, "exec-2", shuffleInfo2) + + val db = ShuffleTestAccessor.shuffleServiceLevelDB(blockResolver) + ShuffleTestAccessor.reloadRegisteredExecutors(db) should not be empty + + s1.stopApplication(new ApplicationTerminationContext(app1Id)) + ShuffleTestAccessor.reloadRegisteredExecutors(db) should not be empty + s1.stopApplication(new ApplicationTerminationContext(app2Id)) + ShuffleTestAccessor.reloadRegisteredExecutors(db) shouldBe empty + } + + test("shuffle service should be robust to corrupt registered executor file") { + s1 = new YarnShuffleService + s1.init(yarnConfig) + val app1Id = ApplicationId.newInstance(0, 1) + val app1Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app1Id, null) + s1.initializeApplication(app1Data) + + val execStateFile = s1.registeredExecutorFile + val shuffleInfo1 = new ExecutorShuffleInfo(Array("/foo", "/bar"), 3, "sort") + + val blockHandler = s1.blockHandler + val blockResolver = ShuffleTestAccessor.getBlockResolver(blockHandler) + ShuffleTestAccessor.registeredExecutorFile(blockResolver) should be (execStateFile) + + blockResolver.registerExecutor(app1Id.toString, "exec-1", shuffleInfo1) + + // now we pretend the shuffle service goes down, and comes back up. But we'll also + // make a corrupt registeredExecutor File + s1.stop() + + execStateFile.listFiles().foreach{_.delete()} + + val out = new DataOutputStream(new FileOutputStream(execStateFile + "/CURRENT")) + out.writeInt(42) + out.close() + + s2 = new YarnShuffleService + s2.init(yarnConfig) + s2.registeredExecutorFile should be (execStateFile) + + val handler2 = s2.blockHandler + val resolver2 = ShuffleTestAccessor.getBlockResolver(handler2) + + // we re-initialize app1, but since the file was corrupt there is nothing we can do about it ... + s2.initializeApplication(app1Data) + // however, when we initialize a totally new app2, everything is still happy + val app2Id = ApplicationId.newInstance(0, 2) + val app2Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app2Id, null) + s2.initializeApplication(app2Data) + val shuffleInfo2 = new ExecutorShuffleInfo(Array("/bippy"), 5, "hash") + resolver2.registerExecutor(app2Id.toString, "exec-2", shuffleInfo2) + ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver2) should be (Some(shuffleInfo2)) + s2.stop() + + // another stop & restart should be fine though (eg., we recover from previous corruption) + s3 = new YarnShuffleService + s3.init(yarnConfig) + s3.registeredExecutorFile should be (execStateFile) + val handler3 = s3.blockHandler + val resolver3 = ShuffleTestAccessor.getBlockResolver(handler3) + + s3.initializeApplication(app2Data) + ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver3) should be (Some(shuffleInfo2)) + s3.stop() + + } + +} diff --git a/yarn/src/test/scala/org/apache/spark/network/yarn/YarnTestAccessor.scala b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnTestAccessor.scala new file mode 100644 index 0000000000..db322cd18e --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnTestAccessor.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.yarn + +import java.io.File + +/** + * just a cheat to get package-visible members in tests + */ +object YarnTestAccessor { + def getShuffleServicePort: Int = { + YarnShuffleService.boundPort + } + + def getShuffleServiceInstance: YarnShuffleService = { + YarnShuffleService.instance + } + + def getRegisteredExecutorFile(service: YarnShuffleService): File = { + service.registeredExecutorFile + } + +} |