aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala119
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala35
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala19
-rw-r--r--core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java39
-rw-r--r--core/src/test/resources/log4j.properties11
-rw-r--r--core/src/test/scala/org/apache/spark/launcher/LauncherBackendSuite.scala81
-rw-r--r--launcher/pom.xml5
-rw-r--r--launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java38
-rw-r--r--launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java159
-rw-r--r--launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java110
-rw-r--r--launcher/src/main/java/org/apache/spark/launcher/LauncherProtocol.java93
-rw-r--r--launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java341
-rw-r--r--launcher/src/main/java/org/apache/spark/launcher/NamedThreadFactory.java40
-rw-r--r--launcher/src/main/java/org/apache/spark/launcher/OutputRedirector.java78
-rw-r--r--launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java126
-rw-r--r--launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java106
-rw-r--r--launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java22
-rw-r--r--launcher/src/main/java/org/apache/spark/launcher/package-info.java38
-rw-r--r--launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java32
-rw-r--r--launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java188
-rw-r--r--launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java4
-rw-r--r--launcher/src/test/java/org/apache/spark/launcher/SparkSubmitOptionParserSuite.java2
-rw-r--r--launcher/src/test/resources/log4j.properties13
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala43
-rw-r--r--yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala10
-rw-r--r--yarn/src/test/resources/log4j.properties7
-rw-r--r--yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala127
-rw-r--r--yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala76
-rw-r--r--yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala4
29 files changed, 1820 insertions, 146 deletions
diff --git a/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala b/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala
new file mode 100644
index 0000000000..3ea984c501
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.launcher
+
+import java.net.{InetAddress, Socket}
+
+import org.apache.spark.SPARK_VERSION
+import org.apache.spark.launcher.LauncherProtocol._
+import org.apache.spark.util.ThreadUtils
+
+/**
+ * A class that can be used to talk to a launcher server. Users should extend this class to
+ * provide implementation for the abstract methods.
+ *
+ * See `LauncherServer` for an explanation of how launcher communication works.
+ */
+private[spark] abstract class LauncherBackend {
+
+ private var clientThread: Thread = _
+ private var connection: BackendConnection = _
+ private var lastState: SparkAppHandle.State = _
+ @volatile private var _isConnected = false
+
+ def connect(): Unit = {
+ val port = sys.env.get(LauncherProtocol.ENV_LAUNCHER_PORT).map(_.toInt)
+ val secret = sys.env.get(LauncherProtocol.ENV_LAUNCHER_SECRET)
+ if (port != None && secret != None) {
+ val s = new Socket(InetAddress.getLoopbackAddress(), port.get)
+ connection = new BackendConnection(s)
+ connection.send(new Hello(secret.get, SPARK_VERSION))
+ clientThread = LauncherBackend.threadFactory.newThread(connection)
+ clientThread.start()
+ _isConnected = true
+ }
+ }
+
+ def close(): Unit = {
+ if (connection != null) {
+ try {
+ connection.close()
+ } finally {
+ if (clientThread != null) {
+ clientThread.join()
+ }
+ }
+ }
+ }
+
+ def setAppId(appId: String): Unit = {
+ if (connection != null) {
+ connection.send(new SetAppId(appId))
+ }
+ }
+
+ def setState(state: SparkAppHandle.State): Unit = {
+ if (connection != null && lastState != state) {
+ connection.send(new SetState(state))
+ lastState = state
+ }
+ }
+
+ /** Return whether the launcher handle is still connected to this backend. */
+ def isConnected(): Boolean = _isConnected
+
+ /**
+ * Implementations should provide this method, which should try to stop the application
+ * as gracefully as possible.
+ */
+ protected def onStopRequest(): Unit
+
+ /**
+ * Callback for when the launcher handle disconnects from this backend.
+ */
+ protected def onDisconnected() : Unit = { }
+
+
+ private class BackendConnection(s: Socket) extends LauncherConnection(s) {
+
+ override protected def handle(m: Message): Unit = m match {
+ case _: Stop =>
+ onStopRequest()
+
+ case _ =>
+ throw new IllegalArgumentException(s"Unexpected message type: ${m.getClass().getName()}")
+ }
+
+ override def close(): Unit = {
+ try {
+ super.close()
+ } finally {
+ onDisconnected()
+ _isConnected = false
+ }
+ }
+
+ }
+
+}
+
+private object LauncherBackend {
+
+ val threadFactory = ThreadUtils.namedThreadFactory("LauncherBackend")
+
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index 27491ecf8b..2625c3e7ac 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -23,6 +23,7 @@ import org.apache.spark.rpc.RpcAddress
import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv}
import org.apache.spark.deploy.{ApplicationDescription, Command}
import org.apache.spark.deploy.client.{AppClient, AppClientListener}
+import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle}
import org.apache.spark.scheduler._
import org.apache.spark.util.Utils
@@ -36,6 +37,9 @@ private[spark] class SparkDeploySchedulerBackend(
private var client: AppClient = null
private var stopping = false
+ private val launcherBackend = new LauncherBackend() {
+ override protected def onStopRequest(): Unit = stop(SparkAppHandle.State.KILLED)
+ }
@volatile var shutdownCallback: SparkDeploySchedulerBackend => Unit = _
@volatile private var appId: String = _
@@ -47,6 +51,7 @@ private[spark] class SparkDeploySchedulerBackend(
override def start() {
super.start()
+ launcherBackend.connect()
// The endpoint for executors to talk to us
val driverUrl = rpcEnv.uriOf(SparkEnv.driverActorSystemName,
@@ -87,24 +92,20 @@ private[spark] class SparkDeploySchedulerBackend(
command, appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor)
client = new AppClient(sc.env.rpcEnv, masters, appDesc, this, conf)
client.start()
+ launcherBackend.setState(SparkAppHandle.State.SUBMITTED)
waitForRegistration()
+ launcherBackend.setState(SparkAppHandle.State.RUNNING)
}
- override def stop() {
- stopping = true
- super.stop()
- client.stop()
-
- val callback = shutdownCallback
- if (callback != null) {
- callback(this)
- }
+ override def stop(): Unit = synchronized {
+ stop(SparkAppHandle.State.FINISHED)
}
override def connected(appId: String) {
logInfo("Connected to Spark cluster with app ID " + appId)
this.appId = appId
notifyContext()
+ launcherBackend.setAppId(appId)
}
override def disconnected() {
@@ -117,6 +118,7 @@ private[spark] class SparkDeploySchedulerBackend(
override def dead(reason: String) {
notifyContext()
if (!stopping) {
+ launcherBackend.setState(SparkAppHandle.State.KILLED)
logError("Application has been killed. Reason: " + reason)
try {
scheduler.error(reason)
@@ -188,4 +190,19 @@ private[spark] class SparkDeploySchedulerBackend(
registrationBarrier.release()
}
+ private def stop(finalState: SparkAppHandle.State): Unit = synchronized {
+ stopping = true
+
+ launcherBackend.setState(finalState)
+ launcherBackend.close()
+
+ super.stop()
+ client.stop()
+
+ val callback = shutdownCallback
+ if (callback != null) {
+ callback(this)
+ }
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
index 4d48fcfea4..c633d860ae 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
@@ -24,6 +24,7 @@ import java.nio.ByteBuffer
import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv, TaskState}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.executor.{Executor, ExecutorBackend}
+import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle}
import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.cluster.ExecutorInfo
@@ -103,6 +104,9 @@ private[spark] class LocalBackend(
private var localEndpoint: RpcEndpointRef = null
private val userClassPath = getUserClasspath(conf)
private val listenerBus = scheduler.sc.listenerBus
+ private val launcherBackend = new LauncherBackend() {
+ override def onStopRequest(): Unit = stop(SparkAppHandle.State.KILLED)
+ }
/**
* Returns a list of URLs representing the user classpath.
@@ -114,6 +118,8 @@ private[spark] class LocalBackend(
userClassPathStr.map(_.split(File.pathSeparator)).toSeq.flatten.map(new File(_).toURI.toURL)
}
+ launcherBackend.connect()
+
override def start() {
val rpcEnv = SparkEnv.get.rpcEnv
val executorEndpoint = new LocalEndpoint(rpcEnv, userClassPath, scheduler, this, totalCores)
@@ -122,10 +128,12 @@ private[spark] class LocalBackend(
System.currentTimeMillis,
executorEndpoint.localExecutorId,
new ExecutorInfo(executorEndpoint.localExecutorHostname, totalCores, Map.empty)))
+ launcherBackend.setAppId(appId)
+ launcherBackend.setState(SparkAppHandle.State.RUNNING)
}
override def stop() {
- localEndpoint.ask(StopExecutor)
+ stop(SparkAppHandle.State.FINISHED)
}
override def reviveOffers() {
@@ -145,4 +153,13 @@ private[spark] class LocalBackend(
override def applicationId(): String = appId
+ private def stop(finalState: SparkAppHandle.State): Unit = {
+ localEndpoint.ask(StopExecutor)
+ try {
+ launcherBackend.setState(finalState)
+ } finally {
+ launcherBackend.close()
+ }
+ }
+
}
diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java
index d0c26dd056..aa15e792e2 100644
--- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java
+++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java
@@ -27,6 +27,7 @@ import java.util.Map;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.slf4j.bridge.SLF4JBridgeHandler;
import static org.junit.Assert.*;
/**
@@ -34,7 +35,13 @@ import static org.junit.Assert.*;
*/
public class SparkLauncherSuite {
+ static {
+ SLF4JBridgeHandler.removeHandlersForRootLogger();
+ SLF4JBridgeHandler.install();
+ }
+
private static final Logger LOG = LoggerFactory.getLogger(SparkLauncherSuite.class);
+ private static final NamedThreadFactory TF = new NamedThreadFactory("SparkLauncherSuite-%d");
@Test
public void testSparkArgumentHandling() throws Exception {
@@ -94,14 +101,15 @@ public class SparkLauncherSuite {
.addSparkArg(opts.CONF,
String.format("%s=-Dfoo=ShouldBeOverriddenBelow", SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS))
.setConf(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS,
- "-Dfoo=bar -Dtest.name=-testChildProcLauncher")
+ "-Dfoo=bar -Dtest.appender=childproc")
.setConf(SparkLauncher.DRIVER_EXTRA_CLASSPATH, System.getProperty("java.class.path"))
.addSparkArg(opts.CLASS, "ShouldBeOverriddenBelow")
.setMainClass(SparkLauncherTestApp.class.getName())
.addAppArgs("proc");
final Process app = launcher.launch();
- new Redirector("stdout", app.getInputStream()).start();
- new Redirector("stderr", app.getErrorStream()).start();
+
+ new OutputRedirector(app.getInputStream(), TF);
+ new OutputRedirector(app.getErrorStream(), TF);
assertEquals(0, app.waitFor());
}
@@ -116,29 +124,4 @@ public class SparkLauncherSuite {
}
- private static class Redirector extends Thread {
-
- private final InputStream in;
-
- Redirector(String name, InputStream in) {
- this.in = in;
- setName(name);
- setDaemon(true);
- }
-
- @Override
- public void run() {
- try {
- BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8"));
- String line;
- while ((line = reader.readLine()) != null) {
- LOG.warn(line);
- }
- } catch (Exception e) {
- LOG.error("Error reading process output.", e);
- }
- }
-
- }
-
}
diff --git a/core/src/test/resources/log4j.properties b/core/src/test/resources/log4j.properties
index eb3b1999eb..a54d27de91 100644
--- a/core/src/test/resources/log4j.properties
+++ b/core/src/test/resources/log4j.properties
@@ -16,13 +16,22 @@
#
# Set everything to be logged to the file target/unit-tests.log
-log4j.rootCategory=INFO, file
+test.appender=file
+log4j.rootCategory=INFO, ${test.appender}
log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file.append=true
log4j.appender.file.file=target/unit-tests.log
log4j.appender.file.layout=org.apache.log4j.PatternLayout
log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
+# Tests that launch java subprocesses can set the "test.appender" system property to
+# "console" to avoid having the child process's logs overwrite the unit test's
+# log file.
+log4j.appender.console=org.apache.log4j.ConsoleAppender
+log4j.appender.console.target=System.err
+log4j.appender.console.layout=org.apache.log4j.PatternLayout
+log4j.appender.console.layout.ConversionPattern=%t: %m%n
+
# Ignore messages below warning level from Jetty, because it's a bit verbose
log4j.logger.org.spark-project.jetty=WARN
org.spark-project.jetty.LEVEL=WARN
diff --git a/core/src/test/scala/org/apache/spark/launcher/LauncherBackendSuite.scala b/core/src/test/scala/org/apache/spark/launcher/LauncherBackendSuite.scala
new file mode 100644
index 0000000000..07e8869833
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/launcher/LauncherBackendSuite.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.launcher
+
+import java.util.concurrent.TimeUnit
+
+import scala.concurrent.duration._
+import scala.language.postfixOps
+
+import org.scalatest.Matchers
+import org.scalatest.concurrent.Eventually._
+
+import org.apache.spark._
+import org.apache.spark.launcher._
+
+class LauncherBackendSuite extends SparkFunSuite with Matchers {
+
+ private val tests = Seq(
+ "local" -> "local",
+ "standalone/client" -> "local-cluster[1,1,1024]")
+
+ tests.foreach { case (name, master) =>
+ test(s"$name: launcher handle") {
+ testWithMaster(master)
+ }
+ }
+
+ private def testWithMaster(master: String): Unit = {
+ val env = new java.util.HashMap[String, String]()
+ env.put("SPARK_PRINT_LAUNCH_COMMAND", "1")
+ val handle = new SparkLauncher(env)
+ .setSparkHome(sys.props("spark.test.home"))
+ .setConf(SparkLauncher.DRIVER_EXTRA_CLASSPATH, System.getProperty("java.class.path"))
+ .setConf("spark.ui.enabled", "false")
+ .setConf(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, s"-Dtest.appender=console")
+ .setMaster(master)
+ .setAppResource("spark-internal")
+ .setMainClass(TestApp.getClass.getName().stripSuffix("$"))
+ .startApplication()
+
+ try {
+ eventually(timeout(10 seconds), interval(100 millis)) {
+ handle.getAppId() should not be (null)
+ }
+
+ handle.stop()
+
+ eventually(timeout(10 seconds), interval(100 millis)) {
+ handle.getState() should be (SparkAppHandle.State.KILLED)
+ }
+ } finally {
+ handle.kill()
+ }
+ }
+
+}
+
+object TestApp {
+
+ def main(args: Array[String]): Unit = {
+ new SparkContext(new SparkConf()).parallelize(Seq(1)).foreach { i =>
+ Thread.sleep(TimeUnit.SECONDS.toMillis(20))
+ }
+ }
+
+}
diff --git a/launcher/pom.xml b/launcher/pom.xml
index d595d74642..5739bfc169 100644
--- a/launcher/pom.xml
+++ b/launcher/pom.xml
@@ -49,6 +49,11 @@
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
+ <artifactId>jul-to-slf4j</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<scope>test</scope>
</dependency>
diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
index 610e8bdaaa..cf3729b7fe 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
@@ -47,7 +47,7 @@ abstract class AbstractCommandBuilder {
String javaHome;
String mainClass;
String master;
- String propertiesFile;
+ protected String propertiesFile;
final List<String> appArgs;
final List<String> jars;
final List<String> files;
@@ -55,6 +55,10 @@ abstract class AbstractCommandBuilder {
final Map<String, String> childEnv;
final Map<String, String> conf;
+ // The merged configuration for the application. Cached to avoid having to read / parse
+ // properties files multiple times.
+ private Map<String, String> effectiveConfig;
+
public AbstractCommandBuilder() {
this.appArgs = new ArrayList<String>();
this.childEnv = new HashMap<String, String>();
@@ -257,12 +261,38 @@ abstract class AbstractCommandBuilder {
return path;
}
+ String getenv(String key) {
+ return firstNonEmpty(childEnv.get(key), System.getenv(key));
+ }
+
+ void setPropertiesFile(String path) {
+ effectiveConfig = null;
+ this.propertiesFile = path;
+ }
+
+ Map<String, String> getEffectiveConfig() throws IOException {
+ if (effectiveConfig == null) {
+ if (propertiesFile == null) {
+ effectiveConfig = conf;
+ } else {
+ effectiveConfig = new HashMap<>(conf);
+ Properties p = loadPropertiesFile();
+ for (String key : p.stringPropertyNames()) {
+ if (!effectiveConfig.containsKey(key)) {
+ effectiveConfig.put(key, p.getProperty(key));
+ }
+ }
+ }
+ }
+ return effectiveConfig;
+ }
+
/**
* Loads the configuration file for the application, if it exists. This is either the
* user-specified properties file, or the spark-defaults.conf file under the Spark configuration
* directory.
*/
- Properties loadPropertiesFile() throws IOException {
+ private Properties loadPropertiesFile() throws IOException {
Properties props = new Properties();
File propsFile;
if (propertiesFile != null) {
@@ -294,10 +324,6 @@ abstract class AbstractCommandBuilder {
return props;
}
- String getenv(String key) {
- return firstNonEmpty(childEnv.get(key), System.getenv(key));
- }
-
private String findAssembly() {
String sparkHome = getSparkHome();
File libdir;
diff --git a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java
new file mode 100644
index 0000000000..de50f14fbd
--- /dev/null
+++ b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.launcher;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.ThreadFactory;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+
+/**
+ * Handle implementation for monitoring apps started as a child process.
+ */
+class ChildProcAppHandle implements SparkAppHandle {
+
+ private static final Logger LOG = Logger.getLogger(ChildProcAppHandle.class.getName());
+ private static final ThreadFactory REDIRECTOR_FACTORY =
+ new NamedThreadFactory("launcher-proc-%d");
+
+ private final String secret;
+ private final LauncherServer server;
+
+ private Process childProc;
+ private boolean disposed;
+ private LauncherConnection connection;
+ private List<Listener> listeners;
+ private State state;
+ private String appId;
+ private OutputRedirector redirector;
+
+ ChildProcAppHandle(String secret, LauncherServer server) {
+ this.secret = secret;
+ this.server = server;
+ this.state = State.UNKNOWN;
+ }
+
+ @Override
+ public synchronized void addListener(Listener l) {
+ if (listeners == null) {
+ listeners = new ArrayList<>();
+ }
+ listeners.add(l);
+ }
+
+ @Override
+ public State getState() {
+ return state;
+ }
+
+ @Override
+ public String getAppId() {
+ return appId;
+ }
+
+ @Override
+ public void stop() {
+ CommandBuilderUtils.checkState(connection != null, "Application is still not connected.");
+ try {
+ connection.send(new LauncherProtocol.Stop());
+ } catch (IOException ioe) {
+ throw new RuntimeException(ioe);
+ }
+ }
+
+ @Override
+ public synchronized void disconnect() {
+ if (!disposed) {
+ disposed = true;
+ if (connection != null) {
+ try {
+ connection.close();
+ } catch (IOException ioe) {
+ // no-op.
+ }
+ }
+ server.unregister(this);
+ if (redirector != null) {
+ redirector.stop();
+ }
+ }
+ }
+
+ @Override
+ public synchronized void kill() {
+ if (!disposed) {
+ disconnect();
+ }
+ if (childProc != null) {
+ childProc.destroy();
+ childProc = null;
+ }
+ }
+
+ String getSecret() {
+ return secret;
+ }
+
+ void setChildProc(Process childProc, String loggerName) {
+ this.childProc = childProc;
+ this.redirector = new OutputRedirector(childProc.getInputStream(), loggerName,
+ REDIRECTOR_FACTORY);
+ }
+
+ void setConnection(LauncherConnection connection) {
+ this.connection = connection;
+ }
+
+ LauncherServer getServer() {
+ return server;
+ }
+
+ LauncherConnection getConnection() {
+ return connection;
+ }
+
+ void setState(State s) {
+ if (!state.isFinal()) {
+ state = s;
+ fireEvent(false);
+ } else {
+ LOG.log(Level.WARNING, "Backend requested transition from final state {0} to {1}.",
+ new Object[] { state, s });
+ }
+ }
+
+ void setAppId(String appId) {
+ this.appId = appId;
+ fireEvent(true);
+ }
+
+ private synchronized void fireEvent(boolean isInfoChanged) {
+ if (listeners != null) {
+ for (Listener l : listeners) {
+ if (isInfoChanged) {
+ l.infoChanged(this);
+ } else {
+ l.stateChanged(this);
+ }
+ }
+ }
+ }
+
+}
diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java
new file mode 100644
index 0000000000..eec264909b
--- /dev/null
+++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.launcher;
+
+import java.io.Closeable;
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+import java.net.Socket;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+
+import static org.apache.spark.launcher.LauncherProtocol.*;
+
+/**
+ * Encapsulates a connection between a launcher server and client. This takes care of the
+ * communication (sending and receiving messages), while processing of messages is left for
+ * the implementations.
+ */
+abstract class LauncherConnection implements Closeable, Runnable {
+
+ private static final Logger LOG = Logger.getLogger(LauncherConnection.class.getName());
+
+ private final Socket socket;
+ private final ObjectOutputStream out;
+
+ private volatile boolean closed;
+
+ LauncherConnection(Socket socket) throws IOException {
+ this.socket = socket;
+ this.out = new ObjectOutputStream(socket.getOutputStream());
+ this.closed = false;
+ }
+
+ protected abstract void handle(Message msg) throws IOException;
+
+ @Override
+ public void run() {
+ try {
+ ObjectInputStream in = new ObjectInputStream(socket.getInputStream());
+ while (!closed) {
+ Message msg = (Message) in.readObject();
+ handle(msg);
+ }
+ } catch (EOFException eof) {
+ // Remote side has closed the connection, just cleanup.
+ try {
+ close();
+ } catch (Exception unused) {
+ // no-op.
+ }
+ } catch (Exception e) {
+ if (!closed) {
+ LOG.log(Level.WARNING, "Error in inbound message handling.", e);
+ try {
+ close();
+ } catch (Exception unused) {
+ // no-op.
+ }
+ }
+ }
+ }
+
+ protected synchronized void send(Message msg) throws IOException {
+ try {
+ CommandBuilderUtils.checkState(!closed, "Disconnected.");
+ out.writeObject(msg);
+ out.flush();
+ } catch (IOException ioe) {
+ if (!closed) {
+ LOG.log(Level.WARNING, "Error when sending message.", ioe);
+ try {
+ close();
+ } catch (Exception unused) {
+ // no-op.
+ }
+ }
+ throw ioe;
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ if (!closed) {
+ synchronized (this) {
+ if (!closed) {
+ closed = true;
+ socket.close();
+ }
+ }
+ }
+ }
+
+}
diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherProtocol.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherProtocol.java
new file mode 100644
index 0000000000..50f136497e
--- /dev/null
+++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherProtocol.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.launcher;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+import java.io.Serializable;
+import java.net.Socket;
+import java.util.Map;
+
+/**
+ * Message definitions for the launcher communication protocol. These messages must remain
+ * backwards-compatible, so that the launcher can talk to older versions of Spark that support
+ * the protocol.
+ */
+final class LauncherProtocol {
+
+ /** Environment variable where the server port is stored. */
+ static final String ENV_LAUNCHER_PORT = "_SPARK_LAUNCHER_PORT";
+
+ /** Environment variable where the secret for connecting back to the server is stored. */
+ static final String ENV_LAUNCHER_SECRET = "_SPARK_LAUNCHER_SECRET";
+
+ static class Message implements Serializable {
+
+ }
+
+ /**
+ * Hello message, sent from client to server.
+ */
+ static class Hello extends Message {
+
+ final String secret;
+ final String sparkVersion;
+
+ Hello(String secret, String version) {
+ this.secret = secret;
+ this.sparkVersion = version;
+ }
+
+ }
+
+ /**
+ * SetAppId message, sent from client to server.
+ */
+ static class SetAppId extends Message {
+
+ final String appId;
+
+ SetAppId(String appId) {
+ this.appId = appId;
+ }
+
+ }
+
+ /**
+ * SetState message, sent from client to server.
+ */
+ static class SetState extends Message {
+
+ final SparkAppHandle.State state;
+
+ SetState(SparkAppHandle.State state) {
+ this.state = state;
+ }
+
+ }
+
+ /**
+ * Stop message, send from server to client to stop the application.
+ */
+ static class Stop extends Message {
+
+ }
+
+}
diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java
new file mode 100644
index 0000000000..c5fd40816d
--- /dev/null
+++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java
@@ -0,0 +1,341 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.launcher;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.net.InetAddress;
+import java.net.InetSocketAddress;
+import java.net.ServerSocket;
+import java.net.Socket;
+import java.security.SecureRandom;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Timer;
+import java.util.TimerTask;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.ThreadFactory;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+
+import static org.apache.spark.launcher.LauncherProtocol.*;
+
+/**
+ * A server that listens locally for connections from client launched by the library. Each client
+ * has a secret that it needs to send to the server to identify itself and establish the session.
+ *
+ * I/O is currently blocking (one thread per client). Clients have a limited time to connect back
+ * to the server, otherwise the server will ignore the connection.
+ *
+ * === Architecture Overview ===
+ *
+ * The launcher server is used when Spark apps are launched as separate processes than the calling
+ * app. It looks more or less like the following:
+ *
+ * ----------------------- -----------------------
+ * | User App | spark-submit | Spark App |
+ * | | -------------------> | |
+ * | ------------| |------------- |
+ * | | | hello | | |
+ * | | L. Server |<----------------------| L. Backend | |
+ * | | | | | |
+ * | ------------- -----------------------
+ * | | | ^
+ * | v | |
+ * | -------------| |
+ * | | | <per-app channel> |
+ * | | App Handle |<------------------------------
+ * | | |
+ * -----------------------
+ *
+ * The server is started on demand and remains active while there are active or outstanding clients,
+ * to avoid opening too many ports when multiple clients are launched. Each client is given a unique
+ * secret, and have a limited amount of time to connect back
+ * ({@link SparkLauncher#CHILD_CONNECTION_TIMEOUT}), at which point the server will throw away
+ * that client's state. A client is only allowed to connect back to the server once.
+ *
+ * The launcher server listens on the localhost only, so it doesn't need access controls (aside from
+ * the per-app secret) nor encryption. It thus requires that the launched app has a local process
+ * that communicates with the server. In cluster mode, this means that the client that launches the
+ * application must remain alive for the duration of the application (or until the app handle is
+ * disconnected).
+ */
+class LauncherServer implements Closeable {
+
+ private static final Logger LOG = Logger.getLogger(LauncherServer.class.getName());
+ private static final String THREAD_NAME_FMT = "LauncherServer-%d";
+ private static final long DEFAULT_CONNECT_TIMEOUT = 10000L;
+
+ /** For creating secrets used for communication with child processes. */
+ private static final SecureRandom RND = new SecureRandom();
+
+ private static volatile LauncherServer serverInstance;
+
+ /**
+ * Creates a handle for an app to be launched. This method will start a server if one hasn't been
+ * started yet. The server is shared for multiple handles, and once all handles are disposed of,
+ * the server is shut down.
+ */
+ static synchronized ChildProcAppHandle newAppHandle() throws IOException {
+ LauncherServer server = serverInstance != null ? serverInstance : new LauncherServer();
+ server.ref();
+ serverInstance = server;
+
+ String secret = server.createSecret();
+ while (server.pending.containsKey(secret)) {
+ secret = server.createSecret();
+ }
+
+ return server.newAppHandle(secret);
+ }
+
+ static LauncherServer getServerInstance() {
+ return serverInstance;
+ }
+
+ private final AtomicLong refCount;
+ private final AtomicLong threadIds;
+ private final ConcurrentMap<String, ChildProcAppHandle> pending;
+ private final List<ServerConnection> clients;
+ private final ServerSocket server;
+ private final Thread serverThread;
+ private final ThreadFactory factory;
+ private final Timer timeoutTimer;
+
+ private volatile boolean running;
+
+ private LauncherServer() throws IOException {
+ this.refCount = new AtomicLong(0);
+
+ ServerSocket server = new ServerSocket();
+ try {
+ server.setReuseAddress(true);
+ server.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0));
+
+ this.clients = new ArrayList<ServerConnection>();
+ this.threadIds = new AtomicLong();
+ this.factory = new NamedThreadFactory(THREAD_NAME_FMT);
+ this.pending = new ConcurrentHashMap<>();
+ this.timeoutTimer = new Timer("LauncherServer-TimeoutTimer", true);
+ this.server = server;
+ this.running = true;
+
+ this.serverThread = factory.newThread(new Runnable() {
+ @Override
+ public void run() {
+ acceptConnections();
+ }
+ });
+ serverThread.start();
+ } catch (IOException ioe) {
+ close();
+ throw ioe;
+ } catch (Exception e) {
+ close();
+ throw new IOException(e);
+ }
+ }
+
+ /**
+ * Creates a new app handle. The handle will wait for an incoming connection for a configurable
+ * amount of time, and if one doesn't arrive, it will transition to an error state.
+ */
+ ChildProcAppHandle newAppHandle(String secret) {
+ ChildProcAppHandle handle = new ChildProcAppHandle(secret, this);
+ ChildProcAppHandle existing = pending.putIfAbsent(secret, handle);
+ CommandBuilderUtils.checkState(existing == null, "Multiple handles with the same secret.");
+ return handle;
+ }
+
+ @Override
+ public void close() throws IOException {
+ synchronized (this) {
+ if (running) {
+ running = false;
+ timeoutTimer.cancel();
+ server.close();
+ synchronized (clients) {
+ List<ServerConnection> copy = new ArrayList<>(clients);
+ clients.clear();
+ for (ServerConnection client : copy) {
+ client.close();
+ }
+ }
+ }
+ }
+ if (serverThread != null) {
+ try {
+ serverThread.join();
+ } catch (InterruptedException ie) {
+ // no-op
+ }
+ }
+ }
+
+ void ref() {
+ refCount.incrementAndGet();
+ }
+
+ void unref() {
+ synchronized(LauncherServer.class) {
+ if (refCount.decrementAndGet() == 0) {
+ try {
+ close();
+ } catch (IOException ioe) {
+ // no-op.
+ } finally {
+ serverInstance = null;
+ }
+ }
+ }
+ }
+
+ int getPort() {
+ return server.getLocalPort();
+ }
+
+ /**
+ * Removes the client handle from the pending list (in case it's still there), and unrefs
+ * the server.
+ */
+ void unregister(ChildProcAppHandle handle) {
+ pending.remove(handle.getSecret());
+ unref();
+ }
+
+ private void acceptConnections() {
+ try {
+ while (running) {
+ final Socket client = server.accept();
+ TimerTask timeout = new TimerTask() {
+ @Override
+ public void run() {
+ LOG.warning("Timed out waiting for hello message from client.");
+ try {
+ client.close();
+ } catch (IOException ioe) {
+ // no-op.
+ }
+ }
+ };
+ ServerConnection clientConnection = new ServerConnection(client, timeout);
+ Thread clientThread = factory.newThread(clientConnection);
+ synchronized (timeout) {
+ clientThread.start();
+ synchronized (clients) {
+ clients.add(clientConnection);
+ }
+ timeoutTimer.schedule(timeout, getConnectionTimeout());
+ }
+ }
+ } catch (IOException ioe) {
+ if (running) {
+ LOG.log(Level.SEVERE, "Error in accept loop.", ioe);
+ }
+ }
+ }
+
+ private long getConnectionTimeout() {
+ String value = SparkLauncher.launcherConfig.get(SparkLauncher.CHILD_CONNECTION_TIMEOUT);
+ return (value != null) ? Long.parseLong(value) : DEFAULT_CONNECT_TIMEOUT;
+ }
+
+ private String createSecret() {
+ byte[] secret = new byte[128];
+ RND.nextBytes(secret);
+
+ StringBuilder sb = new StringBuilder();
+ for (byte b : secret) {
+ int ival = b >= 0 ? b : Byte.MAX_VALUE - b;
+ if (ival < 0x10) {
+ sb.append("0");
+ }
+ sb.append(Integer.toHexString(ival));
+ }
+ return sb.toString();
+ }
+
+ private class ServerConnection extends LauncherConnection {
+
+ private TimerTask timeout;
+ private ChildProcAppHandle handle;
+
+ ServerConnection(Socket socket, TimerTask timeout) throws IOException {
+ super(socket);
+ this.timeout = timeout;
+ }
+
+ @Override
+ protected void handle(Message msg) throws IOException {
+ try {
+ if (msg instanceof Hello) {
+ synchronized (timeout) {
+ timeout.cancel();
+ }
+ timeout = null;
+ Hello hello = (Hello) msg;
+ ChildProcAppHandle handle = pending.remove(hello.secret);
+ if (handle != null) {
+ handle.setState(SparkAppHandle.State.CONNECTED);
+ handle.setConnection(this);
+ this.handle = handle;
+ } else {
+ throw new IllegalArgumentException("Received Hello for unknown client.");
+ }
+ } else {
+ if (handle == null) {
+ throw new IllegalArgumentException("Expected hello, got: " +
+ msg != null ? msg.getClass().getName() : null);
+ }
+ if (msg instanceof SetAppId) {
+ SetAppId set = (SetAppId) msg;
+ handle.setAppId(set.appId);
+ } else if (msg instanceof SetState) {
+ handle.setState(((SetState)msg).state);
+ } else {
+ throw new IllegalArgumentException("Invalid message: " +
+ msg != null ? msg.getClass().getName() : null);
+ }
+ }
+ } catch (Exception e) {
+ LOG.log(Level.INFO, "Error handling message from client.", e);
+ if (timeout != null) {
+ timeout.cancel();
+ }
+ close();
+ } finally {
+ timeoutTimer.purge();
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ synchronized (clients) {
+ clients.remove(this);
+ }
+ super.close();
+ if (handle != null) {
+ handle.disconnect();
+ }
+ }
+
+ }
+
+}
diff --git a/launcher/src/main/java/org/apache/spark/launcher/NamedThreadFactory.java b/launcher/src/main/java/org/apache/spark/launcher/NamedThreadFactory.java
new file mode 100644
index 0000000000..995f4d73da
--- /dev/null
+++ b/launcher/src/main/java/org/apache/spark/launcher/NamedThreadFactory.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.launcher;
+
+import java.util.concurrent.ThreadFactory;
+import java.util.concurrent.atomic.AtomicLong;
+
+class NamedThreadFactory implements ThreadFactory {
+
+ private final String nameFormat;
+ private final AtomicLong threadIds;
+
+ NamedThreadFactory(String nameFormat) {
+ this.nameFormat = nameFormat;
+ this.threadIds = new AtomicLong();
+ }
+
+ @Override
+ public Thread newThread(Runnable r) {
+ Thread t = new Thread(r, String.format(nameFormat, threadIds.incrementAndGet()));
+ t.setDaemon(true);
+ return t;
+ }
+
+}
diff --git a/launcher/src/main/java/org/apache/spark/launcher/OutputRedirector.java b/launcher/src/main/java/org/apache/spark/launcher/OutputRedirector.java
new file mode 100644
index 0000000000..6e7120167d
--- /dev/null
+++ b/launcher/src/main/java/org/apache/spark/launcher/OutputRedirector.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.launcher;
+
+import java.io.BufferedReader;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.io.IOException;
+import java.util.concurrent.ThreadFactory;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+
+/**
+ * Redirects lines read from a given input stream to a j.u.l.Logger (at INFO level).
+ */
+class OutputRedirector {
+
+ private final BufferedReader reader;
+ private final Logger sink;
+ private final Thread thread;
+
+ private volatile boolean active;
+
+ OutputRedirector(InputStream in, ThreadFactory tf) {
+ this(in, OutputRedirector.class.getName(), tf);
+ }
+
+ OutputRedirector(InputStream in, String loggerName, ThreadFactory tf) {
+ this.active = true;
+ this.reader = new BufferedReader(new InputStreamReader(in));
+ this.thread = tf.newThread(new Runnable() {
+ @Override
+ public void run() {
+ redirect();
+ }
+ });
+ this.sink = Logger.getLogger(loggerName);
+ thread.start();
+ }
+
+ private void redirect() {
+ try {
+ String line;
+ while ((line = reader.readLine()) != null) {
+ if (active) {
+ sink.info(line.replaceFirst("\\s*$", ""));
+ }
+ }
+ } catch (IOException e) {
+ sink.log(Level.FINE, "Error reading child process output.", e);
+ }
+ }
+
+ /**
+ * This method just stops the output of the process from showing up in the local logs.
+ * The child's output will still be read (and, thus, the redirect thread will still be
+ * alive) to avoid the child process hanging because of lack of output buffer.
+ */
+ void stop() {
+ active = false;
+ }
+
+}
diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java
new file mode 100644
index 0000000000..2896a91d5e
--- /dev/null
+++ b/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.launcher;
+
+/**
+ * A handle to a running Spark application.
+ * <p/>
+ * Provides runtime information about the underlying Spark application, and actions to control it.
+ *
+ * @since 1.6.0
+ */
+public interface SparkAppHandle {
+
+ /**
+ * Represents the application's state. A state can be "final", in which case it will not change
+ * after it's reached, and means the application is not running anymore.
+ *
+ * @since 1.6.0
+ */
+ public enum State {
+ /** The application has not reported back yet. */
+ UNKNOWN(false),
+ /** The application has connected to the handle. */
+ CONNECTED(false),
+ /** The application has been submitted to the cluster. */
+ SUBMITTED(false),
+ /** The application is running. */
+ RUNNING(false),
+ /** The application finished with a successful status. */
+ FINISHED(true),
+ /** The application finished with a failed status. */
+ FAILED(true),
+ /** The application was killed. */
+ KILLED(true);
+
+ private final boolean isFinal;
+
+ State(boolean isFinal) {
+ this.isFinal = isFinal;
+ }
+
+ /**
+ * Whether this state is a final state, meaning the application is not running anymore
+ * once it's reached.
+ */
+ public boolean isFinal() {
+ return isFinal;
+ }
+ }
+
+ /**
+ * Adds a listener to be notified of changes to the handle's information. Listeners will be called
+ * from the thread processing updates from the application, so they should avoid blocking or
+ * long-running operations.
+ *
+ * @param l Listener to add.
+ */
+ void addListener(Listener l);
+
+ /** Returns the current application state. */
+ State getState();
+
+ /** Returns the application ID, or <code>null</code> if not yet known. */
+ String getAppId();
+
+ /**
+ * Asks the application to stop. This is best-effort, since the application may fail to receive
+ * or act on the command. Callers should watch for a state transition that indicates the
+ * application has really stopped.
+ */
+ void stop();
+
+ /**
+ * Tries to kill the underlying application. Implies {@link #disconnect()}. This will not send
+ * a {@link #stop()} message to the application, so it's recommended that users first try to
+ * stop the application cleanly and only resort to this method if that fails.
+ */
+ void kill();
+
+ /**
+ * Disconnects the handle from the application, without stopping it. After this method is called,
+ * the handle will not be able to communicate with the application anymore.
+ */
+ void disconnect();
+
+ /**
+ * Listener for updates to a handle's state. The callbacks do not receive information about
+ * what exactly has changed, just that an update has occurred.
+ *
+ * @since 1.6.0
+ */
+ public interface Listener {
+
+ /**
+ * Callback for changes in the handle's state.
+ *
+ * @param handle The updated handle.
+ * @see {@link SparkAppHandle#getState()}
+ */
+ void stateChanged(SparkAppHandle handle);
+
+ /**
+ * Callback for changes in any information that is not the handle's state.
+ *
+ * @param handle The updated handle.
+ */
+ void infoChanged(SparkAppHandle handle);
+
+ }
+
+}
diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java
index 57993405e4..5d74b37033 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java
@@ -21,8 +21,10 @@ import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.concurrent.atomic.AtomicInteger;
import static org.apache.spark.launcher.CommandBuilderUtils.*;
@@ -58,6 +60,33 @@ public class SparkLauncher {
/** Configuration key for the number of executor CPU cores. */
public static final String EXECUTOR_CORES = "spark.executor.cores";
+ /** Logger name to use when launching a child process. */
+ public static final String CHILD_PROCESS_LOGGER_NAME = "spark.launcher.childProcLoggerName";
+
+ /**
+ * Maximum time (in ms) to wait for a child process to connect back to the launcher server
+ * when using @link{#start()}.
+ */
+ public static final String CHILD_CONNECTION_TIMEOUT = "spark.launcher.childConectionTimeout";
+
+ /** Used internally to create unique logger names. */
+ private static final AtomicInteger COUNTER = new AtomicInteger();
+
+ static final Map<String, String> launcherConfig = new HashMap<String, String>();
+
+ /**
+ * Set a configuration value for the launcher library. These config values do not affect the
+ * launched application, but rather the behavior of the launcher library itself when managing
+ * applications.
+ *
+ * @since 1.6.0
+ * @param name Config name.
+ * @param value Config value.
+ */
+ public static void setConfig(String name, String value) {
+ launcherConfig.put(name, value);
+ }
+
// Visible for testing.
final SparkSubmitCommandBuilder builder;
@@ -109,7 +138,7 @@ public class SparkLauncher {
*/
public SparkLauncher setPropertiesFile(String path) {
checkNotNull(path, "path");
- builder.propertiesFile = path;
+ builder.setPropertiesFile(path);
return this;
}
@@ -197,6 +226,7 @@ public class SparkLauncher {
* Use this method with caution. It is possible to create an invalid Spark command by passing
* unknown arguments to this method, since those are allowed for forward compatibility.
*
+ * @since 1.5.0
* @param arg Argument to add.
* @return This launcher.
*/
@@ -218,6 +248,7 @@ public class SparkLauncher {
* Use this method with caution. It is possible to create an invalid Spark command by passing
* unknown arguments to this method, since those are allowed for forward compatibility.
*
+ * @since 1.5.0
* @param name Name of argument to add.
* @param value Value of the argument.
* @return This launcher.
@@ -319,10 +350,81 @@ public class SparkLauncher {
/**
* Launches a sub-process that will start the configured Spark application.
+ * <p/>
+ * The {@link #startApplication(SparkAppHandle.Listener...)} method is preferred when launching
+ * Spark, since it provides better control of the child application.
*
* @return A process handle for the Spark app.
*/
public Process launch() throws IOException {
+ return createBuilder().start();
+ }
+
+ /**
+ * Starts a Spark application.
+ * <p/>
+ * This method returns a handle that provides information about the running application and can
+ * be used to do basic interaction with it.
+ * <p/>
+ * The returned handle assumes that the application will instantiate a single SparkContext
+ * during its lifetime. Once that context reports a final state (one that indicates the
+ * SparkContext has stopped), the handle will not perform new state transitions, so anything
+ * that happens after that cannot be monitored. If the underlying application is launched as
+ * a child process, {@link SparkAppHandle#kill()} can still be used to kill the child process.
+ * <p/>
+ * Currently, all applications are launched as child processes. The child's stdout and stderr
+ * are merged and written to a logger (see <code>java.util.logging</code>). The logger's name
+ * can be defined by setting {@link #CHILD_PROCESS_LOGGER_NAME} in the app's configuration. If
+ * that option is not set, the code will try to derive a name from the application's name or
+ * main class / script file. If those cannot be determined, an internal, unique name will be
+ * used. In all cases, the logger name will start with "org.apache.spark.launcher.app", to fit
+ * more easily into the configuration of commonly-used logging systems.
+ *
+ * @since 1.6.0
+ * @param listeners Listeners to add to the handle before the app is launched.
+ * @return A handle for the launched application.
+ */
+ public SparkAppHandle startApplication(SparkAppHandle.Listener... listeners) throws IOException {
+ ChildProcAppHandle handle = LauncherServer.newAppHandle();
+ for (SparkAppHandle.Listener l : listeners) {
+ handle.addListener(l);
+ }
+
+ String appName = builder.getEffectiveConfig().get(CHILD_PROCESS_LOGGER_NAME);
+ if (appName == null) {
+ if (builder.appName != null) {
+ appName = builder.appName;
+ } else if (builder.mainClass != null) {
+ int dot = builder.mainClass.lastIndexOf(".");
+ if (dot >= 0 && dot < builder.mainClass.length() - 1) {
+ appName = builder.mainClass.substring(dot + 1, builder.mainClass.length());
+ } else {
+ appName = builder.mainClass;
+ }
+ } else if (builder.appResource != null) {
+ appName = new File(builder.appResource).getName();
+ } else {
+ appName = String.valueOf(COUNTER.incrementAndGet());
+ }
+ }
+
+ String loggerPrefix = getClass().getPackage().getName();
+ String loggerName = String.format("%s.app.%s", loggerPrefix, appName);
+ ProcessBuilder pb = createBuilder().redirectErrorStream(true);
+ pb.environment().put(LauncherProtocol.ENV_LAUNCHER_PORT,
+ String.valueOf(LauncherServer.getServerInstance().getPort()));
+ pb.environment().put(LauncherProtocol.ENV_LAUNCHER_SECRET, handle.getSecret());
+ try {
+ handle.setChildProc(pb.start(), loggerName);
+ } catch (IOException ioe) {
+ handle.kill();
+ throw ioe;
+ }
+
+ return handle;
+ }
+
+ private ProcessBuilder createBuilder() {
List<String> cmd = new ArrayList<String>();
String script = isWindows() ? "spark-submit.cmd" : "spark-submit";
cmd.add(join(File.separator, builder.getSparkHome(), "bin", script));
@@ -343,7 +445,7 @@ public class SparkLauncher {
for (Map.Entry<String, String> e : builder.childEnv.entrySet()) {
pb.environment().put(e.getKey(), e.getValue());
}
- return pb.start();
+ return pb;
}
private static class ArgumentValidator extends SparkSubmitOptionParser {
diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java
index fc87814a59..39b46e0db8 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java
@@ -188,10 +188,9 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder {
// Load the properties file and check whether spark-submit will be running the app's driver
// or just launching a cluster app. When running the driver, the JVM's argument will be
// modified to cover the driver's configuration.
- Properties props = loadPropertiesFile();
- boolean isClientMode = isClientMode(props);
- String extraClassPath = isClientMode ?
- firstNonEmptyValue(SparkLauncher.DRIVER_EXTRA_CLASSPATH, conf, props) : null;
+ Map<String, String> config = getEffectiveConfig();
+ boolean isClientMode = isClientMode(config);
+ String extraClassPath = isClientMode ? config.get(SparkLauncher.DRIVER_EXTRA_CLASSPATH) : null;
List<String> cmd = buildJavaCommand(extraClassPath);
// Take Thrift Server as daemon
@@ -212,14 +211,13 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder {
// Take Thrift Server as daemon
String tsMemory =
isThriftServer(mainClass) ? System.getenv("SPARK_DAEMON_MEMORY") : null;
- String memory = firstNonEmpty(tsMemory,
- firstNonEmptyValue(SparkLauncher.DRIVER_MEMORY, conf, props),
+ String memory = firstNonEmpty(tsMemory, config.get(SparkLauncher.DRIVER_MEMORY),
System.getenv("SPARK_DRIVER_MEMORY"), System.getenv("SPARK_MEM"), DEFAULT_MEM);
cmd.add("-Xms" + memory);
cmd.add("-Xmx" + memory);
- addOptionString(cmd, firstNonEmptyValue(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, conf, props));
+ addOptionString(cmd, config.get(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS));
mergeEnvPathList(env, getLibPathEnvName(),
- firstNonEmptyValue(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH, conf, props));
+ config.get(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH));
}
addPermGenSizeOpt(cmd);
@@ -281,9 +279,8 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder {
private void constructEnvVarArgs(
Map<String, String> env,
String submitArgsEnvVariable) throws IOException {
- Properties props = loadPropertiesFile();
mergeEnvPathList(env, getLibPathEnvName(),
- firstNonEmptyValue(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH, conf, props));
+ getEffectiveConfig().get(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH));
StringBuilder submitArgs = new StringBuilder();
for (String arg : buildSparkSubmitArgs()) {
@@ -295,9 +292,8 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder {
env.put(submitArgsEnvVariable, submitArgs.toString());
}
-
- private boolean isClientMode(Properties userProps) {
- String userMaster = firstNonEmpty(master, (String) userProps.get(SparkLauncher.SPARK_MASTER));
+ private boolean isClientMode(Map<String, String> userProps) {
+ String userMaster = firstNonEmpty(master, userProps.get(SparkLauncher.SPARK_MASTER));
// Default master is "local[*]", so assume client mode in that case.
return userMaster == null ||
"client".equals(deployMode) ||
diff --git a/launcher/src/main/java/org/apache/spark/launcher/package-info.java b/launcher/src/main/java/org/apache/spark/launcher/package-info.java
index 7c97dba511..d1ac39bdc7 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/package-info.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/package-info.java
@@ -17,17 +17,42 @@
/**
* Library for launching Spark applications.
- *
+ *
* <p>
* This library allows applications to launch Spark programmatically. There's only one entry
* point to the library - the {@link org.apache.spark.launcher.SparkLauncher} class.
* </p>
*
* <p>
- * To launch a Spark application, just instantiate a {@link org.apache.spark.launcher.SparkLauncher}
- * and configure the application to run. For example:
+ * The {@link org.apache.spark.launcher.SparkLauncher#startApplication(
+ * org.apache.spark.launcher.SparkAppHandle.Listener...)} can be used to start Spark and provide
+ * a handle to monitor and control the running application:
* </p>
- *
+ *
+ * <pre>
+ * {@code
+ * import org.apache.spark.launcher.SparkAppHandle;
+ * import org.apache.spark.launcher.SparkLauncher;
+ *
+ * public class MyLauncher {
+ * public static void main(String[] args) throws Exception {
+ * SparkAppHandle handle = new SparkLauncher()
+ * .setAppResource("/my/app.jar")
+ * .setMainClass("my.spark.app.Main")
+ * .setMaster("local")
+ * .setConf(SparkLauncher.DRIVER_MEMORY, "2g")
+ * .startApplication();
+ * // Use handle API to monitor / control application.
+ * }
+ * }
+ * }
+ * </pre>
+ *
+ * <p>
+ * It's also possible to launch a raw child process, using the
+ * {@link org.apache.spark.launcher.SparkLauncher#launch()} method:
+ * </p>
+ *
* <pre>
* {@code
* import org.apache.spark.launcher.SparkLauncher;
@@ -45,5 +70,10 @@
* }
* }
* </pre>
+ *
+ * <p>This method requires the calling code to manually manage the child process, including its
+ * output streams (to avoid possible deadlocks). It's recommended that
+ * {@link org.apache.spark.launcher.SparkLauncher#startApplication(
+ * org.apache.spark.launcher.SparkAppHandle.Listener...)} be used instead.</p>
*/
package org.apache.spark.launcher;
diff --git a/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java
new file mode 100644
index 0000000000..23e2c64d6d
--- /dev/null
+++ b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.launcher;
+
+import org.slf4j.bridge.SLF4JBridgeHandler;
+
+/**
+ * Handles configuring the JUL -> SLF4J bridge.
+ */
+class BaseSuite {
+
+ static {
+ SLF4JBridgeHandler.removeHandlersForRootLogger();
+ SLF4JBridgeHandler.install();
+ }
+
+}
diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java
new file mode 100644
index 0000000000..27cd1061a1
--- /dev/null
+++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.launcher;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.net.InetAddress;
+import java.net.Socket;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+import org.junit.Test;
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.*;
+
+import static org.apache.spark.launcher.LauncherProtocol.*;
+
+public class LauncherServerSuite extends BaseSuite {
+
+ @Test
+ public void testLauncherServerReuse() throws Exception {
+ ChildProcAppHandle handle1 = null;
+ ChildProcAppHandle handle2 = null;
+ ChildProcAppHandle handle3 = null;
+
+ try {
+ handle1 = LauncherServer.newAppHandle();
+ handle2 = LauncherServer.newAppHandle();
+ LauncherServer server1 = handle1.getServer();
+ assertSame(server1, handle2.getServer());
+
+ handle1.kill();
+ handle2.kill();
+
+ handle3 = LauncherServer.newAppHandle();
+ assertNotSame(server1, handle3.getServer());
+
+ handle3.kill();
+
+ assertNull(LauncherServer.getServerInstance());
+ } finally {
+ kill(handle1);
+ kill(handle2);
+ kill(handle3);
+ }
+ }
+
+ @Test
+ public void testCommunication() throws Exception {
+ ChildProcAppHandle handle = LauncherServer.newAppHandle();
+ TestClient client = null;
+ try {
+ Socket s = new Socket(InetAddress.getLoopbackAddress(),
+ LauncherServer.getServerInstance().getPort());
+
+ final Object waitLock = new Object();
+ handle.addListener(new SparkAppHandle.Listener() {
+ @Override
+ public void stateChanged(SparkAppHandle handle) {
+ wakeUp();
+ }
+
+ @Override
+ public void infoChanged(SparkAppHandle handle) {
+ wakeUp();
+ }
+
+ private void wakeUp() {
+ synchronized (waitLock) {
+ waitLock.notifyAll();
+ }
+ }
+ });
+
+ client = new TestClient(s);
+ synchronized (waitLock) {
+ client.send(new Hello(handle.getSecret(), "1.4.0"));
+ waitLock.wait(TimeUnit.SECONDS.toMillis(10));
+ }
+
+ // Make sure the server matched the client to the handle.
+ assertNotNull(handle.getConnection());
+
+ synchronized (waitLock) {
+ client.send(new SetAppId("app-id"));
+ waitLock.wait(TimeUnit.SECONDS.toMillis(10));
+ }
+ assertEquals("app-id", handle.getAppId());
+
+ synchronized (waitLock) {
+ client.send(new SetState(SparkAppHandle.State.RUNNING));
+ waitLock.wait(TimeUnit.SECONDS.toMillis(10));
+ }
+ assertEquals(SparkAppHandle.State.RUNNING, handle.getState());
+
+ handle.stop();
+ Message stopMsg = client.inbound.poll(10, TimeUnit.SECONDS);
+ assertTrue(stopMsg instanceof Stop);
+ } finally {
+ kill(handle);
+ close(client);
+ client.clientThread.join();
+ }
+ }
+
+ @Test
+ public void testTimeout() throws Exception {
+ final long TEST_TIMEOUT = 10L;
+
+ ChildProcAppHandle handle = null;
+ TestClient client = null;
+ try {
+ SparkLauncher.setConfig(SparkLauncher.CHILD_CONNECTION_TIMEOUT, String.valueOf(TEST_TIMEOUT));
+
+ handle = LauncherServer.newAppHandle();
+
+ Socket s = new Socket(InetAddress.getLoopbackAddress(),
+ LauncherServer.getServerInstance().getPort());
+ client = new TestClient(s);
+
+ Thread.sleep(TEST_TIMEOUT * 10);
+ try {
+ client.send(new Hello(handle.getSecret(), "1.4.0"));
+ fail("Expected exception caused by connection timeout.");
+ } catch (IllegalStateException e) {
+ // Expected.
+ }
+ } finally {
+ SparkLauncher.launcherConfig.remove(SparkLauncher.CHILD_CONNECTION_TIMEOUT);
+ kill(handle);
+ close(client);
+ }
+ }
+
+ private void kill(SparkAppHandle handle) {
+ if (handle != null) {
+ handle.kill();
+ }
+ }
+
+ private void close(Closeable c) {
+ if (c != null) {
+ try {
+ c.close();
+ } catch (Exception e) {
+ // no-op.
+ }
+ }
+ }
+
+ private static class TestClient extends LauncherConnection {
+
+ final BlockingQueue<Message> inbound;
+ final Thread clientThread;
+
+ TestClient(Socket s) throws IOException {
+ super(s);
+ this.inbound = new LinkedBlockingQueue<Message>();
+ this.clientThread = new Thread(this);
+ clientThread.setName("TestClient");
+ clientThread.setDaemon(true);
+ clientThread.start();
+ }
+
+ @Override
+ protected void handle(Message msg) throws IOException {
+ inbound.offer(msg);
+ }
+
+ }
+
+}
diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java
index 7329ac9f7f..d5397b0685 100644
--- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java
+++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java
@@ -30,7 +30,7 @@ import org.junit.BeforeClass;
import org.junit.Test;
import static org.junit.Assert.*;
-public class SparkSubmitCommandBuilderSuite {
+public class SparkSubmitCommandBuilderSuite extends BaseSuite {
private static File dummyPropsFile;
private static SparkSubmitOptionParser parser;
@@ -161,7 +161,7 @@ public class SparkSubmitCommandBuilderSuite {
launcher.appResource = "/foo";
launcher.appName = "MyApp";
launcher.mainClass = "my.Class";
- launcher.propertiesFile = dummyPropsFile.getAbsolutePath();
+ launcher.setPropertiesFile(dummyPropsFile.getAbsolutePath());
launcher.appArgs.add("foo");
launcher.appArgs.add("bar");
launcher.conf.put(SparkLauncher.DRIVER_MEMORY, "1g");
diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitOptionParserSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitOptionParserSuite.java
index f3d2109917..3ee5b8cf96 100644
--- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitOptionParserSuite.java
+++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitOptionParserSuite.java
@@ -28,7 +28,7 @@ import static org.mockito.Mockito.*;
import static org.apache.spark.launcher.SparkSubmitOptionParser.*;
-public class SparkSubmitOptionParserSuite {
+public class SparkSubmitOptionParserSuite extends BaseSuite {
private SparkSubmitOptionParser parser;
diff --git a/launcher/src/test/resources/log4j.properties b/launcher/src/test/resources/log4j.properties
index 67a6a98217..c64b1565e1 100644
--- a/launcher/src/test/resources/log4j.properties
+++ b/launcher/src/test/resources/log4j.properties
@@ -16,16 +16,19 @@
#
# Set everything to be logged to the file core/target/unit-tests.log
-log4j.rootCategory=INFO, file
+test.appender=file
+log4j.rootCategory=INFO, ${test.appender}
log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file.append=false
-
-# Some tests will set "test.name" to avoid overwriting the main log file.
-log4j.appender.file.file=target/unit-tests${test.name}.log
-
+log4j.appender.file.file=target/unit-tests.log
log4j.appender.file.layout=org.apache.log4j.PatternLayout
log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
+log4j.appender.childproc=org.apache.log4j.ConsoleAppender
+log4j.appender.childproc.target=System.err
+log4j.appender.childproc.layout=org.apache.log4j.PatternLayout
+log4j.appender.childproc.layout.ConversionPattern=%t: %m%n
+
# Ignore messages below warning level from Jetty, because it's a bit verbose
log4j.logger.org.spark-project.jetty=WARN
org.spark-project.jetty.LEVEL=WARN
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index eb3b7fb885..cec81b9406 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -55,8 +55,8 @@ import org.apache.hadoop.yarn.exceptions.ApplicationNotFoundException
import org.apache.hadoop.yarn.util.Records
import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkException}
+import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle, YarnCommandBuilderUtils}
import org.apache.spark.deploy.SparkHadoopUtil
-import org.apache.spark.launcher.YarnCommandBuilderUtils
import org.apache.spark.util.Utils
private[spark] class Client(
@@ -70,8 +70,6 @@ private[spark] class Client(
def this(clientArgs: ClientArguments, spConf: SparkConf) =
this(clientArgs, SparkHadoopUtil.get.newConfiguration(spConf), spConf)
- def this(clientArgs: ClientArguments) = this(clientArgs, new SparkConf())
-
private val yarnClient = YarnClient.createYarnClient
private val yarnConf = new YarnConfiguration(hadoopConf)
private var credentials: Credentials = null
@@ -84,10 +82,27 @@ private[spark] class Client(
private var principal: String = null
private var keytab: String = null
+ private val launcherBackend = new LauncherBackend() {
+ override def onStopRequest(): Unit = {
+ if (isClusterMode && appId != null) {
+ yarnClient.killApplication(appId)
+ } else {
+ setState(SparkAppHandle.State.KILLED)
+ stop()
+ }
+ }
+ }
private val fireAndForget = isClusterMode &&
!sparkConf.getBoolean("spark.yarn.submit.waitAppCompletion", true)
+ private var appId: ApplicationId = null
+
+ def reportLauncherState(state: SparkAppHandle.State): Unit = {
+ launcherBackend.setState(state)
+ }
+
def stop(): Unit = {
+ launcherBackend.close()
yarnClient.stop()
// Unset YARN mode system env variable, to allow switching between cluster types.
System.clearProperty("SPARK_YARN_MODE")
@@ -103,6 +118,7 @@ private[spark] class Client(
def submitApplication(): ApplicationId = {
var appId: ApplicationId = null
try {
+ launcherBackend.connect()
// Setup the credentials before doing anything else,
// so we have don't have issues at any point.
setupCredentials()
@@ -116,6 +132,8 @@ private[spark] class Client(
val newApp = yarnClient.createApplication()
val newAppResponse = newApp.getNewApplicationResponse()
appId = newAppResponse.getApplicationId()
+ reportLauncherState(SparkAppHandle.State.SUBMITTED)
+ launcherBackend.setAppId(appId.toString())
// Verify whether the cluster has enough resources for our AM
verifyClusterResources(newAppResponse)
@@ -881,6 +899,20 @@ private[spark] class Client(
}
}
+ if (lastState != state) {
+ state match {
+ case YarnApplicationState.RUNNING =>
+ reportLauncherState(SparkAppHandle.State.RUNNING)
+ case YarnApplicationState.FINISHED =>
+ reportLauncherState(SparkAppHandle.State.FINISHED)
+ case YarnApplicationState.FAILED =>
+ reportLauncherState(SparkAppHandle.State.FAILED)
+ case YarnApplicationState.KILLED =>
+ reportLauncherState(SparkAppHandle.State.KILLED)
+ case _ =>
+ }
+ }
+
if (state == YarnApplicationState.FINISHED ||
state == YarnApplicationState.FAILED ||
state == YarnApplicationState.KILLED) {
@@ -928,8 +960,8 @@ private[spark] class Client(
* throw an appropriate SparkException.
*/
def run(): Unit = {
- val appId = submitApplication()
- if (fireAndForget) {
+ this.appId = submitApplication()
+ if (!launcherBackend.isConnected() && fireAndForget) {
val report = getApplicationReport(appId)
val state = report.getYarnApplicationState
logInfo(s"Application report for $appId (state: $state)")
@@ -971,6 +1003,7 @@ private[spark] class Client(
}
object Client extends Logging {
+
def main(argStrings: Array[String]) {
if (!sys.props.contains("SPARK_SUBMIT")) {
logWarning("WARNING: This client is deprecated and will be removed in a " +
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
index 36d5759554..20771f6554 100644
--- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
@@ -23,6 +23,7 @@ import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState}
import org.apache.spark.{SparkException, Logging, SparkContext}
import org.apache.spark.deploy.yarn.{Client, ClientArguments, YarnSparkHadoopUtil}
+import org.apache.spark.launcher.SparkAppHandle
import org.apache.spark.scheduler.TaskSchedulerImpl
private[spark] class YarnClientSchedulerBackend(
@@ -177,6 +178,15 @@ private[spark] class YarnClientSchedulerBackend(
if (monitorThread != null) {
monitorThread.stopMonitor()
}
+
+ // Report a final state to the launcher if one is connected. This is needed since in client
+ // mode this backend doesn't let the app monitor loop run to completion, so it does not report
+ // the final state itself.
+ //
+ // Note: there's not enough information at this point to provide a better final state,
+ // so assume the application was successful.
+ client.reportLauncherState(SparkAppHandle.State.FINISHED)
+
super.stop()
YarnSparkHadoopUtil.get.stopExecutorDelegationTokenRenewer()
client.stop()
diff --git a/yarn/src/test/resources/log4j.properties b/yarn/src/test/resources/log4j.properties
index 6b8a5dbf63..6b9a799954 100644
--- a/yarn/src/test/resources/log4j.properties
+++ b/yarn/src/test/resources/log4j.properties
@@ -23,6 +23,9 @@ log4j.appender.file.file=target/unit-tests.log
log4j.appender.file.layout=org.apache.log4j.PatternLayout
log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
-# Ignore messages below warning level from Jetty, because it's a bit verbose
-log4j.logger.org.spark-project.jetty=WARN
+# Ignore messages below warning level from a few verbose libraries.
+log4j.logger.com.sun.jersey=WARN
log4j.logger.org.apache.hadoop=WARN
+log4j.logger.org.eclipse.jetty=WARN
+log4j.logger.org.mortbay=WARN
+log4j.logger.org.spark-project.jetty=WARN
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala
index 17c59ff06e..12494b0105 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala
@@ -22,15 +22,18 @@ import java.util.Properties
import java.util.concurrent.TimeUnit
import scala.collection.JavaConverters._
+import scala.concurrent.duration._
+import scala.language.postfixOps
import com.google.common.base.Charsets.UTF_8
import com.google.common.io.Files
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.server.MiniYARNCluster
import org.scalatest.{BeforeAndAfterAll, Matchers}
+import org.scalatest.concurrent.Eventually._
import org.apache.spark._
-import org.apache.spark.launcher.TestClasspathBuilder
+import org.apache.spark.launcher._
import org.apache.spark.util.Utils
abstract class BaseYarnClusterSuite
@@ -46,13 +49,14 @@ abstract class BaseYarnClusterSuite
|log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
|log4j.logger.org.apache.hadoop=WARN
|log4j.logger.org.eclipse.jetty=WARN
+ |log4j.logger.org.mortbay=WARN
|log4j.logger.org.spark-project.jetty=WARN
""".stripMargin
private var yarnCluster: MiniYARNCluster = _
protected var tempDir: File = _
private var fakeSparkJar: File = _
- private var hadoopConfDir: File = _
+ protected var hadoopConfDir: File = _
private var logConfDir: File = _
def newYarnConfig(): YarnConfiguration
@@ -120,15 +124,77 @@ abstract class BaseYarnClusterSuite
clientMode: Boolean,
klass: String,
appArgs: Seq[String] = Nil,
- sparkArgs: Seq[String] = Nil,
+ sparkArgs: Seq[(String, String)] = Nil,
extraClassPath: Seq[String] = Nil,
extraJars: Seq[String] = Nil,
extraConf: Map[String, String] = Map(),
- extraEnv: Map[String, String] = Map()): Unit = {
+ extraEnv: Map[String, String] = Map()): SparkAppHandle.State = {
val master = if (clientMode) "yarn-client" else "yarn-cluster"
- val props = new Properties()
+ val propsFile = createConfFile(extraClassPath = extraClassPath, extraConf = extraConf)
+ val env = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath()) ++ extraEnv
+
+ val launcher = new SparkLauncher(env.asJava)
+ if (klass.endsWith(".py")) {
+ launcher.setAppResource(klass)
+ } else {
+ launcher.setMainClass(klass)
+ launcher.setAppResource(fakeSparkJar.getAbsolutePath())
+ }
+ launcher.setSparkHome(sys.props("spark.test.home"))
+ .setMaster(master)
+ .setConf("spark.executor.instances", "1")
+ .setPropertiesFile(propsFile)
+ .addAppArgs(appArgs.toArray: _*)
+
+ sparkArgs.foreach { case (name, value) =>
+ if (value != null) {
+ launcher.addSparkArg(name, value)
+ } else {
+ launcher.addSparkArg(name)
+ }
+ }
+ extraJars.foreach(launcher.addJar)
- props.setProperty("spark.yarn.jar", "local:" + fakeSparkJar.getAbsolutePath())
+ val handle = launcher.startApplication()
+ try {
+ eventually(timeout(2 minutes), interval(1 second)) {
+ assert(handle.getState().isFinal())
+ }
+ } finally {
+ handle.kill()
+ }
+
+ handle.getState()
+ }
+
+ /**
+ * This is a workaround for an issue with yarn-cluster mode: the Client class will not provide
+ * any sort of error when the job process finishes successfully, but the job itself fails. So
+ * the tests enforce that something is written to a file after everything is ok to indicate
+ * that the job succeeded.
+ */
+ protected def checkResult(finalState: SparkAppHandle.State, result: File): Unit = {
+ checkResult(finalState, result, "success")
+ }
+
+ protected def checkResult(
+ finalState: SparkAppHandle.State,
+ result: File,
+ expected: String): Unit = {
+ finalState should be (SparkAppHandle.State.FINISHED)
+ val resultString = Files.toString(result, UTF_8)
+ resultString should be (expected)
+ }
+
+ protected def mainClassName(klass: Class[_]): String = {
+ klass.getName().stripSuffix("$")
+ }
+
+ protected def createConfFile(
+ extraClassPath: Seq[String] = Nil,
+ extraConf: Map[String, String] = Map()): String = {
+ val props = new Properties()
+ props.put("spark.yarn.jar", "local:" + fakeSparkJar.getAbsolutePath())
val testClasspath = new TestClasspathBuilder()
.buildClassPath(
@@ -138,69 +204,28 @@ abstract class BaseYarnClusterSuite
.asScala
.mkString(File.pathSeparator)
- props.setProperty("spark.driver.extraClassPath", testClasspath)
- props.setProperty("spark.executor.extraClassPath", testClasspath)
+ props.put("spark.driver.extraClassPath", testClasspath)
+ props.put("spark.executor.extraClassPath", testClasspath)
// SPARK-4267: make sure java options are propagated correctly.
props.setProperty("spark.driver.extraJavaOptions", "-Dfoo=\"one two three\"")
props.setProperty("spark.executor.extraJavaOptions", "-Dfoo=\"one two three\"")
- yarnCluster.getConfig.asScala.foreach { e =>
+ yarnCluster.getConfig().asScala.foreach { e =>
props.setProperty("spark.hadoop." + e.getKey(), e.getValue())
}
-
sys.props.foreach { case (k, v) =>
if (k.startsWith("spark.")) {
props.setProperty(k, v)
}
}
-
extraConf.foreach { case (k, v) => props.setProperty(k, v) }
val propsFile = File.createTempFile("spark", ".properties", tempDir)
val writer = new OutputStreamWriter(new FileOutputStream(propsFile), UTF_8)
props.store(writer, "Spark properties.")
writer.close()
-
- val extraJarArgs = if (extraJars.nonEmpty) Seq("--jars", extraJars.mkString(",")) else Nil
- val mainArgs =
- if (klass.endsWith(".py")) {
- Seq(klass)
- } else {
- Seq("--class", klass, fakeSparkJar.getAbsolutePath())
- }
- val argv =
- Seq(
- new File(sys.props("spark.test.home"), "bin/spark-submit").getAbsolutePath(),
- "--master", master,
- "--num-executors", "1",
- "--properties-file", propsFile.getAbsolutePath()) ++
- extraJarArgs ++
- sparkArgs ++
- mainArgs ++
- appArgs
-
- Utils.executeAndGetOutput(argv,
- extraEnvironment = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath()) ++ extraEnv)
- }
-
- /**
- * This is a workaround for an issue with yarn-cluster mode: the Client class will not provide
- * any sort of error when the job process finishes successfully, but the job itself fails. So
- * the tests enforce that something is written to a file after everything is ok to indicate
- * that the job succeeded.
- */
- protected def checkResult(result: File): Unit = {
- checkResult(result, "success")
- }
-
- protected def checkResult(result: File, expected: String): Unit = {
- val resultString = Files.toString(result, UTF_8)
- resultString should be (expected)
- }
-
- protected def mainClassName(klass: Class[_]): String = {
- klass.getName().stripSuffix("$")
+ propsFile.getAbsolutePath()
}
}
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
index f1601cd161..d1cd0c89b5 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
@@ -19,16 +19,20 @@ package org.apache.spark.deploy.yarn
import java.io.File
import java.net.URL
+import java.util.{HashMap => JHashMap, Properties}
import scala.collection.mutable
+import scala.concurrent.duration._
+import scala.language.postfixOps
import com.google.common.base.Charsets.UTF_8
import com.google.common.io.{ByteStreams, Files}
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.scalatest.Matchers
+import org.scalatest.concurrent.Eventually._
import org.apache.spark._
-import org.apache.spark.launcher.TestClasspathBuilder
+import org.apache.spark.launcher._
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart,
SparkListenerExecutorAdded}
import org.apache.spark.scheduler.cluster.ExecutorInfo
@@ -82,10 +86,8 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
test("run Spark in yarn-cluster mode unsuccessfully") {
// Don't provide arguments so the driver will fail.
- val exception = intercept[SparkException] {
- runSpark(false, mainClassName(YarnClusterDriver.getClass))
- fail("Spark application should have failed.")
- }
+ val finalState = runSpark(false, mainClassName(YarnClusterDriver.getClass))
+ finalState should be (SparkAppHandle.State.FAILED)
}
test("run Python application in yarn-client mode") {
@@ -104,11 +106,42 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
testUseClassPathFirst(false)
}
+ test("monitor app using launcher library") {
+ val env = new JHashMap[String, String]()
+ env.put("YARN_CONF_DIR", hadoopConfDir.getAbsolutePath())
+
+ val propsFile = createConfFile()
+ val handle = new SparkLauncher(env)
+ .setSparkHome(sys.props("spark.test.home"))
+ .setConf("spark.ui.enabled", "false")
+ .setPropertiesFile(propsFile)
+ .setMaster("yarn-client")
+ .setAppResource("spark-internal")
+ .setMainClass(mainClassName(YarnLauncherTestApp.getClass))
+ .startApplication()
+
+ try {
+ eventually(timeout(30 seconds), interval(100 millis)) {
+ handle.getState() should be (SparkAppHandle.State.RUNNING)
+ }
+
+ handle.getAppId() should not be (null)
+ handle.getAppId() should startWith ("application_")
+ handle.stop()
+
+ eventually(timeout(30 seconds), interval(100 millis)) {
+ handle.getState() should be (SparkAppHandle.State.KILLED)
+ }
+ } finally {
+ handle.kill()
+ }
+ }
+
private def testBasicYarnApp(clientMode: Boolean): Unit = {
val result = File.createTempFile("result", null, tempDir)
- runSpark(clientMode, mainClassName(YarnClusterDriver.getClass),
+ val finalState = runSpark(clientMode, mainClassName(YarnClusterDriver.getClass),
appArgs = Seq(result.getAbsolutePath()))
- checkResult(result)
+ checkResult(finalState, result)
}
private def testPySpark(clientMode: Boolean): Unit = {
@@ -143,11 +176,11 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
val pyFiles = Seq(pyModule.getAbsolutePath(), mod2Archive.getPath()).mkString(",")
val result = File.createTempFile("result", null, tempDir)
- runSpark(clientMode, primaryPyFile.getAbsolutePath(),
- sparkArgs = Seq("--py-files", pyFiles),
+ val finalState = runSpark(clientMode, primaryPyFile.getAbsolutePath(),
+ sparkArgs = Seq("--py-files" -> pyFiles),
appArgs = Seq(result.getAbsolutePath()),
extraEnv = extraEnv)
- checkResult(result)
+ checkResult(finalState, result)
}
private def testUseClassPathFirst(clientMode: Boolean): Unit = {
@@ -156,15 +189,15 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
val userJar = TestUtils.createJarWithFiles(Map("test.resource" -> "OVERRIDDEN"), tempDir)
val driverResult = File.createTempFile("driver", null, tempDir)
val executorResult = File.createTempFile("executor", null, tempDir)
- runSpark(clientMode, mainClassName(YarnClasspathTest.getClass),
+ val finalState = runSpark(clientMode, mainClassName(YarnClasspathTest.getClass),
appArgs = Seq(driverResult.getAbsolutePath(), executorResult.getAbsolutePath()),
extraClassPath = Seq(originalJar.getPath()),
extraJars = Seq("local:" + userJar.getPath()),
extraConf = Map(
"spark.driver.userClassPathFirst" -> "true",
"spark.executor.userClassPathFirst" -> "true"))
- checkResult(driverResult, "OVERRIDDEN")
- checkResult(executorResult, "OVERRIDDEN")
+ checkResult(finalState, driverResult, "OVERRIDDEN")
+ checkResult(finalState, executorResult, "OVERRIDDEN")
}
}
@@ -211,8 +244,8 @@ private object YarnClusterDriver extends Logging with Matchers {
data should be (Set(1, 2, 3, 4))
result = "success"
} finally {
- sc.stop()
Files.write(result, status, UTF_8)
+ sc.stop()
}
// verify log urls are present
@@ -297,3 +330,18 @@ private object YarnClasspathTest extends Logging {
}
}
+
+private object YarnLauncherTestApp {
+
+ def main(args: Array[String]): Unit = {
+ // Do not stop the application; the test will stop it using the launcher lib. Just run a task
+ // that will prevent the process from exiting.
+ val sc = new SparkContext(new SparkConf())
+ sc.parallelize(Seq(1)).foreach { i =>
+ this.synchronized {
+ wait()
+ }
+ }
+ }
+
+}
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala
index a85e5772a0..c17e8695c2 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala
@@ -53,7 +53,7 @@ class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite {
logInfo("Shuffle service port = " + shuffleServicePort)
val result = File.createTempFile("result", null, tempDir)
- runSpark(
+ val finalState = runSpark(
false,
mainClassName(YarnExternalShuffleDriver.getClass),
appArgs = Seq(result.getAbsolutePath(), registeredExecFile.getAbsolutePath),
@@ -62,7 +62,7 @@ class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite {
"spark.shuffle.service.port" -> shuffleServicePort.toString
)
)
- checkResult(result)
+ checkResult(finalState, result)
assert(YarnTestAccessor.getRegisteredExecutorFile(shuffleService).exists())
}
}