aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/pom.xml8
-rw-r--r--core/src/main/resources/org/apache/spark/ui/static/webui.css6
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala20
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala15
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala142
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala81
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/Master.scala46
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala18
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala307
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala449
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolException.scala36
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala146
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala78
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala85
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala3
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala57
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala265
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala324
-rw-r--r--core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala6
-rw-r--r--pom.xml11
23 files changed, 2027 insertions, 94 deletions
diff --git a/core/pom.xml b/core/pom.xml
index 4daaf88147..66180035e6 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -244,6 +244,14 @@
<artifactId>metrics-graphite</artifactId>
</dependency>
<dependency>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-databind</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.module</groupId>
+ <artifactId>jackson-module-scala_2.10</artifactId>
+ </dependency>
+ <dependency>
<groupId>org.apache.derby</groupId>
<artifactId>derby</artifactId>
<scope>test</scope>
diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css
index f23ba9dba1..68b33b5f0d 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/webui.css
+++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css
@@ -103,6 +103,12 @@ span.expand-details {
float: right;
}
+span.rest-uri {
+ font-size: 10pt;
+ font-style: italic;
+ color: gray;
+}
+
pre {
font-size: 0.8em;
}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 5623587c36..71bdbc9b38 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -2110,7 +2110,7 @@ object SparkContext extends Logging {
val scheduler = new TaskSchedulerImpl(sc)
val localCluster = new LocalSparkCluster(
- numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt)
+ numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt, sc.conf)
val masterUrls = localCluster.start()
val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls)
scheduler.initialize(backend)
diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
index e5873ce724..415bd50591 100644
--- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
@@ -29,8 +29,7 @@ import org.apache.spark.util.{IntParam, MemoryParam}
* Command-line parser for the driver client.
*/
private[spark] class ClientArguments(args: Array[String]) {
- val defaultCores = 1
- val defaultMemory = 512
+ import ClientArguments._
var cmd: String = "" // 'launch' or 'kill'
var logLevel = Level.WARN
@@ -39,9 +38,9 @@ private[spark] class ClientArguments(args: Array[String]) {
var master: String = ""
var jarUrl: String = ""
var mainClass: String = ""
- var supervise: Boolean = false
- var memory: Int = defaultMemory
- var cores: Int = defaultCores
+ var supervise: Boolean = DEFAULT_SUPERVISE
+ var memory: Int = DEFAULT_MEMORY
+ var cores: Int = DEFAULT_CORES
private var _driverOptions = ListBuffer[String]()
def driverOptions = _driverOptions.toSeq
@@ -50,7 +49,7 @@ private[spark] class ClientArguments(args: Array[String]) {
parse(args.toList)
- def parse(args: List[String]): Unit = args match {
+ private def parse(args: List[String]): Unit = args match {
case ("--cores" | "-c") :: IntParam(value) :: tail =>
cores = value
parse(tail)
@@ -106,9 +105,10 @@ private[spark] class ClientArguments(args: Array[String]) {
|Usage: DriverClient kill <active-master> <driver-id>
|
|Options:
- | -c CORES, --cores CORES Number of cores to request (default: $defaultCores)
- | -m MEMORY, --memory MEMORY Megabytes of memory to request (default: $defaultMemory)
+ | -c CORES, --cores CORES Number of cores to request (default: $DEFAULT_CORES)
+ | -m MEMORY, --memory MEMORY Megabytes of memory to request (default: $DEFAULT_MEMORY)
| -s, --supervise Whether to restart the driver on failure
+ | (default: $DEFAULT_SUPERVISE)
| -v, --verbose Print more debugging output
""".stripMargin
System.err.println(usage)
@@ -117,6 +117,10 @@ private[spark] class ClientArguments(args: Array[String]) {
}
object ClientArguments {
+ private[spark] val DEFAULT_CORES = 1
+ private[spark] val DEFAULT_MEMORY = 512 // MB
+ private[spark] val DEFAULT_SUPERVISE = false
+
def isValidJarUrl(s: String): Boolean = {
try {
val uri = new URI(s)
diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
index 243d8edb72..7f600d8960 100644
--- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
@@ -148,15 +148,22 @@ private[deploy] object DeployMessages {
// Master to MasterWebUI
- case class MasterStateResponse(host: String, port: Int, workers: Array[WorkerInfo],
- activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo],
- activeDrivers: Array[DriverInfo], completedDrivers: Array[DriverInfo],
- status: MasterState) {
+ case class MasterStateResponse(
+ host: String,
+ port: Int,
+ restPort: Option[Int],
+ workers: Array[WorkerInfo],
+ activeApps: Array[ApplicationInfo],
+ completedApps: Array[ApplicationInfo],
+ activeDrivers: Array[DriverInfo],
+ completedDrivers: Array[DriverInfo],
+ status: MasterState) {
Utils.checkHost(host, "Required hostname")
assert (port > 0)
def uri = "spark://" + host + ":" + port
+ def restUri: Option[String] = restPort.map { p => "spark://" + host + ":" + p }
}
// WorkerWebUI to Worker
diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
index 9a7a113c95..0401b15446 100644
--- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
@@ -33,7 +33,11 @@ import org.apache.spark.util.Utils
* fault recovery without spinning up a lot of processes.
*/
private[spark]
-class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int)
+class LocalSparkCluster(
+ numWorkers: Int,
+ coresPerWorker: Int,
+ memoryPerWorker: Int,
+ conf: SparkConf)
extends Logging {
private val localHostname = Utils.localHostName()
@@ -43,9 +47,11 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I
def start(): Array[String] = {
logInfo("Starting a local Spark cluster with " + numWorkers + " workers.")
+ // Disable REST server on Master in this mode unless otherwise specified
+ val _conf = conf.clone().setIfMissing("spark.master.rest.enabled", "false")
+
/* Start the Master */
- val conf = new SparkConf(false)
- val (masterSystem, masterPort, _) = Master.startSystemAndActor(localHostname, 0, 0, conf)
+ val (masterSystem, masterPort, _, _) = Master.startSystemAndActor(localHostname, 0, 0, _conf)
masterActorSystems += masterSystem
val masterUrl = "spark://" + localHostname + ":" + masterPort
val masters = Array(masterUrl)
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index 9d25e647f1..6d213926f3 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -18,25 +18,35 @@
package org.apache.spark.deploy
import java.io.{File, PrintStream}
-import java.lang.reflect.{Modifier, InvocationTargetException}
+import java.lang.reflect.{InvocationTargetException, Modifier}
import java.net.URL
+
import scala.collection.mutable.{ArrayBuffer, HashMap, Map}
import org.apache.hadoop.fs.Path
import org.apache.ivy.Ivy
import org.apache.ivy.core.LogOptions
-import org.apache.ivy.core.module.descriptor.{DefaultExcludeRule, DefaultDependencyDescriptor, DefaultModuleDescriptor}
-import org.apache.ivy.core.module.id.{ModuleId, ArtifactId, ModuleRevisionId}
+import org.apache.ivy.core.module.descriptor._
+import org.apache.ivy.core.module.id.{ArtifactId, ModuleId, ModuleRevisionId}
import org.apache.ivy.core.report.ResolveReport
-import org.apache.ivy.core.resolve.{IvyNode, ResolveOptions}
+import org.apache.ivy.core.resolve.ResolveOptions
import org.apache.ivy.core.retrieve.RetrieveOptions
import org.apache.ivy.core.settings.IvySettings
import org.apache.ivy.plugins.matcher.GlobPatternMatcher
import org.apache.ivy.plugins.resolver.{ChainResolver, IBiblioResolver}
-import org.apache.spark.executor.ExecutorURLClassLoader
+
+import org.apache.spark.deploy.rest._
+import org.apache.spark.executor._
import org.apache.spark.util.Utils
-import org.apache.spark.executor.ChildExecutorURLClassLoader
-import org.apache.spark.executor.MutableURLClassLoader
+
+/**
+ * Whether to submit, kill, or request the status of an application.
+ * The latter two operations are currently supported only for standalone cluster mode.
+ */
+private[spark] object SparkSubmitAction extends Enumeration {
+ type SparkSubmitAction = Value
+ val SUBMIT, KILL, REQUEST_STATUS = Value
+}
/**
* Main gateway of launching a Spark application.
@@ -83,21 +93,74 @@ object SparkSubmit {
if (appArgs.verbose) {
printStream.println(appArgs)
}
- val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs)
- launch(childArgs, classpath, sysProps, mainClass, appArgs.verbose)
+ appArgs.action match {
+ case SparkSubmitAction.SUBMIT => submit(appArgs)
+ case SparkSubmitAction.KILL => kill(appArgs)
+ case SparkSubmitAction.REQUEST_STATUS => requestStatus(appArgs)
+ }
+ }
+
+ /** Kill an existing submission using the REST protocol. Standalone cluster mode only. */
+ private def kill(args: SparkSubmitArguments): Unit = {
+ new StandaloneRestClient()
+ .killSubmission(args.master, args.submissionToKill)
}
/**
- * @return a tuple containing
- * (1) the arguments for the child process,
- * (2) a list of classpath entries for the child,
- * (3) a list of system properties and env vars, and
- * (4) the main class for the child
+ * Request the status of an existing submission using the REST protocol.
+ * Standalone cluster mode only.
*/
- private[spark] def createLaunchEnv(args: SparkSubmitArguments)
- : (ArrayBuffer[String], ArrayBuffer[String], Map[String, String], String) = {
+ private def requestStatus(args: SparkSubmitArguments): Unit = {
+ new StandaloneRestClient()
+ .requestSubmissionStatus(args.master, args.submissionToRequestStatusFor)
+ }
- // Values to return
+ /**
+ * Submit the application using the provided parameters.
+ *
+ * This runs in two steps. First, we prepare the launch environment by setting up
+ * the appropriate classpath, system properties, and application arguments for
+ * running the child main class based on the cluster manager and the deploy mode.
+ * Second, we use this launch environment to invoke the main method of the child
+ * main class.
+ */
+ private[spark] def submit(args: SparkSubmitArguments): Unit = {
+ val (childArgs, childClasspath, sysProps, childMainClass) = prepareSubmitEnvironment(args)
+ // In standalone cluster mode, there are two submission gateways:
+ // (1) The traditional Akka gateway using o.a.s.deploy.Client as a wrapper
+ // (2) The new REST-based gateway introduced in Spark 1.3
+ // The latter is the default behavior as of Spark 1.3, but Spark submit will fail over
+ // to use the legacy gateway if the master endpoint turns out to be not a REST server.
+ if (args.isStandaloneCluster && args.useRest) {
+ try {
+ printStream.println("Running Spark using the REST application submission protocol.")
+ runMain(childArgs, childClasspath, sysProps, childMainClass)
+ } catch {
+ // Fail over to use the legacy submission gateway
+ case e: SubmitRestConnectionException =>
+ printWarning(s"Master endpoint ${args.master} was not a REST server. " +
+ "Falling back to legacy submission gateway instead.")
+ args.useRest = false
+ submit(args)
+ }
+ // In all other modes, just run the main class as prepared
+ } else {
+ runMain(childArgs, childClasspath, sysProps, childMainClass)
+ }
+ }
+
+ /**
+ * Prepare the environment for submitting an application.
+ * This returns a 4-tuple:
+ * (1) the arguments for the child process,
+ * (2) a list of classpath entries for the child,
+ * (3) a map of system properties, and
+ * (4) the main class for the child
+ * Exposed for testing.
+ */
+ private[spark] def prepareSubmitEnvironment(args: SparkSubmitArguments)
+ : (Seq[String], Seq[String], Map[String, String], String) = {
+ // Return values
val childArgs = new ArrayBuffer[String]()
val childClasspath = new ArrayBuffer[String]()
val sysProps = new HashMap[String, String]()
@@ -235,10 +298,13 @@ object SparkSubmit {
sysProp = "spark.driver.extraLibraryPath"),
// Standalone cluster only
+ // Do not set CL arguments here because there are multiple possibilities for the main class
OptionAssigner(args.jars, STANDALONE, CLUSTER, sysProp = "spark.jars"),
OptionAssigner(args.ivyRepoPath, STANDALONE, CLUSTER, sysProp = "spark.jars.ivy"),
- OptionAssigner(args.driverMemory, STANDALONE, CLUSTER, clOption = "--memory"),
- OptionAssigner(args.driverCores, STANDALONE, CLUSTER, clOption = "--cores"),
+ OptionAssigner(args.driverMemory, STANDALONE, CLUSTER, sysProp = "spark.driver.memory"),
+ OptionAssigner(args.driverCores, STANDALONE, CLUSTER, sysProp = "spark.driver.cores"),
+ OptionAssigner(args.supervise.toString, STANDALONE, CLUSTER,
+ sysProp = "spark.driver.supervise"),
// Yarn client only
OptionAssigner(args.queue, YARN, CLIENT, sysProp = "spark.yarn.queue"),
@@ -279,7 +345,6 @@ object SparkSubmit {
if (args.childArgs != null) { childArgs ++= args.childArgs }
}
-
// Map all arguments to command-line options or system properties for our chosen mode
for (opt <- options) {
if (opt.value != null &&
@@ -301,14 +366,21 @@ object SparkSubmit {
sysProps.put("spark.jars", jars.mkString(","))
}
- // In standalone-cluster mode, use Client as a wrapper around the user class
- if (clusterManager == STANDALONE && deployMode == CLUSTER) {
- childMainClass = "org.apache.spark.deploy.Client"
- if (args.supervise) {
- childArgs += "--supervise"
+ // In standalone cluster mode, use the REST client to submit the application (Spark 1.3+).
+ // All Spark parameters are expected to be passed to the client through system properties.
+ if (args.isStandaloneCluster) {
+ if (args.useRest) {
+ childMainClass = "org.apache.spark.deploy.rest.StandaloneRestClient"
+ childArgs += (args.primaryResource, args.mainClass)
+ } else {
+ // In legacy standalone cluster mode, use Client as a wrapper around the user class
+ childMainClass = "org.apache.spark.deploy.Client"
+ if (args.supervise) { childArgs += "--supervise" }
+ Option(args.driverMemory).foreach { m => childArgs += ("--memory", m) }
+ Option(args.driverCores).foreach { c => childArgs += ("--cores", c) }
+ childArgs += "launch"
+ childArgs += (args.master, args.primaryResource, args.mainClass)
}
- childArgs += "launch"
- childArgs += (args.master, args.primaryResource, args.mainClass)
if (args.childArgs != null) {
childArgs ++= args.childArgs
}
@@ -345,7 +417,7 @@ object SparkSubmit {
// Ignore invalid spark.driver.host in cluster modes.
if (deployMode == CLUSTER) {
- sysProps -= ("spark.driver.host")
+ sysProps -= "spark.driver.host"
}
// Resolve paths in certain spark properties
@@ -374,9 +446,15 @@ object SparkSubmit {
(childArgs, childClasspath, sysProps, childMainClass)
}
- private def launch(
- childArgs: ArrayBuffer[String],
- childClasspath: ArrayBuffer[String],
+ /**
+ * Run the main method of the child class using the provided launch environment.
+ *
+ * Note that this main class will not be the one provided by the user if we're
+ * running cluster deploy mode or python applications.
+ */
+ private def runMain(
+ childArgs: Seq[String],
+ childClasspath: Seq[String],
sysProps: Map[String, String],
childMainClass: String,
verbose: Boolean = false) {
@@ -697,7 +775,7 @@ private[spark] object SparkSubmitUtils {
* Provides an indirection layer for passing arguments as system properties or flags to
* the user's driver program or to downstream launcher tools.
*/
-private[spark] case class OptionAssigner(
+private case class OptionAssigner(
value: String,
clusterManager: Int,
deployMode: Int,
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
index 5cadc534f4..bd0ae26fd8 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -22,6 +22,7 @@ import java.util.jar.JarFile
import scala.collection.mutable.{ArrayBuffer, HashMap}
+import org.apache.spark.deploy.SparkSubmitAction._
import org.apache.spark.util.Utils
/**
@@ -39,8 +40,6 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
var driverExtraClassPath: String = null
var driverExtraLibraryPath: String = null
var driverExtraJavaOptions: String = null
- var driverCores: String = null
- var supervise: Boolean = false
var queue: String = null
var numExecutors: String = null
var files: String = null
@@ -56,8 +55,16 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
var verbose: Boolean = false
var isPython: Boolean = false
var pyFiles: String = null
+ var action: SparkSubmitAction = null
val sparkProperties: HashMap[String, String] = new HashMap[String, String]()
+ // Standalone cluster mode only
+ var supervise: Boolean = false
+ var driverCores: String = null
+ var submissionToKill: String = null
+ var submissionToRequestStatusFor: String = null
+ var useRest: Boolean = true // used internally
+
/** Default properties present in the currently defined defaults file. */
lazy val defaultSparkProperties: HashMap[String, String] = {
val defaultProperties = new HashMap[String, String]()
@@ -82,7 +89,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
// Use `sparkProperties` map along with env vars to fill in any missing parameters
loadEnvironmentArguments()
- checkRequiredArguments()
+ validateArguments()
/**
* Merge values from the default properties file with those specified through --conf.
@@ -107,6 +114,15 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
.orElse(sparkProperties.get("spark.master"))
.orElse(env.get("MASTER"))
.orNull
+ driverExtraClassPath = Option(driverExtraClassPath)
+ .orElse(sparkProperties.get("spark.driver.extraClassPath"))
+ .orNull
+ driverExtraJavaOptions = Option(driverExtraJavaOptions)
+ .orElse(sparkProperties.get("spark.driver.extraJavaOptions"))
+ .orNull
+ driverExtraLibraryPath = Option(driverExtraLibraryPath)
+ .orElse(sparkProperties.get("spark.driver.extraLibraryPath"))
+ .orNull
driverMemory = Option(driverMemory)
.orElse(sparkProperties.get("spark.driver.memory"))
.orElse(env.get("SPARK_DRIVER_MEMORY"))
@@ -166,10 +182,21 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
if (name == null && primaryResource != null) {
name = Utils.stripDirectory(primaryResource)
}
+
+ // Action should be SUBMIT unless otherwise specified
+ action = Option(action).getOrElse(SUBMIT)
}
/** Ensure that required fields exists. Call this only once all defaults are loaded. */
- private def checkRequiredArguments(): Unit = {
+ private def validateArguments(): Unit = {
+ action match {
+ case SUBMIT => validateSubmitArguments()
+ case KILL => validateKillArguments()
+ case REQUEST_STATUS => validateStatusRequestArguments()
+ }
+ }
+
+ private def validateSubmitArguments(): Unit = {
if (args.length == 0) {
printUsageAndExit(-1)
}
@@ -192,6 +219,29 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
}
}
+ private def validateKillArguments(): Unit = {
+ if (!master.startsWith("spark://")) {
+ SparkSubmit.printErrorAndExit("Killing submissions is only supported in standalone mode!")
+ }
+ if (submissionToKill == null) {
+ SparkSubmit.printErrorAndExit("Please specify a submission to kill.")
+ }
+ }
+
+ private def validateStatusRequestArguments(): Unit = {
+ if (!master.startsWith("spark://")) {
+ SparkSubmit.printErrorAndExit(
+ "Requesting submission statuses is only supported in standalone mode!")
+ }
+ if (submissionToRequestStatusFor == null) {
+ SparkSubmit.printErrorAndExit("Please specify a submission to request status for.")
+ }
+ }
+
+ def isStandaloneCluster: Boolean = {
+ master.startsWith("spark://") && deployMode == "cluster"
+ }
+
override def toString = {
s"""Parsed arguments:
| master $master
@@ -300,6 +350,22 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
propertiesFile = value
parse(tail)
+ case ("--kill") :: value :: tail =>
+ submissionToKill = value
+ if (action != null) {
+ SparkSubmit.printErrorAndExit(s"Action cannot be both $action and $KILL.")
+ }
+ action = KILL
+ parse(tail)
+
+ case ("--status") :: value :: tail =>
+ submissionToRequestStatusFor = value
+ if (action != null) {
+ SparkSubmit.printErrorAndExit(s"Action cannot be both $action and $REQUEST_STATUS.")
+ }
+ action = REQUEST_STATUS
+ parse(tail)
+
case ("--supervise") :: tail =>
supervise = true
parse(tail)
@@ -372,7 +438,10 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
outStream.println("Unknown/unsupported param " + unknownParam)
}
outStream.println(
- """Usage: spark-submit [options] <app jar | python file> [app options]
+ """Usage: spark-submit [options] <app jar | python file> [app arguments]
+ |Usage: spark-submit --kill [submission ID] --master [spark://...]
+ |Usage: spark-submit --status [submission ID] --master [spark://...]
+ |
|Options:
| --master MASTER_URL spark://host:port, mesos://host:port, yarn, or local.
| --deploy-mode DEPLOY_MODE Whether to launch the driver program locally ("client") or
@@ -413,6 +482,8 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
| Spark standalone with cluster deploy mode only:
| --driver-cores NUM Cores for driver (Default: 1).
| --supervise If given, restarts the driver on failure.
+ | --kill SUBMISSION_ID If given, kills the driver specified.
+ | --status SUBMISSION_ID If given, requests the status of the driver specified.
|
| Spark standalone and Mesos only:
| --total-executor-cores NUM Total cores for all executors.
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index 5eeb9fe526..b8b1a25abf 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -43,6 +43,7 @@ import org.apache.spark.deploy.history.HistoryServer
import org.apache.spark.deploy.master.DriverState.DriverState
import org.apache.spark.deploy.master.MasterMessages._
import org.apache.spark.deploy.master.ui.MasterWebUI
+import org.apache.spark.deploy.rest.StandaloneRestServer
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus}
import org.apache.spark.ui.SparkUI
@@ -52,12 +53,12 @@ private[spark] class Master(
host: String,
port: Int,
webUiPort: Int,
- val securityMgr: SecurityManager)
+ val securityMgr: SecurityManager,
+ val conf: SparkConf)
extends Actor with ActorLogReceive with Logging with LeaderElectable {
import context.dispatcher // to use Akka's scheduler.schedule()
- val conf = new SparkConf
val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs
@@ -121,6 +122,17 @@ private[spark] class Master(
throw new SparkException("spark.deploy.defaultCores must be positive")
}
+ // Alternative application submission gateway that is stable across Spark versions
+ private val restServerEnabled = conf.getBoolean("spark.master.rest.enabled", true)
+ private val restServer =
+ if (restServerEnabled) {
+ val port = conf.getInt("spark.master.rest.port", 6066)
+ Some(new StandaloneRestServer(host, port, self, masterUrl, conf))
+ } else {
+ None
+ }
+ private val restServerBoundPort = restServer.map(_.start())
+
override def preStart() {
logInfo("Starting Spark master at " + masterUrl)
logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}")
@@ -174,6 +186,7 @@ private[spark] class Master(
recoveryCompletionTask.cancel()
}
webUi.stop()
+ restServer.foreach(_.stop())
masterMetricsSystem.stop()
applicationMetricsSystem.stop()
persistenceEngine.close()
@@ -421,7 +434,9 @@ private[spark] class Master(
}
case RequestMasterState => {
- sender ! MasterStateResponse(host, port, workers.toArray, apps.toArray, completedApps.toArray,
+ sender ! MasterStateResponse(
+ host, port, restServerBoundPort,
+ workers.toArray, apps.toArray, completedApps.toArray,
drivers.toArray, completedDrivers.toArray, state)
}
@@ -429,8 +444,8 @@ private[spark] class Master(
timeOutDeadWorkers()
}
- case RequestWebUIPort => {
- sender ! WebUIPortResponse(webUi.boundPort)
+ case BoundPortsRequest => {
+ sender ! BoundPortsResponse(port, webUi.boundPort, restServerBoundPort)
}
}
@@ -851,7 +866,7 @@ private[spark] object Master extends Logging {
SignalLogger.register(log)
val conf = new SparkConf
val args = new MasterArguments(argStrings, conf)
- val (actorSystem, _, _) = startSystemAndActor(args.host, args.port, args.webUiPort, conf)
+ val (actorSystem, _, _, _) = startSystemAndActor(args.host, args.port, args.webUiPort, conf)
actorSystem.awaitTermination()
}
@@ -875,19 +890,26 @@ private[spark] object Master extends Logging {
Address(protocol, systemName, host, port)
}
+ /**
+ * Start the Master and return a four tuple of:
+ * (1) The Master actor system
+ * (2) The bound port
+ * (3) The web UI bound port
+ * (4) The REST server bound port, if any
+ */
def startSystemAndActor(
host: String,
port: Int,
webUiPort: Int,
- conf: SparkConf): (ActorSystem, Int, Int) = {
+ conf: SparkConf): (ActorSystem, Int, Int, Option[Int]) = {
val securityMgr = new SecurityManager(conf)
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf,
securityManager = securityMgr)
- val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort,
- securityMgr), actorName)
+ val actor = actorSystem.actorOf(
+ Props(classOf[Master], host, boundPort, webUiPort, securityMgr, conf), actorName)
val timeout = AkkaUtils.askTimeout(conf)
- val respFuture = actor.ask(RequestWebUIPort)(timeout)
- val resp = Await.result(respFuture, timeout).asInstanceOf[WebUIPortResponse]
- (actorSystem, boundPort, resp.webUIBoundPort)
+ val portsRequest = actor.ask(BoundPortsRequest)(timeout)
+ val portsResponse = Await.result(portsRequest, timeout).asInstanceOf[BoundPortsResponse]
+ (actorSystem, boundPort, portsResponse.webUIPort, portsResponse.restPort)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala
index db72d8ae9b..15c6296888 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala
@@ -36,7 +36,7 @@ private[master] object MasterMessages {
case object CompleteRecovery
- case object RequestWebUIPort
+ case object BoundPortsRequest
- case class WebUIPortResponse(webUIBoundPort: Int)
+ case class BoundPortsResponse(actorPort: Int, webUIPort: Int, restPort: Option[Int])
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
index 7ca3b08a28..b47a081053 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
@@ -46,19 +46,19 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse]
val state = Await.result(stateFuture, timeout)
- val workerHeaders = Seq("Id", "Address", "State", "Cores", "Memory")
+ val workerHeaders = Seq("Worker Id", "Address", "State", "Cores", "Memory")
val workers = state.workers.sortBy(_.id)
val workerTable = UIUtils.listingTable(workerHeaders, workerRow, workers)
- val appHeaders = Seq("ID", "Name", "Cores", "Memory per Node", "Submitted Time", "User",
- "State", "Duration")
+ val appHeaders = Seq("Application ID", "Name", "Cores", "Memory per Node", "Submitted Time",
+ "User", "State", "Duration")
val activeApps = state.activeApps.sortBy(_.startTime).reverse
val activeAppsTable = UIUtils.listingTable(appHeaders, appRow, activeApps)
val completedApps = state.completedApps.sortBy(_.endTime).reverse
val completedAppsTable = UIUtils.listingTable(appHeaders, appRow, completedApps)
- val driverHeaders = Seq("ID", "Submitted Time", "Worker", "State", "Cores", "Memory",
- "Main Class")
+ val driverHeaders = Seq("Submission ID", "Submitted Time", "Worker", "State", "Cores",
+ "Memory", "Main Class")
val activeDrivers = state.activeDrivers.sortBy(_.startTime).reverse
val activeDriversTable = UIUtils.listingTable(driverHeaders, driverRow, activeDrivers)
val completedDrivers = state.completedDrivers.sortBy(_.startTime).reverse
@@ -73,6 +73,14 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
<div class="span12">
<ul class="unstyled">
<li><strong>URL:</strong> {state.uri}</li>
+ {
+ state.restUri.map { uri =>
+ <li>
+ <strong>REST URL:</strong> {uri}
+ <span class="rest-uri"> (cluster mode)</span>
+ </li>
+ }.getOrElse { Seq.empty }
+ }
<li><strong>Workers:</strong> {state.workers.size}</li>
<li><strong>Cores:</strong> {state.workers.map(_.cores).sum} Total,
{state.workers.map(_.coresUsed).sum} Used</li>
diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala
new file mode 100644
index 0000000000..115aa5278b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala
@@ -0,0 +1,307 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.rest
+
+import java.io.{DataOutputStream, FileNotFoundException}
+import java.net.{HttpURLConnection, SocketException, URL}
+
+import scala.io.Source
+
+import com.fasterxml.jackson.databind.JsonMappingException
+import com.google.common.base.Charsets
+
+import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion}
+
+/**
+ * A client that submits applications to the standalone Master using a REST protocol.
+ * This client is intended to communicate with the [[StandaloneRestServer]] and is
+ * currently used for cluster mode only.
+ *
+ * In protocol version v1, the REST URL takes the form http://[host:port]/v1/submissions/[action],
+ * where [action] can be one of create, kill, or status. Each type of request is represented in
+ * an HTTP message sent to the following prefixes:
+ * (1) submit - POST to /submissions/create
+ * (2) kill - POST /submissions/kill/[submissionId]
+ * (3) status - GET /submissions/status/[submissionId]
+ *
+ * In the case of (1), parameters are posted in the HTTP body in the form of JSON fields.
+ * Otherwise, the URL fully specifies the intended action of the client.
+ *
+ * Since the protocol is expected to be stable across Spark versions, existing fields cannot be
+ * added or removed, though new optional fields can be added. In the rare event that forward or
+ * backward compatibility is broken, Spark must introduce a new protocol version (e.g. v2).
+ *
+ * The client and the server must communicate using the same version of the protocol. If there
+ * is a mismatch, the server will respond with the highest protocol version it supports. A future
+ * implementation of this client can use that information to retry using the version specified
+ * by the server.
+ */
+private[spark] class StandaloneRestClient extends Logging {
+ import StandaloneRestClient._
+
+ /**
+ * Submit an application specified by the parameters in the provided request.
+ *
+ * If the submission was successful, poll the status of the submission and report
+ * it to the user. Otherwise, report the error message provided by the server.
+ */
+ def createSubmission(
+ master: String,
+ request: CreateSubmissionRequest): SubmitRestProtocolResponse = {
+ logInfo(s"Submitting a request to launch an application in $master.")
+ validateMaster(master)
+ val url = getSubmitUrl(master)
+ val response = postJson(url, request.toJson)
+ response match {
+ case s: CreateSubmissionResponse =>
+ reportSubmissionStatus(master, s)
+ handleRestResponse(s)
+ case unexpected =>
+ handleUnexpectedRestResponse(unexpected)
+ }
+ response
+ }
+
+ /** Request that the server kill the specified submission. */
+ def killSubmission(master: String, submissionId: String): SubmitRestProtocolResponse = {
+ logInfo(s"Submitting a request to kill submission $submissionId in $master.")
+ validateMaster(master)
+ val response = post(getKillUrl(master, submissionId))
+ response match {
+ case k: KillSubmissionResponse => handleRestResponse(k)
+ case unexpected => handleUnexpectedRestResponse(unexpected)
+ }
+ response
+ }
+
+ /** Request the status of a submission from the server. */
+ def requestSubmissionStatus(
+ master: String,
+ submissionId: String,
+ quiet: Boolean = false): SubmitRestProtocolResponse = {
+ logInfo(s"Submitting a request for the status of submission $submissionId in $master.")
+ validateMaster(master)
+ val response = get(getStatusUrl(master, submissionId))
+ response match {
+ case s: SubmissionStatusResponse => if (!quiet) { handleRestResponse(s) }
+ case unexpected => handleUnexpectedRestResponse(unexpected)
+ }
+ response
+ }
+
+ /** Construct a message that captures the specified parameters for submitting an application. */
+ def constructSubmitRequest(
+ appResource: String,
+ mainClass: String,
+ appArgs: Array[String],
+ sparkProperties: Map[String, String],
+ environmentVariables: Map[String, String]): CreateSubmissionRequest = {
+ val message = new CreateSubmissionRequest
+ message.clientSparkVersion = sparkVersion
+ message.appResource = appResource
+ message.mainClass = mainClass
+ message.appArgs = appArgs
+ message.sparkProperties = sparkProperties
+ message.environmentVariables = environmentVariables
+ message.validate()
+ message
+ }
+
+ /** Send a GET request to the specified URL. */
+ private def get(url: URL): SubmitRestProtocolResponse = {
+ logDebug(s"Sending GET request to server at $url.")
+ val conn = url.openConnection().asInstanceOf[HttpURLConnection]
+ conn.setRequestMethod("GET")
+ readResponse(conn)
+ }
+
+ /** Send a POST request to the specified URL. */
+ private def post(url: URL): SubmitRestProtocolResponse = {
+ logDebug(s"Sending POST request to server at $url.")
+ val conn = url.openConnection().asInstanceOf[HttpURLConnection]
+ conn.setRequestMethod("POST")
+ readResponse(conn)
+ }
+
+ /** Send a POST request with the given JSON as the body to the specified URL. */
+ private def postJson(url: URL, json: String): SubmitRestProtocolResponse = {
+ logDebug(s"Sending POST request to server at $url:\n$json")
+ val conn = url.openConnection().asInstanceOf[HttpURLConnection]
+ conn.setRequestMethod("POST")
+ conn.setRequestProperty("Content-Type", "application/json")
+ conn.setRequestProperty("charset", "utf-8")
+ conn.setDoOutput(true)
+ val out = new DataOutputStream(conn.getOutputStream)
+ out.write(json.getBytes(Charsets.UTF_8))
+ out.close()
+ readResponse(conn)
+ }
+
+ /**
+ * Read the response from the server and return it as a validated [[SubmitRestProtocolResponse]].
+ * If the response represents an error, report the embedded message to the user.
+ */
+ private def readResponse(connection: HttpURLConnection): SubmitRestProtocolResponse = {
+ try {
+ val responseJson = Source.fromInputStream(connection.getInputStream).mkString
+ logDebug(s"Response from the server:\n$responseJson")
+ val response = SubmitRestProtocolMessage.fromJson(responseJson)
+ response.validate()
+ response match {
+ // If the response is an error, log the message
+ case error: ErrorResponse =>
+ logError(s"Server responded with error:\n${error.message}")
+ error
+ // Otherwise, simply return the response
+ case response: SubmitRestProtocolResponse => response
+ case unexpected =>
+ throw new SubmitRestProtocolException(
+ s"Message received from server was not a response:\n${unexpected.toJson}")
+ }
+ } catch {
+ case unreachable @ (_: FileNotFoundException | _: SocketException) =>
+ throw new SubmitRestConnectionException(
+ s"Unable to connect to server ${connection.getURL}", unreachable)
+ case malformed @ (_: SubmitRestProtocolException | _: JsonMappingException) =>
+ throw new SubmitRestProtocolException(
+ "Malformed response received from server", malformed)
+ }
+ }
+
+ /** Return the REST URL for creating a new submission. */
+ private def getSubmitUrl(master: String): URL = {
+ val baseUrl = getBaseUrl(master)
+ new URL(s"$baseUrl/create")
+ }
+
+ /** Return the REST URL for killing an existing submission. */
+ private def getKillUrl(master: String, submissionId: String): URL = {
+ val baseUrl = getBaseUrl(master)
+ new URL(s"$baseUrl/kill/$submissionId")
+ }
+
+ /** Return the REST URL for requesting the status of an existing submission. */
+ private def getStatusUrl(master: String, submissionId: String): URL = {
+ val baseUrl = getBaseUrl(master)
+ new URL(s"$baseUrl/status/$submissionId")
+ }
+
+ /** Return the base URL for communicating with the server, including the protocol version. */
+ private def getBaseUrl(master: String): String = {
+ val masterUrl = master.stripPrefix("spark://").stripSuffix("/")
+ s"http://$masterUrl/$PROTOCOL_VERSION/submissions"
+ }
+
+ /** Throw an exception if this is not standalone mode. */
+ private def validateMaster(master: String): Unit = {
+ if (!master.startsWith("spark://")) {
+ throw new IllegalArgumentException("This REST client is only supported in standalone mode.")
+ }
+ }
+
+ /** Report the status of a newly created submission. */
+ private def reportSubmissionStatus(
+ master: String,
+ submitResponse: CreateSubmissionResponse): Unit = {
+ if (submitResponse.success) {
+ val submissionId = submitResponse.submissionId
+ if (submissionId != null) {
+ logInfo(s"Submission successfully created as $submissionId. Polling submission state...")
+ pollSubmissionStatus(master, submissionId)
+ } else {
+ // should never happen
+ logError("Application successfully submitted, but submission ID was not provided!")
+ }
+ } else {
+ val failMessage = Option(submitResponse.message).map { ": " + _ }.getOrElse("")
+ logError("Application submission failed" + failMessage)
+ }
+ }
+
+ /**
+ * Poll the status of the specified submission and log it.
+ * This retries up to a fixed number of times before giving up.
+ */
+ private def pollSubmissionStatus(master: String, submissionId: String): Unit = {
+ (1 to REPORT_DRIVER_STATUS_MAX_TRIES).foreach { _ =>
+ val response = requestSubmissionStatus(master, submissionId, quiet = true)
+ val statusResponse = response match {
+ case s: SubmissionStatusResponse => s
+ case _ => return // unexpected type, let upstream caller handle it
+ }
+ if (statusResponse.success) {
+ val driverState = Option(statusResponse.driverState)
+ val workerId = Option(statusResponse.workerId)
+ val workerHostPort = Option(statusResponse.workerHostPort)
+ val exception = Option(statusResponse.message)
+ // Log driver state, if present
+ driverState match {
+ case Some(state) => logInfo(s"State of driver $submissionId is now $state.")
+ case _ => logError(s"State of driver $submissionId was not found!")
+ }
+ // Log worker node, if present
+ (workerId, workerHostPort) match {
+ case (Some(id), Some(hp)) => logInfo(s"Driver is running on worker $id at $hp.")
+ case _ =>
+ }
+ // Log exception stack trace, if present
+ exception.foreach { e => logError(e) }
+ return
+ }
+ Thread.sleep(REPORT_DRIVER_STATUS_INTERVAL)
+ }
+ logError(s"Error: Master did not recognize driver $submissionId.")
+ }
+
+ /** Log the response sent by the server in the REST application submission protocol. */
+ private def handleRestResponse(response: SubmitRestProtocolResponse): Unit = {
+ logInfo(s"Server responded with ${response.messageType}:\n${response.toJson}")
+ }
+
+ /** Log an appropriate error if the response sent by the server is not of the expected type. */
+ private def handleUnexpectedRestResponse(unexpected: SubmitRestProtocolResponse): Unit = {
+ logError(s"Error: Server responded with message of unexpected type ${unexpected.messageType}.")
+ }
+}
+
+private[spark] object StandaloneRestClient {
+ val REPORT_DRIVER_STATUS_INTERVAL = 1000
+ val REPORT_DRIVER_STATUS_MAX_TRIES = 10
+ val PROTOCOL_VERSION = "v1"
+
+ /** Submit an application, assuming Spark parameters are specified through system properties. */
+ def main(args: Array[String]): Unit = {
+ if (args.size < 2) {
+ sys.error("Usage: StandaloneRestClient [app resource] [main class] [app args*]")
+ sys.exit(1)
+ }
+ val appResource = args(0)
+ val mainClass = args(1)
+ val appArgs = args.slice(2, args.size)
+ val conf = new SparkConf
+ val master = conf.getOption("spark.master").getOrElse {
+ throw new IllegalArgumentException("'spark.master' must be set.")
+ }
+ val sparkProperties = conf.getAll.toMap
+ val environmentVariables = sys.env.filter { case (k, _) => k.startsWith("SPARK_") }
+ val client = new StandaloneRestClient
+ val submitRequest = client.constructSubmitRequest(
+ appResource, mainClass, appArgs, sparkProperties, environmentVariables)
+ client.createSubmission(master, submitRequest)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala
new file mode 100644
index 0000000000..2033d67e1f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala
@@ -0,0 +1,449 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.rest
+
+import java.io.{DataOutputStream, File}
+import java.net.InetSocketAddress
+import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse}
+
+import scala.io.Source
+
+import akka.actor.ActorRef
+import com.fasterxml.jackson.databind.JsonMappingException
+import com.google.common.base.Charsets
+import org.eclipse.jetty.server.Server
+import org.eclipse.jetty.servlet.{ServletHolder, ServletContextHandler}
+import org.eclipse.jetty.util.thread.QueuedThreadPool
+import org.json4s._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion}
+import org.apache.spark.util.{AkkaUtils, Utils}
+import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription}
+import org.apache.spark.deploy.ClientArguments._
+
+/**
+ * A server that responds to requests submitted by the [[StandaloneRestClient]].
+ * This is intended to be embedded in the standalone Master and used in cluster mode only.
+ *
+ * This server responds with different HTTP codes depending on the situation:
+ * 200 OK - Request was processed successfully
+ * 400 BAD REQUEST - Request was malformed, not successfully validated, or of unexpected type
+ * 468 UNKNOWN PROTOCOL VERSION - Request specified a protocol this server does not understand
+ * 500 INTERNAL SERVER ERROR - Server throws an exception internally while processing the request
+ *
+ * The server always includes a JSON representation of the relevant [[SubmitRestProtocolResponse]]
+ * in the HTTP body. If an error occurs, however, the server will include an [[ErrorResponse]]
+ * instead of the one expected by the client. If the construction of this error response itself
+ * fails, the response will consist of an empty body with a response code that indicates internal
+ * server error.
+ *
+ * @param host the address this server should bind to
+ * @param requestedPort the port this server will attempt to bind to
+ * @param masterActor reference to the Master actor to which requests can be sent
+ * @param masterUrl the URL of the Master new drivers will attempt to connect to
+ * @param masterConf the conf used by the Master
+ */
+private[spark] class StandaloneRestServer(
+ host: String,
+ requestedPort: Int,
+ masterActor: ActorRef,
+ masterUrl: String,
+ masterConf: SparkConf)
+ extends Logging {
+
+ import StandaloneRestServer._
+
+ private var _server: Option[Server] = None
+ private val baseContext = s"/$PROTOCOL_VERSION/submissions"
+
+ // A mapping from servlets to the URL prefixes they are responsible for
+ private val servletToContext = Map[StandaloneRestServlet, String](
+ new SubmitRequestServlet(masterActor, masterUrl, masterConf) -> s"$baseContext/create/*",
+ new KillRequestServlet(masterActor, masterConf) -> s"$baseContext/kill/*",
+ new StatusRequestServlet(masterActor, masterConf) -> s"$baseContext/status/*",
+ new ErrorServlet -> "/" // default handler
+ )
+
+ /** Start the server and return the bound port. */
+ def start(): Int = {
+ val (server, boundPort) = Utils.startServiceOnPort[Server](requestedPort, doStart, masterConf)
+ _server = Some(server)
+ logInfo(s"Started REST server for submitting applications on port $boundPort")
+ boundPort
+ }
+
+ /**
+ * Map the servlets to their corresponding contexts and attach them to a server.
+ * Return a 2-tuple of the started server and the bound port.
+ */
+ private def doStart(startPort: Int): (Server, Int) = {
+ val server = new Server(new InetSocketAddress(host, startPort))
+ val threadPool = new QueuedThreadPool
+ threadPool.setDaemon(true)
+ server.setThreadPool(threadPool)
+ val mainHandler = new ServletContextHandler
+ mainHandler.setContextPath("/")
+ servletToContext.foreach { case (servlet, prefix) =>
+ mainHandler.addServlet(new ServletHolder(servlet), prefix)
+ }
+ server.setHandler(mainHandler)
+ server.start()
+ val boundPort = server.getConnectors()(0).getLocalPort
+ (server, boundPort)
+ }
+
+ def stop(): Unit = {
+ _server.foreach(_.stop())
+ }
+}
+
+private object StandaloneRestServer {
+ val PROTOCOL_VERSION = StandaloneRestClient.PROTOCOL_VERSION
+ val SC_UNKNOWN_PROTOCOL_VERSION = 468
+}
+
+/**
+ * An abstract servlet for handling requests passed to the [[StandaloneRestServer]].
+ */
+private abstract class StandaloneRestServlet extends HttpServlet with Logging {
+
+ /** Service a request. If an exception is thrown in the process, indicate server error. */
+ protected override def service(
+ request: HttpServletRequest,
+ response: HttpServletResponse): Unit = {
+ try {
+ super.service(request, response)
+ } catch {
+ case e: Exception =>
+ logError("Exception while handling request", e)
+ response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR)
+ }
+ }
+
+ /**
+ * Serialize the given response message to JSON and send it through the response servlet.
+ * This validates the response before sending it to ensure it is properly constructed.
+ */
+ protected def sendResponse(
+ responseMessage: SubmitRestProtocolResponse,
+ responseServlet: HttpServletResponse): Unit = {
+ val message = validateResponse(responseMessage, responseServlet)
+ responseServlet.setContentType("application/json")
+ responseServlet.setCharacterEncoding("utf-8")
+ responseServlet.setStatus(HttpServletResponse.SC_OK)
+ val content = message.toJson.getBytes(Charsets.UTF_8)
+ val out = new DataOutputStream(responseServlet.getOutputStream)
+ out.write(content)
+ out.close()
+ }
+
+ /**
+ * Return any fields in the client request message that the server does not know about.
+ *
+ * The mechanism for this is to reconstruct the JSON on the server side and compare the
+ * diff between this JSON and the one generated on the client side. Any fields that are
+ * only in the client JSON are treated as unexpected.
+ */
+ protected def findUnknownFields(
+ requestJson: String,
+ requestMessage: SubmitRestProtocolMessage): Array[String] = {
+ val clientSideJson = parse(requestJson)
+ val serverSideJson = parse(requestMessage.toJson)
+ val Diff(_, _, unknown) = clientSideJson.diff(serverSideJson)
+ unknown match {
+ case j: JObject => j.obj.map { case (k, _) => k }.toArray
+ case _ => Array.empty[String] // No difference
+ }
+ }
+
+ /** Return a human readable String representation of the exception. */
+ protected def formatException(e: Throwable): String = {
+ val stackTraceString = e.getStackTrace.map { "\t" + _ }.mkString("\n")
+ s"$e\n$stackTraceString"
+ }
+
+ /** Construct an error message to signal the fact that an exception has been thrown. */
+ protected def handleError(message: String): ErrorResponse = {
+ val e = new ErrorResponse
+ e.serverSparkVersion = sparkVersion
+ e.message = message
+ e
+ }
+
+ /**
+ * Validate the response to ensure that it is correctly constructed.
+ *
+ * If it is, simply return the message as is. Otherwise, return an error response instead
+ * to propagate the exception back to the client and set the appropriate error code.
+ */
+ private def validateResponse(
+ responseMessage: SubmitRestProtocolResponse,
+ responseServlet: HttpServletResponse): SubmitRestProtocolResponse = {
+ try {
+ responseMessage.validate()
+ responseMessage
+ } catch {
+ case e: Exception =>
+ responseServlet.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR)
+ handleError("Internal server error: " + formatException(e))
+ }
+ }
+}
+
+/**
+ * A servlet for handling kill requests passed to the [[StandaloneRestServer]].
+ */
+private class KillRequestServlet(masterActor: ActorRef, conf: SparkConf)
+ extends StandaloneRestServlet {
+
+ /**
+ * If a submission ID is specified in the URL, have the Master kill the corresponding
+ * driver and return an appropriate response to the client. Otherwise, return error.
+ */
+ protected override def doPost(
+ request: HttpServletRequest,
+ response: HttpServletResponse): Unit = {
+ val submissionId = request.getPathInfo.stripPrefix("/")
+ val responseMessage =
+ if (submissionId.nonEmpty) {
+ handleKill(submissionId)
+ } else {
+ response.setStatus(HttpServletResponse.SC_BAD_REQUEST)
+ handleError("Submission ID is missing in kill request.")
+ }
+ sendResponse(responseMessage, response)
+ }
+
+ private def handleKill(submissionId: String): KillSubmissionResponse = {
+ val askTimeout = AkkaUtils.askTimeout(conf)
+ val response = AkkaUtils.askWithReply[DeployMessages.KillDriverResponse](
+ DeployMessages.RequestKillDriver(submissionId), masterActor, askTimeout)
+ val k = new KillSubmissionResponse
+ k.serverSparkVersion = sparkVersion
+ k.message = response.message
+ k.submissionId = submissionId
+ k.success = response.success
+ k
+ }
+}
+
+/**
+ * A servlet for handling status requests passed to the [[StandaloneRestServer]].
+ */
+private class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf)
+ extends StandaloneRestServlet {
+
+ /**
+ * If a submission ID is specified in the URL, request the status of the corresponding
+ * driver from the Master and include it in the response. Otherwise, return error.
+ */
+ protected override def doGet(
+ request: HttpServletRequest,
+ response: HttpServletResponse): Unit = {
+ val submissionId = request.getPathInfo.stripPrefix("/")
+ val responseMessage =
+ if (submissionId.nonEmpty) {
+ handleStatus(submissionId)
+ } else {
+ response.setStatus(HttpServletResponse.SC_BAD_REQUEST)
+ handleError("Submission ID is missing in status request.")
+ }
+ sendResponse(responseMessage, response)
+ }
+
+ private def handleStatus(submissionId: String): SubmissionStatusResponse = {
+ val askTimeout = AkkaUtils.askTimeout(conf)
+ val response = AkkaUtils.askWithReply[DeployMessages.DriverStatusResponse](
+ DeployMessages.RequestDriverStatus(submissionId), masterActor, askTimeout)
+ val message = response.exception.map { s"Exception from the cluster:\n" + formatException(_) }
+ val d = new SubmissionStatusResponse
+ d.serverSparkVersion = sparkVersion
+ d.submissionId = submissionId
+ d.success = response.found
+ d.driverState = response.state.map(_.toString).orNull
+ d.workerId = response.workerId.orNull
+ d.workerHostPort = response.workerHostPort.orNull
+ d.message = message.orNull
+ d
+ }
+}
+
+/**
+ * A servlet for handling submit requests passed to the [[StandaloneRestServer]].
+ */
+private class SubmitRequestServlet(
+ masterActor: ActorRef,
+ masterUrl: String,
+ conf: SparkConf)
+ extends StandaloneRestServlet {
+
+ /**
+ * Submit an application to the Master with parameters specified in the request.
+ *
+ * The request is assumed to be a [[SubmitRestProtocolRequest]] in the form of JSON.
+ * If the request is successfully processed, return an appropriate response to the
+ * client indicating so. Otherwise, return error instead.
+ */
+ protected override def doPost(
+ requestServlet: HttpServletRequest,
+ responseServlet: HttpServletResponse): Unit = {
+ val responseMessage =
+ try {
+ val requestMessageJson = Source.fromInputStream(requestServlet.getInputStream).mkString
+ val requestMessage = SubmitRestProtocolMessage.fromJson(requestMessageJson)
+ // The response should have already been validated on the client.
+ // In case this is not true, validate it ourselves to avoid potential NPEs.
+ requestMessage.validate()
+ handleSubmit(requestMessageJson, requestMessage, responseServlet)
+ } catch {
+ // The client failed to provide a valid JSON, so this is not our fault
+ case e @ (_: JsonMappingException | _: SubmitRestProtocolException) =>
+ responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST)
+ handleError("Malformed request: " + formatException(e))
+ }
+ sendResponse(responseMessage, responseServlet)
+ }
+
+ /**
+ * Handle the submit request and construct an appropriate response to return to the client.
+ *
+ * This assumes that the request message is already successfully validated.
+ * If the request message is not of the expected type, return error to the client.
+ */
+ private def handleSubmit(
+ requestMessageJson: String,
+ requestMessage: SubmitRestProtocolMessage,
+ responseServlet: HttpServletResponse): SubmitRestProtocolResponse = {
+ requestMessage match {
+ case submitRequest: CreateSubmissionRequest =>
+ val askTimeout = AkkaUtils.askTimeout(conf)
+ val driverDescription = buildDriverDescription(submitRequest)
+ val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse](
+ DeployMessages.RequestSubmitDriver(driverDescription), masterActor, askTimeout)
+ val submitResponse = new CreateSubmissionResponse
+ submitResponse.serverSparkVersion = sparkVersion
+ submitResponse.message = response.message
+ submitResponse.success = response.success
+ submitResponse.submissionId = response.driverId.orNull
+ val unknownFields = findUnknownFields(requestMessageJson, requestMessage)
+ if (unknownFields.nonEmpty) {
+ // If there are fields that the server does not know about, warn the client
+ submitResponse.unknownFields = unknownFields
+ }
+ submitResponse
+ case unexpected =>
+ responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST)
+ handleError(s"Received message of unexpected type ${unexpected.messageType}.")
+ }
+ }
+
+ /**
+ * Build a driver description from the fields specified in the submit request.
+ *
+ * This involves constructing a command that takes into account memory, java options,
+ * classpath and other settings to launch the driver. This does not currently consider
+ * fields used by python applications since python is not supported in standalone
+ * cluster mode yet.
+ */
+ private def buildDriverDescription(request: CreateSubmissionRequest): DriverDescription = {
+ // Required fields, including the main class because python is not yet supported
+ val appResource = Option(request.appResource).getOrElse {
+ throw new SubmitRestMissingFieldException("Application jar is missing.")
+ }
+ val mainClass = Option(request.mainClass).getOrElse {
+ throw new SubmitRestMissingFieldException("Main class is missing.")
+ }
+
+ // Optional fields
+ val sparkProperties = request.sparkProperties
+ val driverMemory = sparkProperties.get("spark.driver.memory")
+ val driverCores = sparkProperties.get("spark.driver.cores")
+ val driverExtraJavaOptions = sparkProperties.get("spark.driver.extraJavaOptions")
+ val driverExtraClassPath = sparkProperties.get("spark.driver.extraClassPath")
+ val driverExtraLibraryPath = sparkProperties.get("spark.driver.extraLibraryPath")
+ val superviseDriver = sparkProperties.get("spark.driver.supervise")
+ val appArgs = request.appArgs
+ val environmentVariables = request.environmentVariables
+
+ // Construct driver description
+ val conf = new SparkConf(false)
+ .setAll(sparkProperties)
+ .set("spark.master", masterUrl)
+ val extraClassPath = driverExtraClassPath.toSeq.flatMap(_.split(File.pathSeparator))
+ val extraLibraryPath = driverExtraLibraryPath.toSeq.flatMap(_.split(File.pathSeparator))
+ val extraJavaOpts = driverExtraJavaOptions.map(Utils.splitCommandString).getOrElse(Seq.empty)
+ val sparkJavaOpts = Utils.sparkJavaOpts(conf)
+ val javaOpts = sparkJavaOpts ++ extraJavaOpts
+ val command = new Command(
+ "org.apache.spark.deploy.worker.DriverWrapper",
+ Seq("{{WORKER_URL}}", mainClass) ++ appArgs, // args to the DriverWrapper
+ environmentVariables, extraClassPath, extraLibraryPath, javaOpts)
+ val actualDriverMemory = driverMemory.map(Utils.memoryStringToMb).getOrElse(DEFAULT_MEMORY)
+ val actualDriverCores = driverCores.map(_.toInt).getOrElse(DEFAULT_CORES)
+ val actualSuperviseDriver = superviseDriver.map(_.toBoolean).getOrElse(DEFAULT_SUPERVISE)
+ new DriverDescription(
+ appResource, actualDriverMemory, actualDriverCores, actualSuperviseDriver, command)
+ }
+}
+
+/**
+ * A default servlet that handles error cases that are not captured by other servlets.
+ */
+private class ErrorServlet extends StandaloneRestServlet {
+ private val serverVersion = StandaloneRestServer.PROTOCOL_VERSION
+
+ /** Service a faulty request by returning an appropriate error message to the client. */
+ protected override def service(
+ request: HttpServletRequest,
+ response: HttpServletResponse): Unit = {
+ val path = request.getPathInfo
+ val parts = path.stripPrefix("/").split("/").toSeq
+ var versionMismatch = false
+ var msg =
+ parts match {
+ case Nil =>
+ // http://host:port/
+ "Missing protocol version."
+ case `serverVersion` :: Nil =>
+ // http://host:port/correct-version
+ "Missing the /submissions prefix."
+ case `serverVersion` :: "submissions" :: Nil =>
+ // http://host:port/correct-version/submissions
+ "Missing an action: please specify one of /create, /kill, or /status."
+ case unknownVersion :: _ =>
+ // http://host:port/unknown-version/*
+ versionMismatch = true
+ s"Unknown protocol version '$unknownVersion'."
+ case _ =>
+ // never reached
+ s"Malformed path $path."
+ }
+ msg += s" Please submit requests through http://[host]:[port]/$serverVersion/submissions/..."
+ val error = handleError(msg)
+ // If there is a version mismatch, include the highest protocol version that
+ // this server supports in case the client wants to retry with our version
+ if (versionMismatch) {
+ error.highestProtocolVersion = serverVersion
+ response.setStatus(StandaloneRestServer.SC_UNKNOWN_PROTOCOL_VERSION)
+ } else {
+ response.setStatus(HttpServletResponse.SC_BAD_REQUEST)
+ }
+ sendResponse(error, response)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolException.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolException.scala
new file mode 100644
index 0000000000..d7a0bdbe10
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolException.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.rest
+
+/**
+ * An exception thrown in the REST application submission protocol.
+ */
+private[spark] class SubmitRestProtocolException(message: String, cause: Throwable = null)
+ extends Exception(message, cause)
+
+/**
+ * An exception thrown if a field is missing from a [[SubmitRestProtocolMessage]].
+ */
+private[spark] class SubmitRestMissingFieldException(message: String)
+ extends SubmitRestProtocolException(message)
+
+/**
+ * An exception thrown if the REST client cannot reach the REST server.
+ */
+private[spark] class SubmitRestConnectionException(message: String, cause: Throwable)
+ extends SubmitRestProtocolException(message, cause)
diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala
new file mode 100644
index 0000000000..b877898231
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.rest
+
+import scala.util.Try
+
+import com.fasterxml.jackson.annotation._
+import com.fasterxml.jackson.annotation.JsonAutoDetect.Visibility
+import com.fasterxml.jackson.annotation.JsonInclude.Include
+import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper, SerializationFeature}
+import com.fasterxml.jackson.module.scala.DefaultScalaModule
+import org.json4s.JsonAST._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.util.Utils
+
+/**
+ * An abstract message exchanged in the REST application submission protocol.
+ *
+ * This message is intended to be serialized to and deserialized from JSON in the exchange.
+ * Each message can either be a request or a response and consists of three common fields:
+ * (1) the action, which fully specifies the type of the message
+ * (2) the Spark version of the client / server
+ * (3) an optional message
+ */
+@JsonInclude(Include.NON_NULL)
+@JsonAutoDetect(getterVisibility = Visibility.ANY, setterVisibility = Visibility.ANY)
+@JsonPropertyOrder(alphabetic = true)
+private[spark] abstract class SubmitRestProtocolMessage {
+ @JsonIgnore
+ val messageType = Utils.getFormattedClassName(this)
+
+ val action: String = messageType
+ var message: String = null
+
+ // For JSON deserialization
+ private def setAction(a: String): Unit = { }
+
+ /**
+ * Serialize the message to JSON.
+ * This also ensures that the message is valid and its fields are in the expected format.
+ */
+ def toJson: String = {
+ validate()
+ SubmitRestProtocolMessage.mapper.writeValueAsString(this)
+ }
+
+ /**
+ * Assert the validity of the message.
+ * If the validation fails, throw a [[SubmitRestProtocolException]].
+ */
+ final def validate(): Unit = {
+ try {
+ doValidate()
+ } catch {
+ case e: Exception =>
+ throw new SubmitRestProtocolException(s"Validation of message $messageType failed!", e)
+ }
+ }
+
+ /** Assert the validity of the message */
+ protected def doValidate(): Unit = {
+ if (action == null) {
+ throw new SubmitRestMissingFieldException(s"The action field is missing in $messageType")
+ }
+ }
+
+ /** Assert that the specified field is set in this message. */
+ protected def assertFieldIsSet[T](value: T, name: String): Unit = {
+ if (value == null) {
+ throw new SubmitRestMissingFieldException(s"'$name' is missing in message $messageType.")
+ }
+ }
+
+ /**
+ * Assert a condition when validating this message.
+ * If the assertion fails, throw a [[SubmitRestProtocolException]].
+ */
+ protected def assert(condition: Boolean, failMessage: String): Unit = {
+ if (!condition) { throw new SubmitRestProtocolException(failMessage) }
+ }
+}
+
+/**
+ * Helper methods to process serialized [[SubmitRestProtocolMessage]]s.
+ */
+private[spark] object SubmitRestProtocolMessage {
+ private val packagePrefix = this.getClass.getPackage.getName
+ private val mapper = new ObjectMapper()
+ .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
+ .enable(SerializationFeature.INDENT_OUTPUT)
+ .registerModule(DefaultScalaModule)
+
+ /**
+ * Parse the value of the action field from the given JSON.
+ * If the action field is not found, throw a [[SubmitRestMissingFieldException]].
+ */
+ def parseAction(json: String): String = {
+ parse(json).asInstanceOf[JObject].obj
+ .find { case (f, _) => f == "action" }
+ .map { case (_, v) => v.asInstanceOf[JString].s }
+ .getOrElse {
+ throw new SubmitRestMissingFieldException(s"Action field not found in JSON:\n$json")
+ }
+ }
+
+ /**
+ * Construct a [[SubmitRestProtocolMessage]] from its JSON representation.
+ *
+ * This method first parses the action from the JSON and uses it to infer the message type.
+ * Note that the action must represent one of the [[SubmitRestProtocolMessage]]s defined in
+ * this package. Otherwise, a [[ClassNotFoundException]] will be thrown.
+ */
+ def fromJson(json: String): SubmitRestProtocolMessage = {
+ val className = parseAction(json)
+ val clazz = Class.forName(packagePrefix + "." + className)
+ .asSubclass[SubmitRestProtocolMessage](classOf[SubmitRestProtocolMessage])
+ fromJson(json, clazz)
+ }
+
+ /**
+ * Construct a [[SubmitRestProtocolMessage]] from its JSON representation.
+ *
+ * This method determines the type of the message from the class provided instead of
+ * inferring it from the action field. This is useful for deserializing JSON that
+ * represents custom user-defined messages.
+ */
+ def fromJson[T <: SubmitRestProtocolMessage](json: String, clazz: Class[T]): T = {
+ mapper.readValue(json, clazz)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala
new file mode 100644
index 0000000000..9e1fd8c40c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.rest
+
+import scala.util.Try
+
+import org.apache.spark.util.Utils
+
+/**
+ * An abstract request sent from the client in the REST application submission protocol.
+ */
+private[spark] abstract class SubmitRestProtocolRequest extends SubmitRestProtocolMessage {
+ var clientSparkVersion: String = null
+ protected override def doValidate(): Unit = {
+ super.doValidate()
+ assertFieldIsSet(clientSparkVersion, "clientSparkVersion")
+ }
+}
+
+/**
+ * A request to launch a new application in the REST application submission protocol.
+ */
+private[spark] class CreateSubmissionRequest extends SubmitRestProtocolRequest {
+ var appResource: String = null
+ var mainClass: String = null
+ var appArgs: Array[String] = null
+ var sparkProperties: Map[String, String] = null
+ var environmentVariables: Map[String, String] = null
+
+ protected override def doValidate(): Unit = {
+ super.doValidate()
+ assert(sparkProperties != null, "No Spark properties set!")
+ assertFieldIsSet(appResource, "appResource")
+ assertPropertyIsSet("spark.app.name")
+ assertPropertyIsBoolean("spark.driver.supervise")
+ assertPropertyIsNumeric("spark.driver.cores")
+ assertPropertyIsNumeric("spark.cores.max")
+ assertPropertyIsMemory("spark.driver.memory")
+ assertPropertyIsMemory("spark.executor.memory")
+ }
+
+ private def assertPropertyIsSet(key: String): Unit =
+ assertFieldIsSet(sparkProperties.getOrElse(key, null), key)
+
+ private def assertPropertyIsBoolean(key: String): Unit =
+ assertProperty[Boolean](key, "boolean", _.toBoolean)
+
+ private def assertPropertyIsNumeric(key: String): Unit =
+ assertProperty[Int](key, "numeric", _.toInt)
+
+ private def assertPropertyIsMemory(key: String): Unit =
+ assertProperty[Int](key, "memory", Utils.memoryStringToMb)
+
+ /** Assert that a Spark property can be converted to a certain type. */
+ private def assertProperty[T](key: String, valueType: String, convert: (String => T)): Unit = {
+ sparkProperties.get(key).foreach { value =>
+ Try(convert(value)).getOrElse {
+ throw new SubmitRestProtocolException(
+ s"Property '$key' expected $valueType value: actual was '$value'.")
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala
new file mode 100644
index 0000000000..16dfe041d4
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.rest
+
+import java.lang.Boolean
+
+/**
+ * An abstract response sent from the server in the REST application submission protocol.
+ */
+private[spark] abstract class SubmitRestProtocolResponse extends SubmitRestProtocolMessage {
+ var serverSparkVersion: String = null
+ var success: Boolean = null
+ var unknownFields: Array[String] = null
+ protected override def doValidate(): Unit = {
+ super.doValidate()
+ assertFieldIsSet(serverSparkVersion, "serverSparkVersion")
+ }
+}
+
+/**
+ * A response to a [[CreateSubmissionRequest]] in the REST application submission protocol.
+ */
+private[spark] class CreateSubmissionResponse extends SubmitRestProtocolResponse {
+ var submissionId: String = null
+ protected override def doValidate(): Unit = {
+ super.doValidate()
+ assertFieldIsSet(success, "success")
+ }
+}
+
+/**
+ * A response to a kill request in the REST application submission protocol.
+ */
+private[spark] class KillSubmissionResponse extends SubmitRestProtocolResponse {
+ var submissionId: String = null
+ protected override def doValidate(): Unit = {
+ super.doValidate()
+ assertFieldIsSet(submissionId, "submissionId")
+ assertFieldIsSet(success, "success")
+ }
+}
+
+/**
+ * A response to a status request in the REST application submission protocol.
+ */
+private[spark] class SubmissionStatusResponse extends SubmitRestProtocolResponse {
+ var submissionId: String = null
+ var driverState: String = null
+ var workerId: String = null
+ var workerHostPort: String = null
+
+ protected override def doValidate(): Unit = {
+ super.doValidate()
+ assertFieldIsSet(submissionId, "submissionId")
+ assertFieldIsSet(success, "success")
+ }
+}
+
+/**
+ * An error response message used in the REST application submission protocol.
+ */
+private[spark] class ErrorResponse extends SubmitRestProtocolResponse {
+ // The highest protocol version that the server knows about
+ // This is set when the client specifies an unknown version
+ var highestProtocolVersion: String = null
+ protected override def doValidate(): Unit = {
+ super.doValidate()
+ assertFieldIsSet(message, "message")
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
index ed02ca81e4..e955636cf5 100644
--- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
@@ -68,7 +68,8 @@ class JsonProtocolSuite extends FunSuite {
val completedApps = Array[ApplicationInfo]()
val activeDrivers = Array(createDriverInfo())
val completedDrivers = Array(createDriverInfo())
- val stateResponse = new MasterStateResponse("host", 8080, workers, activeApps, completedApps,
+ val stateResponse = new MasterStateResponse(
+ "host", 8080, None, workers, activeApps, completedApps,
activeDrivers, completedDrivers, RecoveryState.ALIVE)
val output = JsonProtocol.writeMasterState(stateResponse)
assertValidJson(output)
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
index 3f1355f828..1ddccae126 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
@@ -141,7 +141,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties
"thejar.jar",
"arg1", "arg2")
val appArgs = new SparkSubmitArguments(clArgs)
- val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs)
+ val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs)
val childArgsStr = childArgs.mkString(" ")
childArgsStr should include ("--class org.SomeClass")
childArgsStr should include ("--executor-memory 5g")
@@ -180,7 +180,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties
"thejar.jar",
"arg1", "arg2")
val appArgs = new SparkSubmitArguments(clArgs)
- val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs)
+ val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs)
childArgs.mkString(" ") should be ("arg1 arg2")
mainClass should be ("org.SomeClass")
classpath should have length (4)
@@ -201,6 +201,18 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties
}
test("handles standalone cluster mode") {
+ testStandaloneCluster(useRest = true)
+ }
+
+ test("handles legacy standalone cluster mode") {
+ testStandaloneCluster(useRest = false)
+ }
+
+ /**
+ * Test whether the launch environment is correctly set up in standalone cluster mode.
+ * @param useRest whether to use the REST submission gateway introduced in Spark 1.3
+ */
+ private def testStandaloneCluster(useRest: Boolean): Unit = {
val clArgs = Seq(
"--deploy-mode", "cluster",
"--master", "spark://h:p",
@@ -212,17 +224,26 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties
"thejar.jar",
"arg1", "arg2")
val appArgs = new SparkSubmitArguments(clArgs)
- val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs)
+ appArgs.useRest = useRest
+ val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs)
val childArgsStr = childArgs.mkString(" ")
- childArgsStr should startWith ("--memory 4g --cores 5 --supervise")
- childArgsStr should include regex ("launch spark://h:p .*thejar.jar org.SomeClass arg1 arg2")
- mainClass should be ("org.apache.spark.deploy.Client")
- classpath should have size (0)
- sysProps should have size (5)
+ if (useRest) {
+ childArgsStr should endWith ("thejar.jar org.SomeClass arg1 arg2")
+ mainClass should be ("org.apache.spark.deploy.rest.StandaloneRestClient")
+ } else {
+ childArgsStr should startWith ("--supervise --memory 4g --cores 5")
+ childArgsStr should include regex "launch spark://h:p .*thejar.jar org.SomeClass arg1 arg2"
+ mainClass should be ("org.apache.spark.deploy.Client")
+ }
+ classpath should have size 0
+ sysProps should have size 8
sysProps.keys should contain ("SPARK_SUBMIT")
sysProps.keys should contain ("spark.master")
sysProps.keys should contain ("spark.app.name")
sysProps.keys should contain ("spark.jars")
+ sysProps.keys should contain ("spark.driver.memory")
+ sysProps.keys should contain ("spark.driver.cores")
+ sysProps.keys should contain ("spark.driver.supervise")
sysProps.keys should contain ("spark.shuffle.spill")
sysProps("spark.shuffle.spill") should be ("false")
}
@@ -239,7 +260,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties
"thejar.jar",
"arg1", "arg2")
val appArgs = new SparkSubmitArguments(clArgs)
- val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs)
+ val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs)
childArgs.mkString(" ") should be ("arg1 arg2")
mainClass should be ("org.SomeClass")
classpath should have length (1)
@@ -261,7 +282,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties
"thejar.jar",
"arg1", "arg2")
val appArgs = new SparkSubmitArguments(clArgs)
- val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs)
+ val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs)
childArgs.mkString(" ") should be ("arg1 arg2")
mainClass should be ("org.SomeClass")
classpath should have length (1)
@@ -281,7 +302,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties
"thejar.jar",
"arg1", "arg2")
val appArgs = new SparkSubmitArguments(clArgs)
- val (_, _, sysProps, mainClass) = createLaunchEnv(appArgs)
+ val (_, _, sysProps, mainClass) = prepareSubmitEnvironment(appArgs)
sysProps("spark.executor.memory") should be ("5g")
sysProps("spark.master") should be ("yarn-cluster")
mainClass should be ("org.apache.spark.deploy.yarn.Client")
@@ -339,7 +360,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties
"--files", files,
"thejar.jar")
val appArgs = new SparkSubmitArguments(clArgs)
- val sysProps = SparkSubmit.createLaunchEnv(appArgs)._3
+ val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs)._3
appArgs.jars should be (Utils.resolveURIs(jars))
appArgs.files should be (Utils.resolveURIs(files))
sysProps("spark.jars") should be (Utils.resolveURIs(jars + ",thejar.jar"))
@@ -354,7 +375,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties
"thejar.jar"
)
val appArgs2 = new SparkSubmitArguments(clArgs2)
- val sysProps2 = SparkSubmit.createLaunchEnv(appArgs2)._3
+ val sysProps2 = SparkSubmit.prepareSubmitEnvironment(appArgs2)._3
appArgs2.files should be (Utils.resolveURIs(files))
appArgs2.archives should be (Utils.resolveURIs(archives))
sysProps2("spark.yarn.dist.files") should be (Utils.resolveURIs(files))
@@ -367,7 +388,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties
"mister.py"
)
val appArgs3 = new SparkSubmitArguments(clArgs3)
- val sysProps3 = SparkSubmit.createLaunchEnv(appArgs3)._3
+ val sysProps3 = SparkSubmit.prepareSubmitEnvironment(appArgs3)._3
appArgs3.pyFiles should be (Utils.resolveURIs(pyFiles))
sysProps3("spark.submit.pyFiles") should be (
PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(","))
@@ -392,7 +413,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties
"thejar.jar"
)
val appArgs = new SparkSubmitArguments(clArgs)
- val sysProps = SparkSubmit.createLaunchEnv(appArgs)._3
+ val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs)._3
sysProps("spark.jars") should be(Utils.resolveURIs(jars + ",thejar.jar"))
sysProps("spark.files") should be(Utils.resolveURIs(files))
@@ -409,7 +430,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties
"thejar.jar"
)
val appArgs2 = new SparkSubmitArguments(clArgs2)
- val sysProps2 = SparkSubmit.createLaunchEnv(appArgs2)._3
+ val sysProps2 = SparkSubmit.prepareSubmitEnvironment(appArgs2)._3
sysProps2("spark.yarn.dist.files") should be(Utils.resolveURIs(files))
sysProps2("spark.yarn.dist.archives") should be(Utils.resolveURIs(archives))
@@ -424,7 +445,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties
"mister.py"
)
val appArgs3 = new SparkSubmitArguments(clArgs3)
- val sysProps3 = SparkSubmit.createLaunchEnv(appArgs3)._3
+ val sysProps3 = SparkSubmit.prepareSubmitEnvironment(appArgs3)._3
sysProps3("spark.submit.pyFiles") should be(
PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(","))
}
@@ -440,7 +461,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties
val appArgs = new SparkSubmitArguments(args, Map("SPARK_CONF_DIR" -> path))
assert(appArgs.propertiesFile != null)
assert(appArgs.propertiesFile.startsWith(path))
- appArgs.executorMemory should be ("2.3g")
+ appArgs.executorMemory should be ("2.3g")
}
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala
new file mode 100644
index 0000000000..29aed89b67
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala
@@ -0,0 +1,265 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.rest
+
+import java.io.{File, FileInputStream, FileOutputStream, PrintWriter}
+import java.util.jar.{JarEntry, JarOutputStream}
+import java.util.zip.ZipEntry
+
+import scala.collection.mutable.ArrayBuffer
+import scala.io.Source
+
+import akka.actor.ActorSystem
+import com.google.common.io.ByteStreams
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}
+import org.scalatest.exceptions.TestFailedException
+
+import org.apache.spark._
+import org.apache.spark.util.Utils
+import org.apache.spark.deploy.{SparkSubmit, SparkSubmitArguments}
+import org.apache.spark.deploy.master.{DriverState, Master}
+import org.apache.spark.deploy.worker.Worker
+
+/**
+ * End-to-end tests for the REST application submission protocol in standalone mode.
+ */
+class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach {
+ private val systemsToStop = new ArrayBuffer[ActorSystem]
+ private val masterRestUrl = startLocalCluster()
+ private val client = new StandaloneRestClient
+ private val mainJar = StandaloneRestSubmitSuite.createJar()
+ private val mainClass = StandaloneRestApp.getClass.getName.stripSuffix("$")
+
+ override def afterAll() {
+ systemsToStop.foreach(_.shutdown())
+ }
+
+ test("simple submit until completion") {
+ val resultsFile = File.createTempFile("test-submit", ".txt")
+ val numbers = Seq(1, 2, 3)
+ val size = 500
+ val submissionId = submitApplication(resultsFile, numbers, size)
+ waitUntilFinished(submissionId)
+ validateResult(resultsFile, numbers, size)
+ }
+
+ test("kill empty submission") {
+ val response = client.killSubmission(masterRestUrl, "submission-that-does-not-exist")
+ val killResponse = getKillResponse(response)
+ val killSuccess = killResponse.success
+ assert(!killSuccess)
+ }
+
+ test("kill running submission") {
+ val resultsFile = File.createTempFile("test-kill", ".txt")
+ val numbers = Seq(1, 2, 3)
+ val size = 500
+ val submissionId = submitApplication(resultsFile, numbers, size)
+ val response = client.killSubmission(masterRestUrl, submissionId)
+ val killResponse = getKillResponse(response)
+ val killSuccess = killResponse.success
+ waitUntilFinished(submissionId)
+ val response2 = client.requestSubmissionStatus(masterRestUrl, submissionId)
+ val statusResponse = getStatusResponse(response2)
+ val statusSuccess = statusResponse.success
+ val driverState = statusResponse.driverState
+ assert(killSuccess)
+ assert(statusSuccess)
+ assert(driverState === DriverState.KILLED.toString)
+ // we should not see the expected results because we killed the submission
+ intercept[TestFailedException] { validateResult(resultsFile, numbers, size) }
+ }
+
+ test("request status for empty submission") {
+ val response = client.requestSubmissionStatus(masterRestUrl, "submission-that-does-not-exist")
+ val statusResponse = getStatusResponse(response)
+ val statusSuccess = statusResponse.success
+ assert(!statusSuccess)
+ }
+
+ /**
+ * Start a local cluster containing one Master and a few Workers.
+ * Do not use [[org.apache.spark.deploy.LocalSparkCluster]] here because we want the REST URL.
+ * Return the Master's REST URL to which applications should be submitted.
+ */
+ private def startLocalCluster(): String = {
+ val conf = new SparkConf(false)
+ .set("spark.master.rest.enabled", "true")
+ .set("spark.master.rest.port", "0")
+ val (numWorkers, coresPerWorker, memPerWorker) = (2, 1, 512)
+ val localHostName = Utils.localHostName()
+ val (masterSystem, masterPort, _, _masterRestPort) =
+ Master.startSystemAndActor(localHostName, 0, 0, conf)
+ val masterRestPort = _masterRestPort.getOrElse { fail("REST server not started on Master!") }
+ val masterUrl = "spark://" + localHostName + ":" + masterPort
+ val masterRestUrl = "spark://" + localHostName + ":" + masterRestPort
+ (1 to numWorkers).foreach { n =>
+ val (workerSystem, _) = Worker.startSystemAndActor(
+ localHostName, 0, 0, coresPerWorker, memPerWorker, Array(masterUrl), null, Some(n))
+ systemsToStop.append(workerSystem)
+ }
+ systemsToStop.append(masterSystem)
+ masterRestUrl
+ }
+
+ /** Submit the [[StandaloneRestApp]] and return the corresponding submission ID. */
+ private def submitApplication(resultsFile: File, numbers: Seq[Int], size: Int): String = {
+ val appArgs = Seq(resultsFile.getAbsolutePath) ++ numbers.map(_.toString) ++ Seq(size.toString)
+ val commandLineArgs = Array(
+ "--deploy-mode", "cluster",
+ "--master", masterRestUrl,
+ "--name", mainClass,
+ "--class", mainClass,
+ mainJar) ++ appArgs
+ val args = new SparkSubmitArguments(commandLineArgs)
+ val (_, _, sparkProperties, _) = SparkSubmit.prepareSubmitEnvironment(args)
+ val request = client.constructSubmitRequest(
+ mainJar, mainClass, appArgs.toArray, sparkProperties.toMap, Map.empty)
+ val response = client.createSubmission(masterRestUrl, request)
+ val submitResponse = getSubmitResponse(response)
+ val submissionId = submitResponse.submissionId
+ assert(submissionId != null, "Application submission was unsuccessful!")
+ submissionId
+ }
+
+ /** Wait until the given submission has finished running up to the specified timeout. */
+ private def waitUntilFinished(submissionId: String, maxSeconds: Int = 30): Unit = {
+ var finished = false
+ val expireTime = System.currentTimeMillis + maxSeconds * 1000
+ while (!finished) {
+ val response = client.requestSubmissionStatus(masterRestUrl, submissionId)
+ val statusResponse = getStatusResponse(response)
+ val driverState = statusResponse.driverState
+ finished =
+ driverState != DriverState.SUBMITTED.toString &&
+ driverState != DriverState.RUNNING.toString
+ if (System.currentTimeMillis > expireTime) {
+ fail(s"Driver $submissionId did not finish within $maxSeconds seconds.")
+ }
+ }
+ }
+
+ /** Return the response as a submit response, or fail with error otherwise. */
+ private def getSubmitResponse(response: SubmitRestProtocolResponse): CreateSubmissionResponse = {
+ response match {
+ case s: CreateSubmissionResponse => s
+ case e: ErrorResponse => fail(s"Server returned error: ${e.message}")
+ case r => fail(s"Expected submit response. Actual: ${r.toJson}")
+ }
+ }
+
+ /** Return the response as a kill response, or fail with error otherwise. */
+ private def getKillResponse(response: SubmitRestProtocolResponse): KillSubmissionResponse = {
+ response match {
+ case k: KillSubmissionResponse => k
+ case e: ErrorResponse => fail(s"Server returned error: ${e.message}")
+ case r => fail(s"Expected kill response. Actual: ${r.toJson}")
+ }
+ }
+
+ /** Return the response as a status response, or fail with error otherwise. */
+ private def getStatusResponse(response: SubmitRestProtocolResponse): SubmissionStatusResponse = {
+ response match {
+ case s: SubmissionStatusResponse => s
+ case e: ErrorResponse => fail(s"Server returned error: ${e.message}")
+ case r => fail(s"Expected status response. Actual: ${r.toJson}")
+ }
+ }
+
+ /** Validate whether the application produced the corrupt output. */
+ private def validateResult(resultsFile: File, numbers: Seq[Int], size: Int): Unit = {
+ val lines = Source.fromFile(resultsFile.getAbsolutePath).getLines().toSeq
+ val unexpectedContent =
+ if (lines.nonEmpty) {
+ "[\n" + lines.map { l => " " + l }.mkString("\n") + "\n]"
+ } else {
+ "[EMPTY]"
+ }
+ assert(lines.size === 2, s"Unexpected content in file: $unexpectedContent")
+ assert(lines(0).toInt === numbers.sum, s"Sum of ${numbers.mkString(",")} is incorrect")
+ assert(lines(1).toInt === (size / 2) + 1, "Result of Spark job is incorrect")
+ }
+}
+
+private object StandaloneRestSubmitSuite {
+ private val pathPrefix = this.getClass.getPackage.getName.replaceAll("\\.", "/")
+
+ /**
+ * Create a jar that contains all the class files needed for running the [[StandaloneRestApp]].
+ * Return the absolute path to that jar.
+ */
+ def createJar(): String = {
+ val jarFile = File.createTempFile("test-standalone-rest-protocol", ".jar")
+ val jarFileStream = new FileOutputStream(jarFile)
+ val jarStream = new JarOutputStream(jarFileStream, new java.util.jar.Manifest)
+ jarStream.putNextEntry(new ZipEntry(pathPrefix))
+ getClassFiles.foreach { cf =>
+ jarStream.putNextEntry(new JarEntry(pathPrefix + "/" + cf.getName))
+ val in = new FileInputStream(cf)
+ ByteStreams.copy(in, jarStream)
+ in.close()
+ }
+ jarStream.close()
+ jarFileStream.close()
+ jarFile.getAbsolutePath
+ }
+
+ /**
+ * Return a list of class files compiled for [[StandaloneRestApp]].
+ * This includes all the anonymous classes used in the application.
+ */
+ private def getClassFiles: Seq[File] = {
+ val className = Utils.getFormattedClassName(StandaloneRestApp)
+ val clazz = StandaloneRestApp.getClass
+ val basePath = clazz.getProtectionDomain.getCodeSource.getLocation.toURI.getPath
+ val baseDir = new File(basePath + "/" + pathPrefix)
+ baseDir.listFiles().filter(_.getName.contains(className))
+ }
+}
+
+/**
+ * Sample application to be submitted to the cluster using the REST gateway.
+ * All relevant classes will be packaged into a jar at run time.
+ */
+object StandaloneRestApp {
+ // Usage: [path to results file] [num1] [num2] [num3] [rddSize]
+ // The first line of the results file should be (num1 + num2 + num3)
+ // The second line should be (rddSize / 2) + 1
+ def main(args: Array[String]) {
+ assert(args.size == 5, s"Expected exactly 5 arguments: ${args.mkString(",")}")
+ val resultFile = new File(args(0))
+ val writer = new PrintWriter(resultFile)
+ try {
+ val conf = new SparkConf()
+ val sc = new SparkContext(conf)
+ val firstLine = args(1).toInt + args(2).toInt + args(3).toInt
+ val secondLine = sc.parallelize(1 to args(4).toInt)
+ .map { i => (i / 2, i) }
+ .reduceByKey(_ + _)
+ .count()
+ writer.println(firstLine)
+ writer.println(secondLine)
+ } catch {
+ case e: Exception =>
+ writer.println(e)
+ e.getStackTrace.foreach { l => writer.println(" " + l) }
+ } finally {
+ writer.close()
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala
new file mode 100644
index 0000000000..1d64ec201e
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala
@@ -0,0 +1,324 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.rest
+
+import java.lang.Boolean
+import java.lang.Integer
+
+import org.json4s.jackson.JsonMethods._
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkConf
+
+/**
+ * Tests for the REST application submission protocol.
+ */
+class SubmitRestProtocolSuite extends FunSuite {
+
+ test("validate") {
+ val request = new DummyRequest
+ intercept[SubmitRestProtocolException] { request.validate() } // missing everything
+ request.clientSparkVersion = "1.2.3"
+ intercept[SubmitRestProtocolException] { request.validate() } // missing name and age
+ request.name = "something"
+ intercept[SubmitRestProtocolException] { request.validate() } // missing only age
+ request.age = 2
+ intercept[SubmitRestProtocolException] { request.validate() } // age too low
+ request.age = 10
+ request.validate() // everything is set properly
+ request.clientSparkVersion = null
+ intercept[SubmitRestProtocolException] { request.validate() } // missing only Spark version
+ request.clientSparkVersion = "1.2.3"
+ request.name = null
+ intercept[SubmitRestProtocolException] { request.validate() } // missing only name
+ request.message = "not-setting-name"
+ intercept[SubmitRestProtocolException] { request.validate() } // still missing name
+ }
+
+ test("request to and from JSON") {
+ val request = new DummyRequest
+ intercept[SubmitRestProtocolException] { request.toJson } // implicit validation
+ request.clientSparkVersion = "1.2.3"
+ request.active = true
+ request.age = 25
+ request.name = "jung"
+ val json = request.toJson
+ assertJsonEquals(json, dummyRequestJson)
+ val newRequest = SubmitRestProtocolMessage.fromJson(json, classOf[DummyRequest])
+ assert(newRequest.clientSparkVersion === "1.2.3")
+ assert(newRequest.clientSparkVersion === "1.2.3")
+ assert(newRequest.active)
+ assert(newRequest.age === 25)
+ assert(newRequest.name === "jung")
+ assert(newRequest.message === null)
+ }
+
+ test("response to and from JSON") {
+ val response = new DummyResponse
+ response.serverSparkVersion = "3.3.4"
+ response.success = true
+ val json = response.toJson
+ assertJsonEquals(json, dummyResponseJson)
+ val newResponse = SubmitRestProtocolMessage.fromJson(json, classOf[DummyResponse])
+ assert(newResponse.serverSparkVersion === "3.3.4")
+ assert(newResponse.serverSparkVersion === "3.3.4")
+ assert(newResponse.success)
+ assert(newResponse.message === null)
+ }
+
+ test("CreateSubmissionRequest") {
+ val message = new CreateSubmissionRequest
+ intercept[SubmitRestProtocolException] { message.validate() }
+ message.clientSparkVersion = "1.2.3"
+ message.appResource = "honey-walnut-cherry.jar"
+ message.mainClass = "org.apache.spark.examples.SparkPie"
+ val conf = new SparkConf(false)
+ conf.set("spark.app.name", "SparkPie")
+ message.sparkProperties = conf.getAll.toMap
+ message.validate()
+ // optional fields
+ conf.set("spark.jars", "mayonnaise.jar,ketchup.jar")
+ conf.set("spark.files", "fireball.png")
+ conf.set("spark.driver.memory", "512m")
+ conf.set("spark.driver.cores", "180")
+ conf.set("spark.driver.extraJavaOptions", " -Dslices=5 -Dcolor=mostly_red")
+ conf.set("spark.driver.extraClassPath", "food-coloring.jar")
+ conf.set("spark.driver.extraLibraryPath", "pickle.jar")
+ conf.set("spark.driver.supervise", "false")
+ conf.set("spark.executor.memory", "256m")
+ conf.set("spark.cores.max", "10000")
+ message.sparkProperties = conf.getAll.toMap
+ message.appArgs = Array("two slices", "a hint of cinnamon")
+ message.environmentVariables = Map("PATH" -> "/dev/null")
+ message.validate()
+ // bad fields
+ var badConf = conf.clone().set("spark.driver.cores", "one hundred feet")
+ message.sparkProperties = badConf.getAll.toMap
+ intercept[SubmitRestProtocolException] { message.validate() }
+ badConf = conf.clone().set("spark.driver.supervise", "nope, never")
+ message.sparkProperties = badConf.getAll.toMap
+ intercept[SubmitRestProtocolException] { message.validate() }
+ badConf = conf.clone().set("spark.cores.max", "two men")
+ message.sparkProperties = badConf.getAll.toMap
+ intercept[SubmitRestProtocolException] { message.validate() }
+ message.sparkProperties = conf.getAll.toMap
+ // test JSON
+ val json = message.toJson
+ assertJsonEquals(json, submitDriverRequestJson)
+ val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[CreateSubmissionRequest])
+ assert(newMessage.clientSparkVersion === "1.2.3")
+ assert(newMessage.appResource === "honey-walnut-cherry.jar")
+ assert(newMessage.mainClass === "org.apache.spark.examples.SparkPie")
+ assert(newMessage.sparkProperties("spark.app.name") === "SparkPie")
+ assert(newMessage.sparkProperties("spark.jars") === "mayonnaise.jar,ketchup.jar")
+ assert(newMessage.sparkProperties("spark.files") === "fireball.png")
+ assert(newMessage.sparkProperties("spark.driver.memory") === "512m")
+ assert(newMessage.sparkProperties("spark.driver.cores") === "180")
+ assert(newMessage.sparkProperties("spark.driver.extraJavaOptions") === " -Dslices=5 -Dcolor=mostly_red")
+ assert(newMessage.sparkProperties("spark.driver.extraClassPath") === "food-coloring.jar")
+ assert(newMessage.sparkProperties("spark.driver.extraLibraryPath") === "pickle.jar")
+ assert(newMessage.sparkProperties("spark.driver.supervise") === "false")
+ assert(newMessage.sparkProperties("spark.executor.memory") === "256m")
+ assert(newMessage.sparkProperties("spark.cores.max") === "10000")
+ assert(newMessage.appArgs === message.appArgs)
+ assert(newMessage.sparkProperties === message.sparkProperties)
+ assert(newMessage.environmentVariables === message.environmentVariables)
+ }
+
+ test("CreateSubmissionResponse") {
+ val message = new CreateSubmissionResponse
+ intercept[SubmitRestProtocolException] { message.validate() }
+ message.serverSparkVersion = "1.2.3"
+ message.submissionId = "driver_123"
+ message.success = true
+ message.validate()
+ // test JSON
+ val json = message.toJson
+ assertJsonEquals(json, submitDriverResponseJson)
+ val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[CreateSubmissionResponse])
+ assert(newMessage.serverSparkVersion === "1.2.3")
+ assert(newMessage.submissionId === "driver_123")
+ assert(newMessage.success)
+ }
+
+ test("KillSubmissionResponse") {
+ val message = new KillSubmissionResponse
+ intercept[SubmitRestProtocolException] { message.validate() }
+ message.serverSparkVersion = "1.2.3"
+ message.submissionId = "driver_123"
+ message.success = true
+ message.validate()
+ // test JSON
+ val json = message.toJson
+ assertJsonEquals(json, killDriverResponseJson)
+ val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[KillSubmissionResponse])
+ assert(newMessage.serverSparkVersion === "1.2.3")
+ assert(newMessage.submissionId === "driver_123")
+ assert(newMessage.success)
+ }
+
+ test("SubmissionStatusResponse") {
+ val message = new SubmissionStatusResponse
+ intercept[SubmitRestProtocolException] { message.validate() }
+ message.serverSparkVersion = "1.2.3"
+ message.submissionId = "driver_123"
+ message.success = true
+ message.validate()
+ // optional fields
+ message.driverState = "RUNNING"
+ message.workerId = "worker_123"
+ message.workerHostPort = "1.2.3.4:7780"
+ // test JSON
+ val json = message.toJson
+ assertJsonEquals(json, driverStatusResponseJson)
+ val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[SubmissionStatusResponse])
+ assert(newMessage.serverSparkVersion === "1.2.3")
+ assert(newMessage.submissionId === "driver_123")
+ assert(newMessage.driverState === "RUNNING")
+ assert(newMessage.success)
+ assert(newMessage.workerId === "worker_123")
+ assert(newMessage.workerHostPort === "1.2.3.4:7780")
+ }
+
+ test("ErrorResponse") {
+ val message = new ErrorResponse
+ intercept[SubmitRestProtocolException] { message.validate() }
+ message.serverSparkVersion = "1.2.3"
+ message.message = "Field not found in submit request: X"
+ message.validate()
+ // test JSON
+ val json = message.toJson
+ assertJsonEquals(json, errorJson)
+ val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[ErrorResponse])
+ assert(newMessage.serverSparkVersion === "1.2.3")
+ assert(newMessage.message === "Field not found in submit request: X")
+ }
+
+ private val dummyRequestJson =
+ """
+ |{
+ | "action" : "DummyRequest",
+ | "active" : true,
+ | "age" : 25,
+ | "clientSparkVersion" : "1.2.3",
+ | "name" : "jung"
+ |}
+ """.stripMargin
+
+ private val dummyResponseJson =
+ """
+ |{
+ | "action" : "DummyResponse",
+ | "serverSparkVersion" : "3.3.4",
+ | "success": true
+ |}
+ """.stripMargin
+
+ private val submitDriverRequestJson =
+ """
+ |{
+ | "action" : "CreateSubmissionRequest",
+ | "appArgs" : [ "two slices", "a hint of cinnamon" ],
+ | "appResource" : "honey-walnut-cherry.jar",
+ | "clientSparkVersion" : "1.2.3",
+ | "environmentVariables" : {
+ | "PATH" : "/dev/null"
+ | },
+ | "mainClass" : "org.apache.spark.examples.SparkPie",
+ | "sparkProperties" : {
+ | "spark.driver.extraLibraryPath" : "pickle.jar",
+ | "spark.jars" : "mayonnaise.jar,ketchup.jar",
+ | "spark.driver.supervise" : "false",
+ | "spark.app.name" : "SparkPie",
+ | "spark.cores.max" : "10000",
+ | "spark.driver.memory" : "512m",
+ | "spark.files" : "fireball.png",
+ | "spark.driver.cores" : "180",
+ | "spark.driver.extraJavaOptions" : " -Dslices=5 -Dcolor=mostly_red",
+ | "spark.executor.memory" : "256m",
+ | "spark.driver.extraClassPath" : "food-coloring.jar"
+ | }
+ |}
+ """.stripMargin
+
+ private val submitDriverResponseJson =
+ """
+ |{
+ | "action" : "CreateSubmissionResponse",
+ | "serverSparkVersion" : "1.2.3",
+ | "submissionId" : "driver_123",
+ | "success" : true
+ |}
+ """.stripMargin
+
+ private val killDriverResponseJson =
+ """
+ |{
+ | "action" : "KillSubmissionResponse",
+ | "serverSparkVersion" : "1.2.3",
+ | "submissionId" : "driver_123",
+ | "success" : true
+ |}
+ """.stripMargin
+
+ private val driverStatusResponseJson =
+ """
+ |{
+ | "action" : "SubmissionStatusResponse",
+ | "driverState" : "RUNNING",
+ | "serverSparkVersion" : "1.2.3",
+ | "submissionId" : "driver_123",
+ | "success" : true,
+ | "workerHostPort" : "1.2.3.4:7780",
+ | "workerId" : "worker_123"
+ |}
+ """.stripMargin
+
+ private val errorJson =
+ """
+ |{
+ | "action" : "ErrorResponse",
+ | "message" : "Field not found in submit request: X",
+ | "serverSparkVersion" : "1.2.3"
+ |}
+ """.stripMargin
+
+ /** Assert that the contents in the two JSON strings are equal after ignoring whitespace. */
+ private def assertJsonEquals(jsonString1: String, jsonString2: String): Unit = {
+ val trimmedJson1 = jsonString1.trim
+ val trimmedJson2 = jsonString2.trim
+ val json1 = compact(render(parse(trimmedJson1)))
+ val json2 = compact(render(parse(trimmedJson2)))
+ // Put this on a separate line to avoid printing comparison twice when test fails
+ val equals = json1 == json2
+ assert(equals, "\"[%s]\" did not equal \"[%s]\"".format(trimmedJson1, trimmedJson2))
+ }
+}
+
+private class DummyResponse extends SubmitRestProtocolResponse
+private class DummyRequest extends SubmitRestProtocolRequest {
+ var active: Boolean = null
+ var age: Integer = null
+ var name: String = null
+ protected override def doValidate(): Unit = {
+ super.doValidate()
+ assertFieldIsSet(name, "name")
+ assertFieldIsSet(age, "age")
+ assert(age > 5, "Not old enough!")
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala
index 855f1b6276..054a4c6489 100644
--- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala
@@ -29,9 +29,9 @@ class KryoSerializerDistributedSuite extends FunSuite {
test("kryo objects are serialised consistently in different processes") {
val conf = new SparkConf(false)
- conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
- conf.set("spark.kryo.registrator", classOf[AppJarRegistrator].getName)
- conf.set("spark.task.maxFailures", "1")
+ .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+ .set("spark.kryo.registrator", classOf[AppJarRegistrator].getName)
+ .set("spark.task.maxFailures", "1")
val jar = TestUtils.createJarWithClasses(List(AppJarRegistrator.customClassName))
conf.setJars(List(jar.getPath))
diff --git a/pom.xml b/pom.xml
index aef450ae63..da8ee077dd 100644
--- a/pom.xml
+++ b/pom.xml
@@ -154,6 +154,7 @@
<jline.groupid>org.scala-lang</jline.groupid>
<jodd.version>3.6.3</jodd.version>
<codehaus.jackson.version>1.8.8</codehaus.jackson.version>
+ <fasterxml.jackson.version>2.4.4</fasterxml.jackson.version>
<snappy.version>1.1.1.6</snappy.version>
<!--
@@ -578,6 +579,16 @@
<version>${codahale.metrics.version}</version>
</dependency>
<dependency>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-databind</artifactId>
+ <version>${fasterxml.jackson.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.module</groupId>
+ <artifactId>jackson-module-scala_2.10</artifactId>
+ <version>${fasterxml.jackson.version}</version>
+ </dependency>
+ <dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-compiler</artifactId>
<version>${scala.version}</version>