diff options
29 files changed, 1820 insertions, 146 deletions
diff --git a/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala b/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala new file mode 100644 index 0000000000..3ea984c501 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher + +import java.net.{InetAddress, Socket} + +import org.apache.spark.SPARK_VERSION +import org.apache.spark.launcher.LauncherProtocol._ +import org.apache.spark.util.ThreadUtils + +/** + * A class that can be used to talk to a launcher server. Users should extend this class to + * provide implementation for the abstract methods. + * + * See `LauncherServer` for an explanation of how launcher communication works. + */ +private[spark] abstract class LauncherBackend { + + private var clientThread: Thread = _ + private var connection: BackendConnection = _ + private var lastState: SparkAppHandle.State = _ + @volatile private var _isConnected = false + + def connect(): Unit = { + val port = sys.env.get(LauncherProtocol.ENV_LAUNCHER_PORT).map(_.toInt) + val secret = sys.env.get(LauncherProtocol.ENV_LAUNCHER_SECRET) + if (port != None && secret != None) { + val s = new Socket(InetAddress.getLoopbackAddress(), port.get) + connection = new BackendConnection(s) + connection.send(new Hello(secret.get, SPARK_VERSION)) + clientThread = LauncherBackend.threadFactory.newThread(connection) + clientThread.start() + _isConnected = true + } + } + + def close(): Unit = { + if (connection != null) { + try { + connection.close() + } finally { + if (clientThread != null) { + clientThread.join() + } + } + } + } + + def setAppId(appId: String): Unit = { + if (connection != null) { + connection.send(new SetAppId(appId)) + } + } + + def setState(state: SparkAppHandle.State): Unit = { + if (connection != null && lastState != state) { + connection.send(new SetState(state)) + lastState = state + } + } + + /** Return whether the launcher handle is still connected to this backend. */ + def isConnected(): Boolean = _isConnected + + /** + * Implementations should provide this method, which should try to stop the application + * as gracefully as possible. + */ + protected def onStopRequest(): Unit + + /** + * Callback for when the launcher handle disconnects from this backend. + */ + protected def onDisconnected() : Unit = { } + + + private class BackendConnection(s: Socket) extends LauncherConnection(s) { + + override protected def handle(m: Message): Unit = m match { + case _: Stop => + onStopRequest() + + case _ => + throw new IllegalArgumentException(s"Unexpected message type: ${m.getClass().getName()}") + } + + override def close(): Unit = { + try { + super.close() + } finally { + onDisconnected() + _isConnected = false + } + } + + } + +} + +private object LauncherBackend { + + val threadFactory = ThreadUtils.namedThreadFactory("LauncherBackend") + +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 27491ecf8b..2625c3e7ac 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -23,6 +23,7 @@ import org.apache.spark.rpc.RpcAddress import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.deploy.client.{AppClient, AppClientListener} +import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} import org.apache.spark.scheduler._ import org.apache.spark.util.Utils @@ -36,6 +37,9 @@ private[spark] class SparkDeploySchedulerBackend( private var client: AppClient = null private var stopping = false + private val launcherBackend = new LauncherBackend() { + override protected def onStopRequest(): Unit = stop(SparkAppHandle.State.KILLED) + } @volatile var shutdownCallback: SparkDeploySchedulerBackend => Unit = _ @volatile private var appId: String = _ @@ -47,6 +51,7 @@ private[spark] class SparkDeploySchedulerBackend( override def start() { super.start() + launcherBackend.connect() // The endpoint for executors to talk to us val driverUrl = rpcEnv.uriOf(SparkEnv.driverActorSystemName, @@ -87,24 +92,20 @@ private[spark] class SparkDeploySchedulerBackend( command, appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor) client = new AppClient(sc.env.rpcEnv, masters, appDesc, this, conf) client.start() + launcherBackend.setState(SparkAppHandle.State.SUBMITTED) waitForRegistration() + launcherBackend.setState(SparkAppHandle.State.RUNNING) } - override def stop() { - stopping = true - super.stop() - client.stop() - - val callback = shutdownCallback - if (callback != null) { - callback(this) - } + override def stop(): Unit = synchronized { + stop(SparkAppHandle.State.FINISHED) } override def connected(appId: String) { logInfo("Connected to Spark cluster with app ID " + appId) this.appId = appId notifyContext() + launcherBackend.setAppId(appId) } override def disconnected() { @@ -117,6 +118,7 @@ private[spark] class SparkDeploySchedulerBackend( override def dead(reason: String) { notifyContext() if (!stopping) { + launcherBackend.setState(SparkAppHandle.State.KILLED) logError("Application has been killed. Reason: " + reason) try { scheduler.error(reason) @@ -188,4 +190,19 @@ private[spark] class SparkDeploySchedulerBackend( registrationBarrier.release() } + private def stop(finalState: SparkAppHandle.State): Unit = synchronized { + stopping = true + + launcherBackend.setState(finalState) + launcherBackend.close() + + super.stop() + client.stop() + + val callback = shutdownCallback + if (callback != null) { + callback(this) + } + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 4d48fcfea4..c633d860ae 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -24,6 +24,7 @@ import java.nio.ByteBuffer import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} +import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo @@ -103,6 +104,9 @@ private[spark] class LocalBackend( private var localEndpoint: RpcEndpointRef = null private val userClassPath = getUserClasspath(conf) private val listenerBus = scheduler.sc.listenerBus + private val launcherBackend = new LauncherBackend() { + override def onStopRequest(): Unit = stop(SparkAppHandle.State.KILLED) + } /** * Returns a list of URLs representing the user classpath. @@ -114,6 +118,8 @@ private[spark] class LocalBackend( userClassPathStr.map(_.split(File.pathSeparator)).toSeq.flatten.map(new File(_).toURI.toURL) } + launcherBackend.connect() + override def start() { val rpcEnv = SparkEnv.get.rpcEnv val executorEndpoint = new LocalEndpoint(rpcEnv, userClassPath, scheduler, this, totalCores) @@ -122,10 +128,12 @@ private[spark] class LocalBackend( System.currentTimeMillis, executorEndpoint.localExecutorId, new ExecutorInfo(executorEndpoint.localExecutorHostname, totalCores, Map.empty))) + launcherBackend.setAppId(appId) + launcherBackend.setState(SparkAppHandle.State.RUNNING) } override def stop() { - localEndpoint.ask(StopExecutor) + stop(SparkAppHandle.State.FINISHED) } override def reviveOffers() { @@ -145,4 +153,13 @@ private[spark] class LocalBackend( override def applicationId(): String = appId + private def stop(finalState: SparkAppHandle.State): Unit = { + localEndpoint.ask(StopExecutor) + try { + launcherBackend.setState(finalState) + } finally { + launcherBackend.close() + } + } + } diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index d0c26dd056..aa15e792e2 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -27,6 +27,7 @@ import java.util.Map; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.slf4j.bridge.SLF4JBridgeHandler; import static org.junit.Assert.*; /** @@ -34,7 +35,13 @@ import static org.junit.Assert.*; */ public class SparkLauncherSuite { + static { + SLF4JBridgeHandler.removeHandlersForRootLogger(); + SLF4JBridgeHandler.install(); + } + private static final Logger LOG = LoggerFactory.getLogger(SparkLauncherSuite.class); + private static final NamedThreadFactory TF = new NamedThreadFactory("SparkLauncherSuite-%d"); @Test public void testSparkArgumentHandling() throws Exception { @@ -94,14 +101,15 @@ public class SparkLauncherSuite { .addSparkArg(opts.CONF, String.format("%s=-Dfoo=ShouldBeOverriddenBelow", SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS)) .setConf(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, - "-Dfoo=bar -Dtest.name=-testChildProcLauncher") + "-Dfoo=bar -Dtest.appender=childproc") .setConf(SparkLauncher.DRIVER_EXTRA_CLASSPATH, System.getProperty("java.class.path")) .addSparkArg(opts.CLASS, "ShouldBeOverriddenBelow") .setMainClass(SparkLauncherTestApp.class.getName()) .addAppArgs("proc"); final Process app = launcher.launch(); - new Redirector("stdout", app.getInputStream()).start(); - new Redirector("stderr", app.getErrorStream()).start(); + + new OutputRedirector(app.getInputStream(), TF); + new OutputRedirector(app.getErrorStream(), TF); assertEquals(0, app.waitFor()); } @@ -116,29 +124,4 @@ public class SparkLauncherSuite { } - private static class Redirector extends Thread { - - private final InputStream in; - - Redirector(String name, InputStream in) { - this.in = in; - setName(name); - setDaemon(true); - } - - @Override - public void run() { - try { - BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8")); - String line; - while ((line = reader.readLine()) != null) { - LOG.warn(line); - } - } catch (Exception e) { - LOG.error("Error reading process output.", e); - } - } - - } - } diff --git a/core/src/test/resources/log4j.properties b/core/src/test/resources/log4j.properties index eb3b1999eb..a54d27de91 100644 --- a/core/src/test/resources/log4j.properties +++ b/core/src/test/resources/log4j.properties @@ -16,13 +16,22 @@ # # Set everything to be logged to the file target/unit-tests.log -log4j.rootCategory=INFO, file +test.appender=file +log4j.rootCategory=INFO, ${test.appender} log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=true log4j.appender.file.file=target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n +# Tests that launch java subprocesses can set the "test.appender" system property to +# "console" to avoid having the child process's logs overwrite the unit test's +# log file. +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.target=System.err +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%t: %m%n + # Ignore messages below warning level from Jetty, because it's a bit verbose log4j.logger.org.spark-project.jetty=WARN org.spark-project.jetty.LEVEL=WARN diff --git a/core/src/test/scala/org/apache/spark/launcher/LauncherBackendSuite.scala b/core/src/test/scala/org/apache/spark/launcher/LauncherBackendSuite.scala new file mode 100644 index 0000000000..07e8869833 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/launcher/LauncherBackendSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher + +import java.util.concurrent.TimeUnit + +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.scalatest.Matchers +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark._ +import org.apache.spark.launcher._ + +class LauncherBackendSuite extends SparkFunSuite with Matchers { + + private val tests = Seq( + "local" -> "local", + "standalone/client" -> "local-cluster[1,1,1024]") + + tests.foreach { case (name, master) => + test(s"$name: launcher handle") { + testWithMaster(master) + } + } + + private def testWithMaster(master: String): Unit = { + val env = new java.util.HashMap[String, String]() + env.put("SPARK_PRINT_LAUNCH_COMMAND", "1") + val handle = new SparkLauncher(env) + .setSparkHome(sys.props("spark.test.home")) + .setConf(SparkLauncher.DRIVER_EXTRA_CLASSPATH, System.getProperty("java.class.path")) + .setConf("spark.ui.enabled", "false") + .setConf(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, s"-Dtest.appender=console") + .setMaster(master) + .setAppResource("spark-internal") + .setMainClass(TestApp.getClass.getName().stripSuffix("$")) + .startApplication() + + try { + eventually(timeout(10 seconds), interval(100 millis)) { + handle.getAppId() should not be (null) + } + + handle.stop() + + eventually(timeout(10 seconds), interval(100 millis)) { + handle.getState() should be (SparkAppHandle.State.KILLED) + } + } finally { + handle.kill() + } + } + +} + +object TestApp { + + def main(args: Array[String]): Unit = { + new SparkContext(new SparkConf()).parallelize(Seq(1)).foreach { i => + Thread.sleep(TimeUnit.SECONDS.toMillis(20)) + } + } + +} diff --git a/launcher/pom.xml b/launcher/pom.xml index d595d74642..5739bfc169 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -49,6 +49,11 @@ </dependency> <dependency> <groupId>org.slf4j</groupId> + <artifactId>jul-to-slf4j</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.slf4j</groupId> <artifactId>slf4j-api</artifactId> <scope>test</scope> </dependency> diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index 610e8bdaaa..cf3729b7fe 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -47,7 +47,7 @@ abstract class AbstractCommandBuilder { String javaHome; String mainClass; String master; - String propertiesFile; + protected String propertiesFile; final List<String> appArgs; final List<String> jars; final List<String> files; @@ -55,6 +55,10 @@ abstract class AbstractCommandBuilder { final Map<String, String> childEnv; final Map<String, String> conf; + // The merged configuration for the application. Cached to avoid having to read / parse + // properties files multiple times. + private Map<String, String> effectiveConfig; + public AbstractCommandBuilder() { this.appArgs = new ArrayList<String>(); this.childEnv = new HashMap<String, String>(); @@ -257,12 +261,38 @@ abstract class AbstractCommandBuilder { return path; } + String getenv(String key) { + return firstNonEmpty(childEnv.get(key), System.getenv(key)); + } + + void setPropertiesFile(String path) { + effectiveConfig = null; + this.propertiesFile = path; + } + + Map<String, String> getEffectiveConfig() throws IOException { + if (effectiveConfig == null) { + if (propertiesFile == null) { + effectiveConfig = conf; + } else { + effectiveConfig = new HashMap<>(conf); + Properties p = loadPropertiesFile(); + for (String key : p.stringPropertyNames()) { + if (!effectiveConfig.containsKey(key)) { + effectiveConfig.put(key, p.getProperty(key)); + } + } + } + } + return effectiveConfig; + } + /** * Loads the configuration file for the application, if it exists. This is either the * user-specified properties file, or the spark-defaults.conf file under the Spark configuration * directory. */ - Properties loadPropertiesFile() throws IOException { + private Properties loadPropertiesFile() throws IOException { Properties props = new Properties(); File propsFile; if (propertiesFile != null) { @@ -294,10 +324,6 @@ abstract class AbstractCommandBuilder { return props; } - String getenv(String key) { - return firstNonEmpty(childEnv.get(key), System.getenv(key)); - } - private String findAssembly() { String sparkHome = getSparkHome(); File libdir; diff --git a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java new file mode 100644 index 0000000000..de50f14fbd --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ThreadFactory; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * Handle implementation for monitoring apps started as a child process. + */ +class ChildProcAppHandle implements SparkAppHandle { + + private static final Logger LOG = Logger.getLogger(ChildProcAppHandle.class.getName()); + private static final ThreadFactory REDIRECTOR_FACTORY = + new NamedThreadFactory("launcher-proc-%d"); + + private final String secret; + private final LauncherServer server; + + private Process childProc; + private boolean disposed; + private LauncherConnection connection; + private List<Listener> listeners; + private State state; + private String appId; + private OutputRedirector redirector; + + ChildProcAppHandle(String secret, LauncherServer server) { + this.secret = secret; + this.server = server; + this.state = State.UNKNOWN; + } + + @Override + public synchronized void addListener(Listener l) { + if (listeners == null) { + listeners = new ArrayList<>(); + } + listeners.add(l); + } + + @Override + public State getState() { + return state; + } + + @Override + public String getAppId() { + return appId; + } + + @Override + public void stop() { + CommandBuilderUtils.checkState(connection != null, "Application is still not connected."); + try { + connection.send(new LauncherProtocol.Stop()); + } catch (IOException ioe) { + throw new RuntimeException(ioe); + } + } + + @Override + public synchronized void disconnect() { + if (!disposed) { + disposed = true; + if (connection != null) { + try { + connection.close(); + } catch (IOException ioe) { + // no-op. + } + } + server.unregister(this); + if (redirector != null) { + redirector.stop(); + } + } + } + + @Override + public synchronized void kill() { + if (!disposed) { + disconnect(); + } + if (childProc != null) { + childProc.destroy(); + childProc = null; + } + } + + String getSecret() { + return secret; + } + + void setChildProc(Process childProc, String loggerName) { + this.childProc = childProc; + this.redirector = new OutputRedirector(childProc.getInputStream(), loggerName, + REDIRECTOR_FACTORY); + } + + void setConnection(LauncherConnection connection) { + this.connection = connection; + } + + LauncherServer getServer() { + return server; + } + + LauncherConnection getConnection() { + return connection; + } + + void setState(State s) { + if (!state.isFinal()) { + state = s; + fireEvent(false); + } else { + LOG.log(Level.WARNING, "Backend requested transition from final state {0} to {1}.", + new Object[] { state, s }); + } + } + + void setAppId(String appId) { + this.appId = appId; + fireEvent(true); + } + + private synchronized void fireEvent(boolean isInfoChanged) { + if (listeners != null) { + for (Listener l : listeners) { + if (isInfoChanged) { + l.infoChanged(this); + } else { + l.stateChanged(this); + } + } + } + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java new file mode 100644 index 0000000000..eec264909b --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.io.Closeable; +import java.io.EOFException; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.net.Socket; +import java.util.logging.Level; +import java.util.logging.Logger; + +import static org.apache.spark.launcher.LauncherProtocol.*; + +/** + * Encapsulates a connection between a launcher server and client. This takes care of the + * communication (sending and receiving messages), while processing of messages is left for + * the implementations. + */ +abstract class LauncherConnection implements Closeable, Runnable { + + private static final Logger LOG = Logger.getLogger(LauncherConnection.class.getName()); + + private final Socket socket; + private final ObjectOutputStream out; + + private volatile boolean closed; + + LauncherConnection(Socket socket) throws IOException { + this.socket = socket; + this.out = new ObjectOutputStream(socket.getOutputStream()); + this.closed = false; + } + + protected abstract void handle(Message msg) throws IOException; + + @Override + public void run() { + try { + ObjectInputStream in = new ObjectInputStream(socket.getInputStream()); + while (!closed) { + Message msg = (Message) in.readObject(); + handle(msg); + } + } catch (EOFException eof) { + // Remote side has closed the connection, just cleanup. + try { + close(); + } catch (Exception unused) { + // no-op. + } + } catch (Exception e) { + if (!closed) { + LOG.log(Level.WARNING, "Error in inbound message handling.", e); + try { + close(); + } catch (Exception unused) { + // no-op. + } + } + } + } + + protected synchronized void send(Message msg) throws IOException { + try { + CommandBuilderUtils.checkState(!closed, "Disconnected."); + out.writeObject(msg); + out.flush(); + } catch (IOException ioe) { + if (!closed) { + LOG.log(Level.WARNING, "Error when sending message.", ioe); + try { + close(); + } catch (Exception unused) { + // no-op. + } + } + throw ioe; + } + } + + @Override + public void close() throws IOException { + if (!closed) { + synchronized (this) { + if (!closed) { + closed = true; + socket.close(); + } + } + } + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherProtocol.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherProtocol.java new file mode 100644 index 0000000000..50f136497e --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherProtocol.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.io.Closeable; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.net.Socket; +import java.util.Map; + +/** + * Message definitions for the launcher communication protocol. These messages must remain + * backwards-compatible, so that the launcher can talk to older versions of Spark that support + * the protocol. + */ +final class LauncherProtocol { + + /** Environment variable where the server port is stored. */ + static final String ENV_LAUNCHER_PORT = "_SPARK_LAUNCHER_PORT"; + + /** Environment variable where the secret for connecting back to the server is stored. */ + static final String ENV_LAUNCHER_SECRET = "_SPARK_LAUNCHER_SECRET"; + + static class Message implements Serializable { + + } + + /** + * Hello message, sent from client to server. + */ + static class Hello extends Message { + + final String secret; + final String sparkVersion; + + Hello(String secret, String version) { + this.secret = secret; + this.sparkVersion = version; + } + + } + + /** + * SetAppId message, sent from client to server. + */ + static class SetAppId extends Message { + + final String appId; + + SetAppId(String appId) { + this.appId = appId; + } + + } + + /** + * SetState message, sent from client to server. + */ + static class SetState extends Message { + + final SparkAppHandle.State state; + + SetState(SparkAppHandle.State state) { + this.state = state; + } + + } + + /** + * Stop message, send from server to client to stop the application. + */ + static class Stop extends Message { + + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java new file mode 100644 index 0000000000..c5fd40816d --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -0,0 +1,341 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.io.Closeable; +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.ServerSocket; +import java.net.Socket; +import java.security.SecureRandom; +import java.util.ArrayList; +import java.util.List; +import java.util.Timer; +import java.util.TimerTask; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.atomic.AtomicLong; +import java.util.logging.Level; +import java.util.logging.Logger; + +import static org.apache.spark.launcher.LauncherProtocol.*; + +/** + * A server that listens locally for connections from client launched by the library. Each client + * has a secret that it needs to send to the server to identify itself and establish the session. + * + * I/O is currently blocking (one thread per client). Clients have a limited time to connect back + * to the server, otherwise the server will ignore the connection. + * + * === Architecture Overview === + * + * The launcher server is used when Spark apps are launched as separate processes than the calling + * app. It looks more or less like the following: + * + * ----------------------- ----------------------- + * | User App | spark-submit | Spark App | + * | | -------------------> | | + * | ------------| |------------- | + * | | | hello | | | + * | | L. Server |<----------------------| L. Backend | | + * | | | | | | + * | ------------- ----------------------- + * | | | ^ + * | v | | + * | -------------| | + * | | | <per-app channel> | + * | | App Handle |<------------------------------ + * | | | + * ----------------------- + * + * The server is started on demand and remains active while there are active or outstanding clients, + * to avoid opening too many ports when multiple clients are launched. Each client is given a unique + * secret, and have a limited amount of time to connect back + * ({@link SparkLauncher#CHILD_CONNECTION_TIMEOUT}), at which point the server will throw away + * that client's state. A client is only allowed to connect back to the server once. + * + * The launcher server listens on the localhost only, so it doesn't need access controls (aside from + * the per-app secret) nor encryption. It thus requires that the launched app has a local process + * that communicates with the server. In cluster mode, this means that the client that launches the + * application must remain alive for the duration of the application (or until the app handle is + * disconnected). + */ +class LauncherServer implements Closeable { + + private static final Logger LOG = Logger.getLogger(LauncherServer.class.getName()); + private static final String THREAD_NAME_FMT = "LauncherServer-%d"; + private static final long DEFAULT_CONNECT_TIMEOUT = 10000L; + + /** For creating secrets used for communication with child processes. */ + private static final SecureRandom RND = new SecureRandom(); + + private static volatile LauncherServer serverInstance; + + /** + * Creates a handle for an app to be launched. This method will start a server if one hasn't been + * started yet. The server is shared for multiple handles, and once all handles are disposed of, + * the server is shut down. + */ + static synchronized ChildProcAppHandle newAppHandle() throws IOException { + LauncherServer server = serverInstance != null ? serverInstance : new LauncherServer(); + server.ref(); + serverInstance = server; + + String secret = server.createSecret(); + while (server.pending.containsKey(secret)) { + secret = server.createSecret(); + } + + return server.newAppHandle(secret); + } + + static LauncherServer getServerInstance() { + return serverInstance; + } + + private final AtomicLong refCount; + private final AtomicLong threadIds; + private final ConcurrentMap<String, ChildProcAppHandle> pending; + private final List<ServerConnection> clients; + private final ServerSocket server; + private final Thread serverThread; + private final ThreadFactory factory; + private final Timer timeoutTimer; + + private volatile boolean running; + + private LauncherServer() throws IOException { + this.refCount = new AtomicLong(0); + + ServerSocket server = new ServerSocket(); + try { + server.setReuseAddress(true); + server.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0)); + + this.clients = new ArrayList<ServerConnection>(); + this.threadIds = new AtomicLong(); + this.factory = new NamedThreadFactory(THREAD_NAME_FMT); + this.pending = new ConcurrentHashMap<>(); + this.timeoutTimer = new Timer("LauncherServer-TimeoutTimer", true); + this.server = server; + this.running = true; + + this.serverThread = factory.newThread(new Runnable() { + @Override + public void run() { + acceptConnections(); + } + }); + serverThread.start(); + } catch (IOException ioe) { + close(); + throw ioe; + } catch (Exception e) { + close(); + throw new IOException(e); + } + } + + /** + * Creates a new app handle. The handle will wait for an incoming connection for a configurable + * amount of time, and if one doesn't arrive, it will transition to an error state. + */ + ChildProcAppHandle newAppHandle(String secret) { + ChildProcAppHandle handle = new ChildProcAppHandle(secret, this); + ChildProcAppHandle existing = pending.putIfAbsent(secret, handle); + CommandBuilderUtils.checkState(existing == null, "Multiple handles with the same secret."); + return handle; + } + + @Override + public void close() throws IOException { + synchronized (this) { + if (running) { + running = false; + timeoutTimer.cancel(); + server.close(); + synchronized (clients) { + List<ServerConnection> copy = new ArrayList<>(clients); + clients.clear(); + for (ServerConnection client : copy) { + client.close(); + } + } + } + } + if (serverThread != null) { + try { + serverThread.join(); + } catch (InterruptedException ie) { + // no-op + } + } + } + + void ref() { + refCount.incrementAndGet(); + } + + void unref() { + synchronized(LauncherServer.class) { + if (refCount.decrementAndGet() == 0) { + try { + close(); + } catch (IOException ioe) { + // no-op. + } finally { + serverInstance = null; + } + } + } + } + + int getPort() { + return server.getLocalPort(); + } + + /** + * Removes the client handle from the pending list (in case it's still there), and unrefs + * the server. + */ + void unregister(ChildProcAppHandle handle) { + pending.remove(handle.getSecret()); + unref(); + } + + private void acceptConnections() { + try { + while (running) { + final Socket client = server.accept(); + TimerTask timeout = new TimerTask() { + @Override + public void run() { + LOG.warning("Timed out waiting for hello message from client."); + try { + client.close(); + } catch (IOException ioe) { + // no-op. + } + } + }; + ServerConnection clientConnection = new ServerConnection(client, timeout); + Thread clientThread = factory.newThread(clientConnection); + synchronized (timeout) { + clientThread.start(); + synchronized (clients) { + clients.add(clientConnection); + } + timeoutTimer.schedule(timeout, getConnectionTimeout()); + } + } + } catch (IOException ioe) { + if (running) { + LOG.log(Level.SEVERE, "Error in accept loop.", ioe); + } + } + } + + private long getConnectionTimeout() { + String value = SparkLauncher.launcherConfig.get(SparkLauncher.CHILD_CONNECTION_TIMEOUT); + return (value != null) ? Long.parseLong(value) : DEFAULT_CONNECT_TIMEOUT; + } + + private String createSecret() { + byte[] secret = new byte[128]; + RND.nextBytes(secret); + + StringBuilder sb = new StringBuilder(); + for (byte b : secret) { + int ival = b >= 0 ? b : Byte.MAX_VALUE - b; + if (ival < 0x10) { + sb.append("0"); + } + sb.append(Integer.toHexString(ival)); + } + return sb.toString(); + } + + private class ServerConnection extends LauncherConnection { + + private TimerTask timeout; + private ChildProcAppHandle handle; + + ServerConnection(Socket socket, TimerTask timeout) throws IOException { + super(socket); + this.timeout = timeout; + } + + @Override + protected void handle(Message msg) throws IOException { + try { + if (msg instanceof Hello) { + synchronized (timeout) { + timeout.cancel(); + } + timeout = null; + Hello hello = (Hello) msg; + ChildProcAppHandle handle = pending.remove(hello.secret); + if (handle != null) { + handle.setState(SparkAppHandle.State.CONNECTED); + handle.setConnection(this); + this.handle = handle; + } else { + throw new IllegalArgumentException("Received Hello for unknown client."); + } + } else { + if (handle == null) { + throw new IllegalArgumentException("Expected hello, got: " + + msg != null ? msg.getClass().getName() : null); + } + if (msg instanceof SetAppId) { + SetAppId set = (SetAppId) msg; + handle.setAppId(set.appId); + } else if (msg instanceof SetState) { + handle.setState(((SetState)msg).state); + } else { + throw new IllegalArgumentException("Invalid message: " + + msg != null ? msg.getClass().getName() : null); + } + } + } catch (Exception e) { + LOG.log(Level.INFO, "Error handling message from client.", e); + if (timeout != null) { + timeout.cancel(); + } + close(); + } finally { + timeoutTimer.purge(); + } + } + + @Override + public void close() throws IOException { + synchronized (clients) { + clients.remove(this); + } + super.close(); + if (handle != null) { + handle.disconnect(); + } + } + + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/NamedThreadFactory.java b/launcher/src/main/java/org/apache/spark/launcher/NamedThreadFactory.java new file mode 100644 index 0000000000..995f4d73da --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/NamedThreadFactory.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.atomic.AtomicLong; + +class NamedThreadFactory implements ThreadFactory { + + private final String nameFormat; + private final AtomicLong threadIds; + + NamedThreadFactory(String nameFormat) { + this.nameFormat = nameFormat; + this.threadIds = new AtomicLong(); + } + + @Override + public Thread newThread(Runnable r) { + Thread t = new Thread(r, String.format(nameFormat, threadIds.incrementAndGet())); + t.setDaemon(true); + return t; + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/OutputRedirector.java b/launcher/src/main/java/org/apache/spark/launcher/OutputRedirector.java new file mode 100644 index 0000000000..6e7120167d --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/OutputRedirector.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.io.BufferedReader; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.IOException; +import java.util.concurrent.ThreadFactory; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * Redirects lines read from a given input stream to a j.u.l.Logger (at INFO level). + */ +class OutputRedirector { + + private final BufferedReader reader; + private final Logger sink; + private final Thread thread; + + private volatile boolean active; + + OutputRedirector(InputStream in, ThreadFactory tf) { + this(in, OutputRedirector.class.getName(), tf); + } + + OutputRedirector(InputStream in, String loggerName, ThreadFactory tf) { + this.active = true; + this.reader = new BufferedReader(new InputStreamReader(in)); + this.thread = tf.newThread(new Runnable() { + @Override + public void run() { + redirect(); + } + }); + this.sink = Logger.getLogger(loggerName); + thread.start(); + } + + private void redirect() { + try { + String line; + while ((line = reader.readLine()) != null) { + if (active) { + sink.info(line.replaceFirst("\\s*$", "")); + } + } + } catch (IOException e) { + sink.log(Level.FINE, "Error reading child process output.", e); + } + } + + /** + * This method just stops the output of the process from showing up in the local logs. + * The child's output will still be read (and, thus, the redirect thread will still be + * alive) to avoid the child process hanging because of lack of output buffer. + */ + void stop() { + active = false; + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java new file mode 100644 index 0000000000..2896a91d5e --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +/** + * A handle to a running Spark application. + * <p/> + * Provides runtime information about the underlying Spark application, and actions to control it. + * + * @since 1.6.0 + */ +public interface SparkAppHandle { + + /** + * Represents the application's state. A state can be "final", in which case it will not change + * after it's reached, and means the application is not running anymore. + * + * @since 1.6.0 + */ + public enum State { + /** The application has not reported back yet. */ + UNKNOWN(false), + /** The application has connected to the handle. */ + CONNECTED(false), + /** The application has been submitted to the cluster. */ + SUBMITTED(false), + /** The application is running. */ + RUNNING(false), + /** The application finished with a successful status. */ + FINISHED(true), + /** The application finished with a failed status. */ + FAILED(true), + /** The application was killed. */ + KILLED(true); + + private final boolean isFinal; + + State(boolean isFinal) { + this.isFinal = isFinal; + } + + /** + * Whether this state is a final state, meaning the application is not running anymore + * once it's reached. + */ + public boolean isFinal() { + return isFinal; + } + } + + /** + * Adds a listener to be notified of changes to the handle's information. Listeners will be called + * from the thread processing updates from the application, so they should avoid blocking or + * long-running operations. + * + * @param l Listener to add. + */ + void addListener(Listener l); + + /** Returns the current application state. */ + State getState(); + + /** Returns the application ID, or <code>null</code> if not yet known. */ + String getAppId(); + + /** + * Asks the application to stop. This is best-effort, since the application may fail to receive + * or act on the command. Callers should watch for a state transition that indicates the + * application has really stopped. + */ + void stop(); + + /** + * Tries to kill the underlying application. Implies {@link #disconnect()}. This will not send + * a {@link #stop()} message to the application, so it's recommended that users first try to + * stop the application cleanly and only resort to this method if that fails. + */ + void kill(); + + /** + * Disconnects the handle from the application, without stopping it. After this method is called, + * the handle will not be able to communicate with the application anymore. + */ + void disconnect(); + + /** + * Listener for updates to a handle's state. The callbacks do not receive information about + * what exactly has changed, just that an update has occurred. + * + * @since 1.6.0 + */ + public interface Listener { + + /** + * Callback for changes in the handle's state. + * + * @param handle The updated handle. + * @see {@link SparkAppHandle#getState()} + */ + void stateChanged(SparkAppHandle handle); + + /** + * Callback for changes in any information that is not the handle's state. + * + * @param handle The updated handle. + */ + void infoChanged(SparkAppHandle handle); + + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java index 57993405e4..5d74b37033 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java @@ -21,8 +21,10 @@ import java.io.File; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; import static org.apache.spark.launcher.CommandBuilderUtils.*; @@ -58,6 +60,33 @@ public class SparkLauncher { /** Configuration key for the number of executor CPU cores. */ public static final String EXECUTOR_CORES = "spark.executor.cores"; + /** Logger name to use when launching a child process. */ + public static final String CHILD_PROCESS_LOGGER_NAME = "spark.launcher.childProcLoggerName"; + + /** + * Maximum time (in ms) to wait for a child process to connect back to the launcher server + * when using @link{#start()}. + */ + public static final String CHILD_CONNECTION_TIMEOUT = "spark.launcher.childConectionTimeout"; + + /** Used internally to create unique logger names. */ + private static final AtomicInteger COUNTER = new AtomicInteger(); + + static final Map<String, String> launcherConfig = new HashMap<String, String>(); + + /** + * Set a configuration value for the launcher library. These config values do not affect the + * launched application, but rather the behavior of the launcher library itself when managing + * applications. + * + * @since 1.6.0 + * @param name Config name. + * @param value Config value. + */ + public static void setConfig(String name, String value) { + launcherConfig.put(name, value); + } + // Visible for testing. final SparkSubmitCommandBuilder builder; @@ -109,7 +138,7 @@ public class SparkLauncher { */ public SparkLauncher setPropertiesFile(String path) { checkNotNull(path, "path"); - builder.propertiesFile = path; + builder.setPropertiesFile(path); return this; } @@ -197,6 +226,7 @@ public class SparkLauncher { * Use this method with caution. It is possible to create an invalid Spark command by passing * unknown arguments to this method, since those are allowed for forward compatibility. * + * @since 1.5.0 * @param arg Argument to add. * @return This launcher. */ @@ -218,6 +248,7 @@ public class SparkLauncher { * Use this method with caution. It is possible to create an invalid Spark command by passing * unknown arguments to this method, since those are allowed for forward compatibility. * + * @since 1.5.0 * @param name Name of argument to add. * @param value Value of the argument. * @return This launcher. @@ -319,10 +350,81 @@ public class SparkLauncher { /** * Launches a sub-process that will start the configured Spark application. + * <p/> + * The {@link #startApplication(SparkAppHandle.Listener...)} method is preferred when launching + * Spark, since it provides better control of the child application. * * @return A process handle for the Spark app. */ public Process launch() throws IOException { + return createBuilder().start(); + } + + /** + * Starts a Spark application. + * <p/> + * This method returns a handle that provides information about the running application and can + * be used to do basic interaction with it. + * <p/> + * The returned handle assumes that the application will instantiate a single SparkContext + * during its lifetime. Once that context reports a final state (one that indicates the + * SparkContext has stopped), the handle will not perform new state transitions, so anything + * that happens after that cannot be monitored. If the underlying application is launched as + * a child process, {@link SparkAppHandle#kill()} can still be used to kill the child process. + * <p/> + * Currently, all applications are launched as child processes. The child's stdout and stderr + * are merged and written to a logger (see <code>java.util.logging</code>). The logger's name + * can be defined by setting {@link #CHILD_PROCESS_LOGGER_NAME} in the app's configuration. If + * that option is not set, the code will try to derive a name from the application's name or + * main class / script file. If those cannot be determined, an internal, unique name will be + * used. In all cases, the logger name will start with "org.apache.spark.launcher.app", to fit + * more easily into the configuration of commonly-used logging systems. + * + * @since 1.6.0 + * @param listeners Listeners to add to the handle before the app is launched. + * @return A handle for the launched application. + */ + public SparkAppHandle startApplication(SparkAppHandle.Listener... listeners) throws IOException { + ChildProcAppHandle handle = LauncherServer.newAppHandle(); + for (SparkAppHandle.Listener l : listeners) { + handle.addListener(l); + } + + String appName = builder.getEffectiveConfig().get(CHILD_PROCESS_LOGGER_NAME); + if (appName == null) { + if (builder.appName != null) { + appName = builder.appName; + } else if (builder.mainClass != null) { + int dot = builder.mainClass.lastIndexOf("."); + if (dot >= 0 && dot < builder.mainClass.length() - 1) { + appName = builder.mainClass.substring(dot + 1, builder.mainClass.length()); + } else { + appName = builder.mainClass; + } + } else if (builder.appResource != null) { + appName = new File(builder.appResource).getName(); + } else { + appName = String.valueOf(COUNTER.incrementAndGet()); + } + } + + String loggerPrefix = getClass().getPackage().getName(); + String loggerName = String.format("%s.app.%s", loggerPrefix, appName); + ProcessBuilder pb = createBuilder().redirectErrorStream(true); + pb.environment().put(LauncherProtocol.ENV_LAUNCHER_PORT, + String.valueOf(LauncherServer.getServerInstance().getPort())); + pb.environment().put(LauncherProtocol.ENV_LAUNCHER_SECRET, handle.getSecret()); + try { + handle.setChildProc(pb.start(), loggerName); + } catch (IOException ioe) { + handle.kill(); + throw ioe; + } + + return handle; + } + + private ProcessBuilder createBuilder() { List<String> cmd = new ArrayList<String>(); String script = isWindows() ? "spark-submit.cmd" : "spark-submit"; cmd.add(join(File.separator, builder.getSparkHome(), "bin", script)); @@ -343,7 +445,7 @@ public class SparkLauncher { for (Map.Entry<String, String> e : builder.childEnv.entrySet()) { pb.environment().put(e.getKey(), e.getValue()); } - return pb.start(); + return pb; } private static class ArgumentValidator extends SparkSubmitOptionParser { diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index fc87814a59..39b46e0db8 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -188,10 +188,9 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { // Load the properties file and check whether spark-submit will be running the app's driver // or just launching a cluster app. When running the driver, the JVM's argument will be // modified to cover the driver's configuration. - Properties props = loadPropertiesFile(); - boolean isClientMode = isClientMode(props); - String extraClassPath = isClientMode ? - firstNonEmptyValue(SparkLauncher.DRIVER_EXTRA_CLASSPATH, conf, props) : null; + Map<String, String> config = getEffectiveConfig(); + boolean isClientMode = isClientMode(config); + String extraClassPath = isClientMode ? config.get(SparkLauncher.DRIVER_EXTRA_CLASSPATH) : null; List<String> cmd = buildJavaCommand(extraClassPath); // Take Thrift Server as daemon @@ -212,14 +211,13 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { // Take Thrift Server as daemon String tsMemory = isThriftServer(mainClass) ? System.getenv("SPARK_DAEMON_MEMORY") : null; - String memory = firstNonEmpty(tsMemory, - firstNonEmptyValue(SparkLauncher.DRIVER_MEMORY, conf, props), + String memory = firstNonEmpty(tsMemory, config.get(SparkLauncher.DRIVER_MEMORY), System.getenv("SPARK_DRIVER_MEMORY"), System.getenv("SPARK_MEM"), DEFAULT_MEM); cmd.add("-Xms" + memory); cmd.add("-Xmx" + memory); - addOptionString(cmd, firstNonEmptyValue(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, conf, props)); + addOptionString(cmd, config.get(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS)); mergeEnvPathList(env, getLibPathEnvName(), - firstNonEmptyValue(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH, conf, props)); + config.get(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH)); } addPermGenSizeOpt(cmd); @@ -281,9 +279,8 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { private void constructEnvVarArgs( Map<String, String> env, String submitArgsEnvVariable) throws IOException { - Properties props = loadPropertiesFile(); mergeEnvPathList(env, getLibPathEnvName(), - firstNonEmptyValue(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH, conf, props)); + getEffectiveConfig().get(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH)); StringBuilder submitArgs = new StringBuilder(); for (String arg : buildSparkSubmitArgs()) { @@ -295,9 +292,8 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { env.put(submitArgsEnvVariable, submitArgs.toString()); } - - private boolean isClientMode(Properties userProps) { - String userMaster = firstNonEmpty(master, (String) userProps.get(SparkLauncher.SPARK_MASTER)); + private boolean isClientMode(Map<String, String> userProps) { + String userMaster = firstNonEmpty(master, userProps.get(SparkLauncher.SPARK_MASTER)); // Default master is "local[*]", so assume client mode in that case. return userMaster == null || "client".equals(deployMode) || diff --git a/launcher/src/main/java/org/apache/spark/launcher/package-info.java b/launcher/src/main/java/org/apache/spark/launcher/package-info.java index 7c97dba511..d1ac39bdc7 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/package-info.java +++ b/launcher/src/main/java/org/apache/spark/launcher/package-info.java @@ -17,17 +17,42 @@ /** * Library for launching Spark applications. - * + * * <p> * This library allows applications to launch Spark programmatically. There's only one entry * point to the library - the {@link org.apache.spark.launcher.SparkLauncher} class. * </p> * * <p> - * To launch a Spark application, just instantiate a {@link org.apache.spark.launcher.SparkLauncher} - * and configure the application to run. For example: + * The {@link org.apache.spark.launcher.SparkLauncher#startApplication( + * org.apache.spark.launcher.SparkAppHandle.Listener...)} can be used to start Spark and provide + * a handle to monitor and control the running application: * </p> - * + * + * <pre> + * {@code + * import org.apache.spark.launcher.SparkAppHandle; + * import org.apache.spark.launcher.SparkLauncher; + * + * public class MyLauncher { + * public static void main(String[] args) throws Exception { + * SparkAppHandle handle = new SparkLauncher() + * .setAppResource("/my/app.jar") + * .setMainClass("my.spark.app.Main") + * .setMaster("local") + * .setConf(SparkLauncher.DRIVER_MEMORY, "2g") + * .startApplication(); + * // Use handle API to monitor / control application. + * } + * } + * } + * </pre> + * + * <p> + * It's also possible to launch a raw child process, using the + * {@link org.apache.spark.launcher.SparkLauncher#launch()} method: + * </p> + * * <pre> * {@code * import org.apache.spark.launcher.SparkLauncher; @@ -45,5 +70,10 @@ * } * } * </pre> + * + * <p>This method requires the calling code to manually manage the child process, including its + * output streams (to avoid possible deadlocks). It's recommended that + * {@link org.apache.spark.launcher.SparkLauncher#startApplication( + * org.apache.spark.launcher.SparkAppHandle.Listener...)} be used instead.</p> */ package org.apache.spark.launcher; diff --git a/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java new file mode 100644 index 0000000000..23e2c64d6d --- /dev/null +++ b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import org.slf4j.bridge.SLF4JBridgeHandler; + +/** + * Handles configuring the JUL -> SLF4J bridge. + */ +class BaseSuite { + + static { + SLF4JBridgeHandler.removeHandlersForRootLogger(); + SLF4JBridgeHandler.install(); + } + +} diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java new file mode 100644 index 0000000000..27cd1061a1 --- /dev/null +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.io.Closeable; +import java.io.IOException; +import java.net.InetAddress; +import java.net.Socket; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +import org.junit.Test; +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +import static org.apache.spark.launcher.LauncherProtocol.*; + +public class LauncherServerSuite extends BaseSuite { + + @Test + public void testLauncherServerReuse() throws Exception { + ChildProcAppHandle handle1 = null; + ChildProcAppHandle handle2 = null; + ChildProcAppHandle handle3 = null; + + try { + handle1 = LauncherServer.newAppHandle(); + handle2 = LauncherServer.newAppHandle(); + LauncherServer server1 = handle1.getServer(); + assertSame(server1, handle2.getServer()); + + handle1.kill(); + handle2.kill(); + + handle3 = LauncherServer.newAppHandle(); + assertNotSame(server1, handle3.getServer()); + + handle3.kill(); + + assertNull(LauncherServer.getServerInstance()); + } finally { + kill(handle1); + kill(handle2); + kill(handle3); + } + } + + @Test + public void testCommunication() throws Exception { + ChildProcAppHandle handle = LauncherServer.newAppHandle(); + TestClient client = null; + try { + Socket s = new Socket(InetAddress.getLoopbackAddress(), + LauncherServer.getServerInstance().getPort()); + + final Object waitLock = new Object(); + handle.addListener(new SparkAppHandle.Listener() { + @Override + public void stateChanged(SparkAppHandle handle) { + wakeUp(); + } + + @Override + public void infoChanged(SparkAppHandle handle) { + wakeUp(); + } + + private void wakeUp() { + synchronized (waitLock) { + waitLock.notifyAll(); + } + } + }); + + client = new TestClient(s); + synchronized (waitLock) { + client.send(new Hello(handle.getSecret(), "1.4.0")); + waitLock.wait(TimeUnit.SECONDS.toMillis(10)); + } + + // Make sure the server matched the client to the handle. + assertNotNull(handle.getConnection()); + + synchronized (waitLock) { + client.send(new SetAppId("app-id")); + waitLock.wait(TimeUnit.SECONDS.toMillis(10)); + } + assertEquals("app-id", handle.getAppId()); + + synchronized (waitLock) { + client.send(new SetState(SparkAppHandle.State.RUNNING)); + waitLock.wait(TimeUnit.SECONDS.toMillis(10)); + } + assertEquals(SparkAppHandle.State.RUNNING, handle.getState()); + + handle.stop(); + Message stopMsg = client.inbound.poll(10, TimeUnit.SECONDS); + assertTrue(stopMsg instanceof Stop); + } finally { + kill(handle); + close(client); + client.clientThread.join(); + } + } + + @Test + public void testTimeout() throws Exception { + final long TEST_TIMEOUT = 10L; + + ChildProcAppHandle handle = null; + TestClient client = null; + try { + SparkLauncher.setConfig(SparkLauncher.CHILD_CONNECTION_TIMEOUT, String.valueOf(TEST_TIMEOUT)); + + handle = LauncherServer.newAppHandle(); + + Socket s = new Socket(InetAddress.getLoopbackAddress(), + LauncherServer.getServerInstance().getPort()); + client = new TestClient(s); + + Thread.sleep(TEST_TIMEOUT * 10); + try { + client.send(new Hello(handle.getSecret(), "1.4.0")); + fail("Expected exception caused by connection timeout."); + } catch (IllegalStateException e) { + // Expected. + } + } finally { + SparkLauncher.launcherConfig.remove(SparkLauncher.CHILD_CONNECTION_TIMEOUT); + kill(handle); + close(client); + } + } + + private void kill(SparkAppHandle handle) { + if (handle != null) { + handle.kill(); + } + } + + private void close(Closeable c) { + if (c != null) { + try { + c.close(); + } catch (Exception e) { + // no-op. + } + } + } + + private static class TestClient extends LauncherConnection { + + final BlockingQueue<Message> inbound; + final Thread clientThread; + + TestClient(Socket s) throws IOException { + super(s); + this.inbound = new LinkedBlockingQueue<Message>(); + this.clientThread = new Thread(this); + clientThread.setName("TestClient"); + clientThread.setDaemon(true); + clientThread.start(); + } + + @Override + protected void handle(Message msg) throws IOException { + inbound.offer(msg); + } + + } + +} diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java index 7329ac9f7f..d5397b0685 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java @@ -30,7 +30,7 @@ import org.junit.BeforeClass; import org.junit.Test; import static org.junit.Assert.*; -public class SparkSubmitCommandBuilderSuite { +public class SparkSubmitCommandBuilderSuite extends BaseSuite { private static File dummyPropsFile; private static SparkSubmitOptionParser parser; @@ -161,7 +161,7 @@ public class SparkSubmitCommandBuilderSuite { launcher.appResource = "/foo"; launcher.appName = "MyApp"; launcher.mainClass = "my.Class"; - launcher.propertiesFile = dummyPropsFile.getAbsolutePath(); + launcher.setPropertiesFile(dummyPropsFile.getAbsolutePath()); launcher.appArgs.add("foo"); launcher.appArgs.add("bar"); launcher.conf.put(SparkLauncher.DRIVER_MEMORY, "1g"); diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitOptionParserSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitOptionParserSuite.java index f3d2109917..3ee5b8cf96 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitOptionParserSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitOptionParserSuite.java @@ -28,7 +28,7 @@ import static org.mockito.Mockito.*; import static org.apache.spark.launcher.SparkSubmitOptionParser.*; -public class SparkSubmitOptionParserSuite { +public class SparkSubmitOptionParserSuite extends BaseSuite { private SparkSubmitOptionParser parser; diff --git a/launcher/src/test/resources/log4j.properties b/launcher/src/test/resources/log4j.properties index 67a6a98217..c64b1565e1 100644 --- a/launcher/src/test/resources/log4j.properties +++ b/launcher/src/test/resources/log4j.properties @@ -16,16 +16,19 @@ # # Set everything to be logged to the file core/target/unit-tests.log -log4j.rootCategory=INFO, file +test.appender=file +log4j.rootCategory=INFO, ${test.appender} log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false - -# Some tests will set "test.name" to avoid overwriting the main log file. -log4j.appender.file.file=target/unit-tests${test.name}.log - +log4j.appender.file.file=target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n +log4j.appender.childproc=org.apache.log4j.ConsoleAppender +log4j.appender.childproc.target=System.err +log4j.appender.childproc.layout=org.apache.log4j.PatternLayout +log4j.appender.childproc.layout.ConversionPattern=%t: %m%n + # Ignore messages below warning level from Jetty, because it's a bit verbose log4j.logger.org.spark-project.jetty=WARN org.spark-project.jetty.LEVEL=WARN diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index eb3b7fb885..cec81b9406 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -55,8 +55,8 @@ import org.apache.hadoop.yarn.exceptions.ApplicationNotFoundException import org.apache.hadoop.yarn.util.Records import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkException} +import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle, YarnCommandBuilderUtils} import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.launcher.YarnCommandBuilderUtils import org.apache.spark.util.Utils private[spark] class Client( @@ -70,8 +70,6 @@ private[spark] class Client( def this(clientArgs: ClientArguments, spConf: SparkConf) = this(clientArgs, SparkHadoopUtil.get.newConfiguration(spConf), spConf) - def this(clientArgs: ClientArguments) = this(clientArgs, new SparkConf()) - private val yarnClient = YarnClient.createYarnClient private val yarnConf = new YarnConfiguration(hadoopConf) private var credentials: Credentials = null @@ -84,10 +82,27 @@ private[spark] class Client( private var principal: String = null private var keytab: String = null + private val launcherBackend = new LauncherBackend() { + override def onStopRequest(): Unit = { + if (isClusterMode && appId != null) { + yarnClient.killApplication(appId) + } else { + setState(SparkAppHandle.State.KILLED) + stop() + } + } + } private val fireAndForget = isClusterMode && !sparkConf.getBoolean("spark.yarn.submit.waitAppCompletion", true) + private var appId: ApplicationId = null + + def reportLauncherState(state: SparkAppHandle.State): Unit = { + launcherBackend.setState(state) + } + def stop(): Unit = { + launcherBackend.close() yarnClient.stop() // Unset YARN mode system env variable, to allow switching between cluster types. System.clearProperty("SPARK_YARN_MODE") @@ -103,6 +118,7 @@ private[spark] class Client( def submitApplication(): ApplicationId = { var appId: ApplicationId = null try { + launcherBackend.connect() // Setup the credentials before doing anything else, // so we have don't have issues at any point. setupCredentials() @@ -116,6 +132,8 @@ private[spark] class Client( val newApp = yarnClient.createApplication() val newAppResponse = newApp.getNewApplicationResponse() appId = newAppResponse.getApplicationId() + reportLauncherState(SparkAppHandle.State.SUBMITTED) + launcherBackend.setAppId(appId.toString()) // Verify whether the cluster has enough resources for our AM verifyClusterResources(newAppResponse) @@ -881,6 +899,20 @@ private[spark] class Client( } } + if (lastState != state) { + state match { + case YarnApplicationState.RUNNING => + reportLauncherState(SparkAppHandle.State.RUNNING) + case YarnApplicationState.FINISHED => + reportLauncherState(SparkAppHandle.State.FINISHED) + case YarnApplicationState.FAILED => + reportLauncherState(SparkAppHandle.State.FAILED) + case YarnApplicationState.KILLED => + reportLauncherState(SparkAppHandle.State.KILLED) + case _ => + } + } + if (state == YarnApplicationState.FINISHED || state == YarnApplicationState.FAILED || state == YarnApplicationState.KILLED) { @@ -928,8 +960,8 @@ private[spark] class Client( * throw an appropriate SparkException. */ def run(): Unit = { - val appId = submitApplication() - if (fireAndForget) { + this.appId = submitApplication() + if (!launcherBackend.isConnected() && fireAndForget) { val report = getApplicationReport(appId) val state = report.getYarnApplicationState logInfo(s"Application report for $appId (state: $state)") @@ -971,6 +1003,7 @@ private[spark] class Client( } object Client extends Logging { + def main(argStrings: Array[String]) { if (!sys.props.contains("SPARK_SUBMIT")) { logWarning("WARNING: This client is deprecated and will be removed in a " + diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 36d5759554..20771f6554 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState} import org.apache.spark.{SparkException, Logging, SparkContext} import org.apache.spark.deploy.yarn.{Client, ClientArguments, YarnSparkHadoopUtil} +import org.apache.spark.launcher.SparkAppHandle import org.apache.spark.scheduler.TaskSchedulerImpl private[spark] class YarnClientSchedulerBackend( @@ -177,6 +178,15 @@ private[spark] class YarnClientSchedulerBackend( if (monitorThread != null) { monitorThread.stopMonitor() } + + // Report a final state to the launcher if one is connected. This is needed since in client + // mode this backend doesn't let the app monitor loop run to completion, so it does not report + // the final state itself. + // + // Note: there's not enough information at this point to provide a better final state, + // so assume the application was successful. + client.reportLauncherState(SparkAppHandle.State.FINISHED) + super.stop() YarnSparkHadoopUtil.get.stopExecutorDelegationTokenRenewer() client.stop() diff --git a/yarn/src/test/resources/log4j.properties b/yarn/src/test/resources/log4j.properties index 6b8a5dbf63..6b9a799954 100644 --- a/yarn/src/test/resources/log4j.properties +++ b/yarn/src/test/resources/log4j.properties @@ -23,6 +23,9 @@ log4j.appender.file.file=target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n -# Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN +# Ignore messages below warning level from a few verbose libraries. +log4j.logger.com.sun.jersey=WARN log4j.logger.org.apache.hadoop=WARN +log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.mortbay=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala index 17c59ff06e..12494b0105 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala @@ -22,15 +22,18 @@ import java.util.Properties import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ +import scala.concurrent.duration._ +import scala.language.postfixOps import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.server.MiniYARNCluster import org.scalatest.{BeforeAndAfterAll, Matchers} +import org.scalatest.concurrent.Eventually._ import org.apache.spark._ -import org.apache.spark.launcher.TestClasspathBuilder +import org.apache.spark.launcher._ import org.apache.spark.util.Utils abstract class BaseYarnClusterSuite @@ -46,13 +49,14 @@ abstract class BaseYarnClusterSuite |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n |log4j.logger.org.apache.hadoop=WARN |log4j.logger.org.eclipse.jetty=WARN + |log4j.logger.org.mortbay=WARN |log4j.logger.org.spark-project.jetty=WARN """.stripMargin private var yarnCluster: MiniYARNCluster = _ protected var tempDir: File = _ private var fakeSparkJar: File = _ - private var hadoopConfDir: File = _ + protected var hadoopConfDir: File = _ private var logConfDir: File = _ def newYarnConfig(): YarnConfiguration @@ -120,15 +124,77 @@ abstract class BaseYarnClusterSuite clientMode: Boolean, klass: String, appArgs: Seq[String] = Nil, - sparkArgs: Seq[String] = Nil, + sparkArgs: Seq[(String, String)] = Nil, extraClassPath: Seq[String] = Nil, extraJars: Seq[String] = Nil, extraConf: Map[String, String] = Map(), - extraEnv: Map[String, String] = Map()): Unit = { + extraEnv: Map[String, String] = Map()): SparkAppHandle.State = { val master = if (clientMode) "yarn-client" else "yarn-cluster" - val props = new Properties() + val propsFile = createConfFile(extraClassPath = extraClassPath, extraConf = extraConf) + val env = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath()) ++ extraEnv + + val launcher = new SparkLauncher(env.asJava) + if (klass.endsWith(".py")) { + launcher.setAppResource(klass) + } else { + launcher.setMainClass(klass) + launcher.setAppResource(fakeSparkJar.getAbsolutePath()) + } + launcher.setSparkHome(sys.props("spark.test.home")) + .setMaster(master) + .setConf("spark.executor.instances", "1") + .setPropertiesFile(propsFile) + .addAppArgs(appArgs.toArray: _*) + + sparkArgs.foreach { case (name, value) => + if (value != null) { + launcher.addSparkArg(name, value) + } else { + launcher.addSparkArg(name) + } + } + extraJars.foreach(launcher.addJar) - props.setProperty("spark.yarn.jar", "local:" + fakeSparkJar.getAbsolutePath()) + val handle = launcher.startApplication() + try { + eventually(timeout(2 minutes), interval(1 second)) { + assert(handle.getState().isFinal()) + } + } finally { + handle.kill() + } + + handle.getState() + } + + /** + * This is a workaround for an issue with yarn-cluster mode: the Client class will not provide + * any sort of error when the job process finishes successfully, but the job itself fails. So + * the tests enforce that something is written to a file after everything is ok to indicate + * that the job succeeded. + */ + protected def checkResult(finalState: SparkAppHandle.State, result: File): Unit = { + checkResult(finalState, result, "success") + } + + protected def checkResult( + finalState: SparkAppHandle.State, + result: File, + expected: String): Unit = { + finalState should be (SparkAppHandle.State.FINISHED) + val resultString = Files.toString(result, UTF_8) + resultString should be (expected) + } + + protected def mainClassName(klass: Class[_]): String = { + klass.getName().stripSuffix("$") + } + + protected def createConfFile( + extraClassPath: Seq[String] = Nil, + extraConf: Map[String, String] = Map()): String = { + val props = new Properties() + props.put("spark.yarn.jar", "local:" + fakeSparkJar.getAbsolutePath()) val testClasspath = new TestClasspathBuilder() .buildClassPath( @@ -138,69 +204,28 @@ abstract class BaseYarnClusterSuite .asScala .mkString(File.pathSeparator) - props.setProperty("spark.driver.extraClassPath", testClasspath) - props.setProperty("spark.executor.extraClassPath", testClasspath) + props.put("spark.driver.extraClassPath", testClasspath) + props.put("spark.executor.extraClassPath", testClasspath) // SPARK-4267: make sure java options are propagated correctly. props.setProperty("spark.driver.extraJavaOptions", "-Dfoo=\"one two three\"") props.setProperty("spark.executor.extraJavaOptions", "-Dfoo=\"one two three\"") - yarnCluster.getConfig.asScala.foreach { e => + yarnCluster.getConfig().asScala.foreach { e => props.setProperty("spark.hadoop." + e.getKey(), e.getValue()) } - sys.props.foreach { case (k, v) => if (k.startsWith("spark.")) { props.setProperty(k, v) } } - extraConf.foreach { case (k, v) => props.setProperty(k, v) } val propsFile = File.createTempFile("spark", ".properties", tempDir) val writer = new OutputStreamWriter(new FileOutputStream(propsFile), UTF_8) props.store(writer, "Spark properties.") writer.close() - - val extraJarArgs = if (extraJars.nonEmpty) Seq("--jars", extraJars.mkString(",")) else Nil - val mainArgs = - if (klass.endsWith(".py")) { - Seq(klass) - } else { - Seq("--class", klass, fakeSparkJar.getAbsolutePath()) - } - val argv = - Seq( - new File(sys.props("spark.test.home"), "bin/spark-submit").getAbsolutePath(), - "--master", master, - "--num-executors", "1", - "--properties-file", propsFile.getAbsolutePath()) ++ - extraJarArgs ++ - sparkArgs ++ - mainArgs ++ - appArgs - - Utils.executeAndGetOutput(argv, - extraEnvironment = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath()) ++ extraEnv) - } - - /** - * This is a workaround for an issue with yarn-cluster mode: the Client class will not provide - * any sort of error when the job process finishes successfully, but the job itself fails. So - * the tests enforce that something is written to a file after everything is ok to indicate - * that the job succeeded. - */ - protected def checkResult(result: File): Unit = { - checkResult(result, "success") - } - - protected def checkResult(result: File, expected: String): Unit = { - val resultString = Files.toString(result, UTF_8) - resultString should be (expected) - } - - protected def mainClassName(klass: Class[_]): String = { - klass.getName().stripSuffix("$") + propsFile.getAbsolutePath() } } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index f1601cd161..d1cd0c89b5 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -19,16 +19,20 @@ package org.apache.spark.deploy.yarn import java.io.File import java.net.URL +import java.util.{HashMap => JHashMap, Properties} import scala.collection.mutable +import scala.concurrent.duration._ +import scala.language.postfixOps import com.google.common.base.Charsets.UTF_8 import com.google.common.io.{ByteStreams, Files} import org.apache.hadoop.yarn.conf.YarnConfiguration import org.scalatest.Matchers +import org.scalatest.concurrent.Eventually._ import org.apache.spark._ -import org.apache.spark.launcher.TestClasspathBuilder +import org.apache.spark.launcher._ import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart, SparkListenerExecutorAdded} import org.apache.spark.scheduler.cluster.ExecutorInfo @@ -82,10 +86,8 @@ class YarnClusterSuite extends BaseYarnClusterSuite { test("run Spark in yarn-cluster mode unsuccessfully") { // Don't provide arguments so the driver will fail. - val exception = intercept[SparkException] { - runSpark(false, mainClassName(YarnClusterDriver.getClass)) - fail("Spark application should have failed.") - } + val finalState = runSpark(false, mainClassName(YarnClusterDriver.getClass)) + finalState should be (SparkAppHandle.State.FAILED) } test("run Python application in yarn-client mode") { @@ -104,11 +106,42 @@ class YarnClusterSuite extends BaseYarnClusterSuite { testUseClassPathFirst(false) } + test("monitor app using launcher library") { + val env = new JHashMap[String, String]() + env.put("YARN_CONF_DIR", hadoopConfDir.getAbsolutePath()) + + val propsFile = createConfFile() + val handle = new SparkLauncher(env) + .setSparkHome(sys.props("spark.test.home")) + .setConf("spark.ui.enabled", "false") + .setPropertiesFile(propsFile) + .setMaster("yarn-client") + .setAppResource("spark-internal") + .setMainClass(mainClassName(YarnLauncherTestApp.getClass)) + .startApplication() + + try { + eventually(timeout(30 seconds), interval(100 millis)) { + handle.getState() should be (SparkAppHandle.State.RUNNING) + } + + handle.getAppId() should not be (null) + handle.getAppId() should startWith ("application_") + handle.stop() + + eventually(timeout(30 seconds), interval(100 millis)) { + handle.getState() should be (SparkAppHandle.State.KILLED) + } + } finally { + handle.kill() + } + } + private def testBasicYarnApp(clientMode: Boolean): Unit = { val result = File.createTempFile("result", null, tempDir) - runSpark(clientMode, mainClassName(YarnClusterDriver.getClass), + val finalState = runSpark(clientMode, mainClassName(YarnClusterDriver.getClass), appArgs = Seq(result.getAbsolutePath())) - checkResult(result) + checkResult(finalState, result) } private def testPySpark(clientMode: Boolean): Unit = { @@ -143,11 +176,11 @@ class YarnClusterSuite extends BaseYarnClusterSuite { val pyFiles = Seq(pyModule.getAbsolutePath(), mod2Archive.getPath()).mkString(",") val result = File.createTempFile("result", null, tempDir) - runSpark(clientMode, primaryPyFile.getAbsolutePath(), - sparkArgs = Seq("--py-files", pyFiles), + val finalState = runSpark(clientMode, primaryPyFile.getAbsolutePath(), + sparkArgs = Seq("--py-files" -> pyFiles), appArgs = Seq(result.getAbsolutePath()), extraEnv = extraEnv) - checkResult(result) + checkResult(finalState, result) } private def testUseClassPathFirst(clientMode: Boolean): Unit = { @@ -156,15 +189,15 @@ class YarnClusterSuite extends BaseYarnClusterSuite { val userJar = TestUtils.createJarWithFiles(Map("test.resource" -> "OVERRIDDEN"), tempDir) val driverResult = File.createTempFile("driver", null, tempDir) val executorResult = File.createTempFile("executor", null, tempDir) - runSpark(clientMode, mainClassName(YarnClasspathTest.getClass), + val finalState = runSpark(clientMode, mainClassName(YarnClasspathTest.getClass), appArgs = Seq(driverResult.getAbsolutePath(), executorResult.getAbsolutePath()), extraClassPath = Seq(originalJar.getPath()), extraJars = Seq("local:" + userJar.getPath()), extraConf = Map( "spark.driver.userClassPathFirst" -> "true", "spark.executor.userClassPathFirst" -> "true")) - checkResult(driverResult, "OVERRIDDEN") - checkResult(executorResult, "OVERRIDDEN") + checkResult(finalState, driverResult, "OVERRIDDEN") + checkResult(finalState, executorResult, "OVERRIDDEN") } } @@ -211,8 +244,8 @@ private object YarnClusterDriver extends Logging with Matchers { data should be (Set(1, 2, 3, 4)) result = "success" } finally { - sc.stop() Files.write(result, status, UTF_8) + sc.stop() } // verify log urls are present @@ -297,3 +330,18 @@ private object YarnClasspathTest extends Logging { } } + +private object YarnLauncherTestApp { + + def main(args: Array[String]): Unit = { + // Do not stop the application; the test will stop it using the launcher lib. Just run a task + // that will prevent the process from exiting. + val sc = new SparkContext(new SparkConf()) + sc.parallelize(Seq(1)).foreach { i => + this.synchronized { + wait() + } + } + } + +} diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala index a85e5772a0..c17e8695c2 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala @@ -53,7 +53,7 @@ class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite { logInfo("Shuffle service port = " + shuffleServicePort) val result = File.createTempFile("result", null, tempDir) - runSpark( + val finalState = runSpark( false, mainClassName(YarnExternalShuffleDriver.getClass), appArgs = Seq(result.getAbsolutePath(), registeredExecFile.getAbsolutePath), @@ -62,7 +62,7 @@ class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite { "spark.shuffle.service.port" -> shuffleServicePort.toString ) ) - checkResult(result) + checkResult(finalState, result) assert(YarnTestAccessor.getRegisteredExecutorFile(shuffleService).exists()) } } |