aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrew Or <andrew@databricks.com>2015-02-12 14:47:52 -0800
committerAndrew Or <andrew@databricks.com>2015-02-12 14:47:52 -0800
commit1d5663e92cdaaa3dabfa58fdd7aede7e4fa4ec63 (patch)
tree0f737e6b9f718ef0391df50cdf0ea5d315fc71d7
parent47c73d410ab533c3196184d2b6004081e79daeaa (diff)
downloadspark-1d5663e92cdaaa3dabfa58fdd7aede7e4fa4ec63.tar.gz
spark-1d5663e92cdaaa3dabfa58fdd7aede7e4fa4ec63.tar.bz2
spark-1d5663e92cdaaa3dabfa58fdd7aede7e4fa4ec63.zip
[SPARK-5760][SPARK-5761] Fix standalone rest protocol corner cases + revamp tests
The changes are summarized in the commit message. Test or test-related code accounts for 90% of the lines changed. Author: Andrew Or <andrew@databricks.com> Closes #4557 from andrewor14/rest-tests and squashes the following commits: b4dc980 [Andrew Or] Merge branch 'master' of github.com:apache/spark into rest-tests b55e40f [Andrew Or] Add test for unknown fields cc96993 [Andrew Or] private[spark] -> private[rest] 578cf45 [Andrew Or] Clean up test code a little d82d971 [Andrew Or] v1 -> serverVersion ea48f65 [Andrew Or] Merge branch 'master' of github.com:apache/spark into rest-tests 00999a8 [Andrew Or] Revamp tests + fix a few corner cases
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala52
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala105
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala671
3 files changed, 589 insertions, 239 deletions
diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala
index 115aa5278b..c4be1f19e8 100644
--- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala
@@ -19,10 +19,11 @@ package org.apache.spark.deploy.rest
import java.io.{DataOutputStream, FileNotFoundException}
import java.net.{HttpURLConnection, SocketException, URL}
+import javax.servlet.http.HttpServletResponse
import scala.io.Source
-import com.fasterxml.jackson.databind.JsonMappingException
+import com.fasterxml.jackson.core.JsonProcessingException
import com.google.common.base.Charsets
import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion}
@@ -155,10 +156,21 @@ private[spark] class StandaloneRestClient extends Logging {
/**
* Read the response from the server and return it as a validated [[SubmitRestProtocolResponse]].
* If the response represents an error, report the embedded message to the user.
+ * Exposed for testing.
*/
- private def readResponse(connection: HttpURLConnection): SubmitRestProtocolResponse = {
+ private[rest] def readResponse(connection: HttpURLConnection): SubmitRestProtocolResponse = {
try {
- val responseJson = Source.fromInputStream(connection.getInputStream).mkString
+ val dataStream =
+ if (connection.getResponseCode == HttpServletResponse.SC_OK) {
+ connection.getInputStream
+ } else {
+ connection.getErrorStream
+ }
+ // If the server threw an exception while writing a response, it will not have a body
+ if (dataStream == null) {
+ throw new SubmitRestProtocolException("Server returned empty body")
+ }
+ val responseJson = Source.fromInputStream(dataStream).mkString
logDebug(s"Response from the server:\n$responseJson")
val response = SubmitRestProtocolMessage.fromJson(responseJson)
response.validate()
@@ -177,7 +189,7 @@ private[spark] class StandaloneRestClient extends Logging {
case unreachable @ (_: FileNotFoundException | _: SocketException) =>
throw new SubmitRestConnectionException(
s"Unable to connect to server ${connection.getURL}", unreachable)
- case malformed @ (_: SubmitRestProtocolException | _: JsonMappingException) =>
+ case malformed @ (_: JsonProcessingException | _: SubmitRestProtocolException) =>
throw new SubmitRestProtocolException(
"Malformed response received from server", malformed)
}
@@ -284,7 +296,27 @@ private[spark] object StandaloneRestClient {
val REPORT_DRIVER_STATUS_MAX_TRIES = 10
val PROTOCOL_VERSION = "v1"
- /** Submit an application, assuming Spark parameters are specified through system properties. */
+ /**
+ * Submit an application, assuming Spark parameters are specified through the given config.
+ * This is abstracted to its own method for testing purposes.
+ */
+ private[rest] def run(
+ appResource: String,
+ mainClass: String,
+ appArgs: Array[String],
+ conf: SparkConf,
+ env: Map[String, String] = sys.env): SubmitRestProtocolResponse = {
+ val master = conf.getOption("spark.master").getOrElse {
+ throw new IllegalArgumentException("'spark.master' must be set.")
+ }
+ val sparkProperties = conf.getAll.toMap
+ val environmentVariables = env.filter { case (k, _) => k.startsWith("SPARK_") }
+ val client = new StandaloneRestClient
+ val submitRequest = client.constructSubmitRequest(
+ appResource, mainClass, appArgs, sparkProperties, environmentVariables)
+ client.createSubmission(master, submitRequest)
+ }
+
def main(args: Array[String]): Unit = {
if (args.size < 2) {
sys.error("Usage: StandaloneRestClient [app resource] [main class] [app args*]")
@@ -294,14 +326,6 @@ private[spark] object StandaloneRestClient {
val mainClass = args(1)
val appArgs = args.slice(2, args.size)
val conf = new SparkConf
- val master = conf.getOption("spark.master").getOrElse {
- throw new IllegalArgumentException("'spark.master' must be set.")
- }
- val sparkProperties = conf.getAll.toMap
- val environmentVariables = sys.env.filter { case (k, _) => k.startsWith("SPARK_") }
- val client = new StandaloneRestClient
- val submitRequest = client.constructSubmitRequest(
- appResource, mainClass, appArgs, sparkProperties, environmentVariables)
- client.createSubmission(master, submitRequest)
+ run(appResource, mainClass, appArgs, conf)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala
index acd3a2b5ab..f9e0478e4f 100644
--- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala
@@ -17,15 +17,14 @@
package org.apache.spark.deploy.rest
-import java.io.{DataOutputStream, File}
+import java.io.File
import java.net.InetSocketAddress
import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse}
import scala.io.Source
import akka.actor.ActorRef
-import com.fasterxml.jackson.databind.JsonMappingException
-import com.google.common.base.Charsets
+import com.fasterxml.jackson.core.JsonProcessingException
import org.eclipse.jetty.server.Server
import org.eclipse.jetty.servlet.{ServletHolder, ServletContextHandler}
import org.eclipse.jetty.util.thread.QueuedThreadPool
@@ -70,14 +69,14 @@ private[spark] class StandaloneRestServer(
import StandaloneRestServer._
private var _server: Option[Server] = None
- private val baseContext = s"/$PROTOCOL_VERSION/submissions"
-
- // A mapping from servlets to the URL prefixes they are responsible for
- private val servletToContext = Map[StandaloneRestServlet, String](
- new SubmitRequestServlet(masterActor, masterUrl, masterConf) -> s"$baseContext/create/*",
- new KillRequestServlet(masterActor, masterConf) -> s"$baseContext/kill/*",
- new StatusRequestServlet(masterActor, masterConf) -> s"$baseContext/status/*",
- new ErrorServlet -> "/*" // default handler
+
+ // A mapping from URL prefixes to servlets that serve them. Exposed for testing.
+ protected val baseContext = s"/$PROTOCOL_VERSION/submissions"
+ protected val contextToServlet = Map[String, StandaloneRestServlet](
+ s"$baseContext/create/*" -> new SubmitRequestServlet(masterActor, masterUrl, masterConf),
+ s"$baseContext/kill/*" -> new KillRequestServlet(masterActor, masterConf),
+ s"$baseContext/status/*" -> new StatusRequestServlet(masterActor, masterConf),
+ "/*" -> new ErrorServlet // default handler
)
/** Start the server and return the bound port. */
@@ -99,7 +98,7 @@ private[spark] class StandaloneRestServer(
server.setThreadPool(threadPool)
val mainHandler = new ServletContextHandler
mainHandler.setContextPath("/")
- servletToContext.foreach { case (servlet, prefix) =>
+ contextToServlet.foreach { case (prefix, servlet) =>
mainHandler.addServlet(new ServletHolder(servlet), prefix)
}
server.setHandler(mainHandler)
@@ -113,7 +112,7 @@ private[spark] class StandaloneRestServer(
}
}
-private object StandaloneRestServer {
+private[rest] object StandaloneRestServer {
val PROTOCOL_VERSION = StandaloneRestClient.PROTOCOL_VERSION
val SC_UNKNOWN_PROTOCOL_VERSION = 468
}
@@ -121,20 +120,7 @@ private object StandaloneRestServer {
/**
* An abstract servlet for handling requests passed to the [[StandaloneRestServer]].
*/
-private abstract class StandaloneRestServlet extends HttpServlet with Logging {
-
- /** Service a request. If an exception is thrown in the process, indicate server error. */
- protected override def service(
- request: HttpServletRequest,
- response: HttpServletResponse): Unit = {
- try {
- super.service(request, response)
- } catch {
- case e: Exception =>
- logError("Exception while handling request", e)
- response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR)
- }
- }
+private[rest] abstract class StandaloneRestServlet extends HttpServlet with Logging {
/**
* Serialize the given response message to JSON and send it through the response servlet.
@@ -146,11 +132,7 @@ private abstract class StandaloneRestServlet extends HttpServlet with Logging {
val message = validateResponse(responseMessage, responseServlet)
responseServlet.setContentType("application/json")
responseServlet.setCharacterEncoding("utf-8")
- responseServlet.setStatus(HttpServletResponse.SC_OK)
- val content = message.toJson.getBytes(Charsets.UTF_8)
- val out = new DataOutputStream(responseServlet.getOutputStream)
- out.write(content)
- out.close()
+ responseServlet.getWriter.write(message.toJson)
}
/**
@@ -187,6 +169,19 @@ private abstract class StandaloneRestServlet extends HttpServlet with Logging {
}
/**
+ * Parse a submission ID from the relative path, assuming it is the first part of the path.
+ * For instance, we expect the path to take the form /[submission ID]/maybe/something/else.
+ * The returned submission ID cannot be empty. If the path is unexpected, return None.
+ */
+ protected def parseSubmissionId(path: String): Option[String] = {
+ if (path == null || path.isEmpty) {
+ None
+ } else {
+ path.stripPrefix("/").split("/").headOption.filter(_.nonEmpty)
+ }
+ }
+
+ /**
* Validate the response to ensure that it is correctly constructed.
*
* If it is, simply return the message as is. Otherwise, return an error response instead
@@ -209,7 +204,7 @@ private abstract class StandaloneRestServlet extends HttpServlet with Logging {
/**
* A servlet for handling kill requests passed to the [[StandaloneRestServer]].
*/
-private class KillRequestServlet(masterActor: ActorRef, conf: SparkConf)
+private[rest] class KillRequestServlet(masterActor: ActorRef, conf: SparkConf)
extends StandaloneRestServlet {
/**
@@ -219,18 +214,15 @@ private class KillRequestServlet(masterActor: ActorRef, conf: SparkConf)
protected override def doPost(
request: HttpServletRequest,
response: HttpServletResponse): Unit = {
- val submissionId = request.getPathInfo.stripPrefix("/")
- val responseMessage =
- if (submissionId.nonEmpty) {
- handleKill(submissionId)
- } else {
- response.setStatus(HttpServletResponse.SC_BAD_REQUEST)
- handleError("Submission ID is missing in kill request.")
- }
+ val submissionId = parseSubmissionId(request.getPathInfo)
+ val responseMessage = submissionId.map(handleKill).getOrElse {
+ response.setStatus(HttpServletResponse.SC_BAD_REQUEST)
+ handleError("Submission ID is missing in kill request.")
+ }
sendResponse(responseMessage, response)
}
- private def handleKill(submissionId: String): KillSubmissionResponse = {
+ protected def handleKill(submissionId: String): KillSubmissionResponse = {
val askTimeout = AkkaUtils.askTimeout(conf)
val response = AkkaUtils.askWithReply[DeployMessages.KillDriverResponse](
DeployMessages.RequestKillDriver(submissionId), masterActor, askTimeout)
@@ -246,7 +238,7 @@ private class KillRequestServlet(masterActor: ActorRef, conf: SparkConf)
/**
* A servlet for handling status requests passed to the [[StandaloneRestServer]].
*/
-private class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf)
+private[rest] class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf)
extends StandaloneRestServlet {
/**
@@ -256,18 +248,15 @@ private class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf)
protected override def doGet(
request: HttpServletRequest,
response: HttpServletResponse): Unit = {
- val submissionId = request.getPathInfo.stripPrefix("/")
- val responseMessage =
- if (submissionId.nonEmpty) {
- handleStatus(submissionId)
- } else {
- response.setStatus(HttpServletResponse.SC_BAD_REQUEST)
- handleError("Submission ID is missing in status request.")
- }
+ val submissionId = parseSubmissionId(request.getPathInfo)
+ val responseMessage = submissionId.map(handleStatus).getOrElse {
+ response.setStatus(HttpServletResponse.SC_BAD_REQUEST)
+ handleError("Submission ID is missing in status request.")
+ }
sendResponse(responseMessage, response)
}
- private def handleStatus(submissionId: String): SubmissionStatusResponse = {
+ protected def handleStatus(submissionId: String): SubmissionStatusResponse = {
val askTimeout = AkkaUtils.askTimeout(conf)
val response = AkkaUtils.askWithReply[DeployMessages.DriverStatusResponse](
DeployMessages.RequestDriverStatus(submissionId), masterActor, askTimeout)
@@ -287,7 +276,7 @@ private class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf)
/**
* A servlet for handling submit requests passed to the [[StandaloneRestServer]].
*/
-private class SubmitRequestServlet(
+private[rest] class SubmitRequestServlet(
masterActor: ActorRef,
masterUrl: String,
conf: SparkConf)
@@ -313,7 +302,7 @@ private class SubmitRequestServlet(
handleSubmit(requestMessageJson, requestMessage, responseServlet)
} catch {
// The client failed to provide a valid JSON, so this is not our fault
- case e @ (_: JsonMappingException | _: SubmitRestProtocolException) =>
+ case e @ (_: JsonProcessingException | _: SubmitRestProtocolException) =>
responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST)
handleError("Malformed request: " + formatException(e))
}
@@ -413,7 +402,7 @@ private class ErrorServlet extends StandaloneRestServlet {
request: HttpServletRequest,
response: HttpServletResponse): Unit = {
val path = request.getPathInfo
- val parts = path.stripPrefix("/").split("/").toSeq
+ val parts = path.stripPrefix("/").split("/").filter(_.nonEmpty).toList
var versionMismatch = false
var msg =
parts match {
@@ -423,10 +412,10 @@ private class ErrorServlet extends StandaloneRestServlet {
case `serverVersion` :: Nil =>
// http://host:port/correct-version
"Missing the /submissions prefix."
- case `serverVersion` :: "submissions" :: Nil =>
- // http://host:port/correct-version/submissions
+ case `serverVersion` :: "submissions" :: tail =>
+ // http://host:port/correct-version/submissions/*
"Missing an action: please specify one of /create, /kill, or /status."
- case unknownVersion :: _ =>
+ case unknownVersion :: tail =>
// http://host:port/unknown-version/*
versionMismatch = true
s"Unknown protocol version '$unknownVersion'."
diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala
index 29aed89b67..a345e06ecb 100644
--- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala
@@ -17,141 +17,412 @@
package org.apache.spark.deploy.rest
-import java.io.{File, FileInputStream, FileOutputStream, PrintWriter}
-import java.util.jar.{JarEntry, JarOutputStream}
-import java.util.zip.ZipEntry
+import java.io.DataOutputStream
+import java.net.{HttpURLConnection, URL}
+import javax.servlet.http.HttpServletResponse
-import scala.collection.mutable.ArrayBuffer
-import scala.io.Source
+import scala.collection.mutable
-import akka.actor.ActorSystem
-import com.google.common.io.ByteStreams
-import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}
-import org.scalatest.exceptions.TestFailedException
+import akka.actor.{Actor, ActorRef, ActorSystem, Props}
+import com.google.common.base.Charsets
+import org.scalatest.{BeforeAndAfterEach, FunSuite}
+import org.json4s.JsonAST._
+import org.json4s.jackson.JsonMethods._
import org.apache.spark._
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{AkkaUtils, Utils}
+import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.{SparkSubmit, SparkSubmitArguments}
-import org.apache.spark.deploy.master.{DriverState, Master}
-import org.apache.spark.deploy.worker.Worker
+import org.apache.spark.deploy.master.DriverState._
/**
- * End-to-end tests for the REST application submission protocol in standalone mode.
+ * Tests for the REST application submission protocol used in standalone cluster mode.
*/
-class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach {
- private val systemsToStop = new ArrayBuffer[ActorSystem]
- private val masterRestUrl = startLocalCluster()
+class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach {
private val client = new StandaloneRestClient
- private val mainJar = StandaloneRestSubmitSuite.createJar()
- private val mainClass = StandaloneRestApp.getClass.getName.stripSuffix("$")
+ private var actorSystem: Option[ActorSystem] = None
+ private var server: Option[StandaloneRestServer] = None
- override def afterAll() {
- systemsToStop.foreach(_.shutdown())
+ override def afterEach() {
+ actorSystem.foreach(_.shutdown())
+ server.foreach(_.stop())
}
- test("simple submit until completion") {
- val resultsFile = File.createTempFile("test-submit", ".txt")
- val numbers = Seq(1, 2, 3)
- val size = 500
- val submissionId = submitApplication(resultsFile, numbers, size)
- waitUntilFinished(submissionId)
- validateResult(resultsFile, numbers, size)
+ test("construct submit request") {
+ val appArgs = Array("one", "two", "three")
+ val sparkProperties = Map("spark.app.name" -> "pi")
+ val environmentVariables = Map("SPARK_ONE" -> "UN", "SPARK_TWO" -> "DEUX")
+ val request = client.constructSubmitRequest(
+ "my-app-resource", "my-main-class", appArgs, sparkProperties, environmentVariables)
+ assert(request.action === Utils.getFormattedClassName(request))
+ assert(request.clientSparkVersion === SPARK_VERSION)
+ assert(request.appResource === "my-app-resource")
+ assert(request.mainClass === "my-main-class")
+ assert(request.appArgs === appArgs)
+ assert(request.sparkProperties === sparkProperties)
+ assert(request.environmentVariables === environmentVariables)
}
- test("kill empty submission") {
- val response = client.killSubmission(masterRestUrl, "submission-that-does-not-exist")
- val killResponse = getKillResponse(response)
- val killSuccess = killResponse.success
- assert(!killSuccess)
+ test("create submission") {
+ val submittedDriverId = "my-driver-id"
+ val submitMessage = "your driver is submitted"
+ val masterUrl = startDummyServer(submitId = submittedDriverId, submitMessage = submitMessage)
+ val appArgs = Array("one", "two", "four")
+ val request = constructSubmitRequest(masterUrl, appArgs)
+ assert(request.appArgs === appArgs)
+ assert(request.sparkProperties("spark.master") === masterUrl)
+ val response = client.createSubmission(masterUrl, request)
+ val submitResponse = getSubmitResponse(response)
+ assert(submitResponse.action === Utils.getFormattedClassName(submitResponse))
+ assert(submitResponse.serverSparkVersion === SPARK_VERSION)
+ assert(submitResponse.message === submitMessage)
+ assert(submitResponse.submissionId === submittedDriverId)
+ assert(submitResponse.success)
+ }
+
+ test("create submission from main method") {
+ val submittedDriverId = "your-driver-id"
+ val submitMessage = "my driver is submitted"
+ val masterUrl = startDummyServer(submitId = submittedDriverId, submitMessage = submitMessage)
+ val conf = new SparkConf(loadDefaults = false)
+ conf.set("spark.master", masterUrl)
+ conf.set("spark.app.name", "dreamer")
+ val appArgs = Array("one", "two", "six")
+ // main method calls this
+ val response = StandaloneRestClient.run("app-resource", "main-class", appArgs, conf)
+ val submitResponse = getSubmitResponse(response)
+ assert(submitResponse.action === Utils.getFormattedClassName(submitResponse))
+ assert(submitResponse.serverSparkVersion === SPARK_VERSION)
+ assert(submitResponse.message === submitMessage)
+ assert(submitResponse.submissionId === submittedDriverId)
+ assert(submitResponse.success)
}
- test("kill running submission") {
- val resultsFile = File.createTempFile("test-kill", ".txt")
- val numbers = Seq(1, 2, 3)
- val size = 500
- val submissionId = submitApplication(resultsFile, numbers, size)
- val response = client.killSubmission(masterRestUrl, submissionId)
+ test("kill submission") {
+ val submissionId = "my-lyft-driver"
+ val killMessage = "your driver is killed"
+ val masterUrl = startDummyServer(killMessage = killMessage)
+ val response = client.killSubmission(masterUrl, submissionId)
val killResponse = getKillResponse(response)
- val killSuccess = killResponse.success
- waitUntilFinished(submissionId)
- val response2 = client.requestSubmissionStatus(masterRestUrl, submissionId)
- val statusResponse = getStatusResponse(response2)
- val statusSuccess = statusResponse.success
- val driverState = statusResponse.driverState
- assert(killSuccess)
- assert(statusSuccess)
- assert(driverState === DriverState.KILLED.toString)
- // we should not see the expected results because we killed the submission
- intercept[TestFailedException] { validateResult(resultsFile, numbers, size) }
+ assert(killResponse.action === Utils.getFormattedClassName(killResponse))
+ assert(killResponse.serverSparkVersion === SPARK_VERSION)
+ assert(killResponse.message === killMessage)
+ assert(killResponse.submissionId === submissionId)
+ assert(killResponse.success)
}
- test("request status for empty submission") {
- val response = client.requestSubmissionStatus(masterRestUrl, "submission-that-does-not-exist")
+ test("request submission status") {
+ val submissionId = "my-uber-driver"
+ val submissionState = KILLED
+ val submissionException = new Exception("there was an irresponsible mix of alcohol and cars")
+ val masterUrl = startDummyServer(state = submissionState, exception = Some(submissionException))
+ val response = client.requestSubmissionStatus(masterUrl, submissionId)
val statusResponse = getStatusResponse(response)
- val statusSuccess = statusResponse.success
- assert(!statusSuccess)
+ assert(statusResponse.action === Utils.getFormattedClassName(statusResponse))
+ assert(statusResponse.serverSparkVersion === SPARK_VERSION)
+ assert(statusResponse.message.contains(submissionException.getMessage))
+ assert(statusResponse.submissionId === submissionId)
+ assert(statusResponse.driverState === submissionState.toString)
+ assert(statusResponse.success)
+ }
+
+ test("create then kill") {
+ val masterUrl = startSmartServer()
+ val request = constructSubmitRequest(masterUrl)
+ val response1 = client.createSubmission(masterUrl, request)
+ val submitResponse = getSubmitResponse(response1)
+ assert(submitResponse.success)
+ assert(submitResponse.submissionId != null)
+ // kill submission that was just created
+ val submissionId = submitResponse.submissionId
+ val response2 = client.killSubmission(masterUrl, submissionId)
+ val killResponse = getKillResponse(response2)
+ assert(killResponse.success)
+ assert(killResponse.submissionId === submissionId)
+ }
+
+ test("create then request status") {
+ val masterUrl = startSmartServer()
+ val request = constructSubmitRequest(masterUrl)
+ val response1 = client.createSubmission(masterUrl, request)
+ val submitResponse = getSubmitResponse(response1)
+ assert(submitResponse.success)
+ assert(submitResponse.submissionId != null)
+ // request status of submission that was just created
+ val submissionId = submitResponse.submissionId
+ val response2 = client.requestSubmissionStatus(masterUrl, submissionId)
+ val statusResponse = getStatusResponse(response2)
+ assert(statusResponse.success)
+ assert(statusResponse.submissionId === submissionId)
+ assert(statusResponse.driverState === RUNNING.toString)
+ }
+
+ test("create then kill then request status") {
+ val masterUrl = startSmartServer()
+ val request = constructSubmitRequest(masterUrl)
+ val response1 = client.createSubmission(masterUrl, request)
+ val response2 = client.createSubmission(masterUrl, request)
+ val submitResponse1 = getSubmitResponse(response1)
+ val submitResponse2 = getSubmitResponse(response2)
+ assert(submitResponse1.success)
+ assert(submitResponse2.success)
+ assert(submitResponse1.submissionId != null)
+ assert(submitResponse2.submissionId != null)
+ val submissionId1 = submitResponse1.submissionId
+ val submissionId2 = submitResponse2.submissionId
+ // kill only submission 1, but not submission 2
+ val response3 = client.killSubmission(masterUrl, submissionId1)
+ val killResponse = getKillResponse(response3)
+ assert(killResponse.success)
+ assert(killResponse.submissionId === submissionId1)
+ // request status for both submissions: 1 should be KILLED but 2 should be RUNNING still
+ val response4 = client.requestSubmissionStatus(masterUrl, submissionId1)
+ val response5 = client.requestSubmissionStatus(masterUrl, submissionId2)
+ val statusResponse1 = getStatusResponse(response4)
+ val statusResponse2 = getStatusResponse(response5)
+ assert(statusResponse1.submissionId === submissionId1)
+ assert(statusResponse2.submissionId === submissionId2)
+ assert(statusResponse1.driverState === KILLED.toString)
+ assert(statusResponse2.driverState === RUNNING.toString)
+ }
+
+ test("kill or request status before create") {
+ val masterUrl = startSmartServer()
+ val doesNotExist = "does-not-exist"
+ // kill a non-existent submission
+ val response1 = client.killSubmission(masterUrl, doesNotExist)
+ val killResponse = getKillResponse(response1)
+ assert(!killResponse.success)
+ assert(killResponse.submissionId === doesNotExist)
+ // request status for a non-existent submission
+ val response2 = client.requestSubmissionStatus(masterUrl, doesNotExist)
+ val statusResponse = getStatusResponse(response2)
+ assert(!statusResponse.success)
+ assert(statusResponse.submissionId === doesNotExist)
+ }
+
+ /* ---------------------------------------- *
+ | Aberrant client / server behavior |
+ * ---------------------------------------- */
+
+ test("good request paths") {
+ val masterUrl = startSmartServer()
+ val httpUrl = masterUrl.replace("spark://", "http://")
+ val v = StandaloneRestServer.PROTOCOL_VERSION
+ val json = constructSubmitRequest(masterUrl).toJson
+ val submitRequestPath = s"$httpUrl/$v/submissions/create"
+ val killRequestPath = s"$httpUrl/$v/submissions/kill"
+ val statusRequestPath = s"$httpUrl/$v/submissions/status"
+ val (response1, code1) = sendHttpRequestWithResponse(submitRequestPath, "POST", json)
+ val (response2, code2) = sendHttpRequestWithResponse(s"$killRequestPath/anything", "POST")
+ val (response3, code3) = sendHttpRequestWithResponse(s"$killRequestPath/any/thing", "POST")
+ val (response4, code4) = sendHttpRequestWithResponse(s"$statusRequestPath/anything", "GET")
+ val (response5, code5) = sendHttpRequestWithResponse(s"$statusRequestPath/any/thing", "GET")
+ // these should all succeed and the responses should be of the correct types
+ getSubmitResponse(response1)
+ val killResponse1 = getKillResponse(response2)
+ val killResponse2 = getKillResponse(response3)
+ val statusResponse1 = getStatusResponse(response4)
+ val statusResponse2 = getStatusResponse(response5)
+ assert(killResponse1.submissionId === "anything")
+ assert(killResponse2.submissionId === "any")
+ assert(statusResponse1.submissionId === "anything")
+ assert(statusResponse2.submissionId === "any")
+ assert(code1 === HttpServletResponse.SC_OK)
+ assert(code2 === HttpServletResponse.SC_OK)
+ assert(code3 === HttpServletResponse.SC_OK)
+ assert(code4 === HttpServletResponse.SC_OK)
+ assert(code5 === HttpServletResponse.SC_OK)
+ }
+
+ test("good request paths, bad requests") {
+ val masterUrl = startSmartServer()
+ val httpUrl = masterUrl.replace("spark://", "http://")
+ val v = StandaloneRestServer.PROTOCOL_VERSION
+ val submitRequestPath = s"$httpUrl/$v/submissions/create"
+ val killRequestPath = s"$httpUrl/$v/submissions/kill"
+ val statusRequestPath = s"$httpUrl/$v/submissions/status"
+ val goodJson = constructSubmitRequest(masterUrl).toJson
+ val badJson1 = goodJson.replaceAll("action", "fraction") // invalid JSON
+ val badJson2 = goodJson.substring(goodJson.size / 2) // malformed JSON
+ val (response1, code1) = sendHttpRequestWithResponse(submitRequestPath, "POST") // missing JSON
+ val (response2, code2) = sendHttpRequestWithResponse(submitRequestPath, "POST", badJson1)
+ val (response3, code3) = sendHttpRequestWithResponse(submitRequestPath, "POST", badJson2)
+ val (response4, code4) = sendHttpRequestWithResponse(killRequestPath, "POST") // missing ID
+ val (response5, code5) = sendHttpRequestWithResponse(s"$killRequestPath/", "POST")
+ val (response6, code6) = sendHttpRequestWithResponse(statusRequestPath, "GET") // missing ID
+ val (response7, code7) = sendHttpRequestWithResponse(s"$statusRequestPath/", "GET")
+ // these should all fail as error responses
+ getErrorResponse(response1)
+ getErrorResponse(response2)
+ getErrorResponse(response3)
+ getErrorResponse(response4)
+ getErrorResponse(response5)
+ getErrorResponse(response6)
+ getErrorResponse(response7)
+ assert(code1 === HttpServletResponse.SC_BAD_REQUEST)
+ assert(code2 === HttpServletResponse.SC_BAD_REQUEST)
+ assert(code3 === HttpServletResponse.SC_BAD_REQUEST)
+ assert(code4 === HttpServletResponse.SC_BAD_REQUEST)
+ assert(code5 === HttpServletResponse.SC_BAD_REQUEST)
+ assert(code6 === HttpServletResponse.SC_BAD_REQUEST)
+ assert(code7 === HttpServletResponse.SC_BAD_REQUEST)
+ }
+
+ test("bad request paths") {
+ val masterUrl = startSmartServer()
+ val httpUrl = masterUrl.replace("spark://", "http://")
+ val v = StandaloneRestServer.PROTOCOL_VERSION
+ val (response1, code1) = sendHttpRequestWithResponse(httpUrl, "GET")
+ val (response2, code2) = sendHttpRequestWithResponse(s"$httpUrl/", "GET")
+ val (response3, code3) = sendHttpRequestWithResponse(s"$httpUrl/$v", "GET")
+ val (response4, code4) = sendHttpRequestWithResponse(s"$httpUrl/$v/", "GET")
+ val (response5, code5) = sendHttpRequestWithResponse(s"$httpUrl/$v/submissions", "GET")
+ val (response6, code6) = sendHttpRequestWithResponse(s"$httpUrl/$v/submissions/", "GET")
+ val (response7, code7) = sendHttpRequestWithResponse(s"$httpUrl/$v/submissions/bad", "GET")
+ val (response8, code8) = sendHttpRequestWithResponse(s"$httpUrl/bad-version", "GET")
+ assert(code1 === HttpServletResponse.SC_BAD_REQUEST)
+ assert(code2 === HttpServletResponse.SC_BAD_REQUEST)
+ assert(code3 === HttpServletResponse.SC_BAD_REQUEST)
+ assert(code4 === HttpServletResponse.SC_BAD_REQUEST)
+ assert(code5 === HttpServletResponse.SC_BAD_REQUEST)
+ assert(code6 === HttpServletResponse.SC_BAD_REQUEST)
+ assert(code7 === HttpServletResponse.SC_BAD_REQUEST)
+ assert(code8 === StandaloneRestServer.SC_UNKNOWN_PROTOCOL_VERSION)
+ // all responses should be error responses
+ val errorResponse1 = getErrorResponse(response1)
+ val errorResponse2 = getErrorResponse(response2)
+ val errorResponse3 = getErrorResponse(response3)
+ val errorResponse4 = getErrorResponse(response4)
+ val errorResponse5 = getErrorResponse(response5)
+ val errorResponse6 = getErrorResponse(response6)
+ val errorResponse7 = getErrorResponse(response7)
+ val errorResponse8 = getErrorResponse(response8)
+ // only the incompatible version response should have server protocol version set
+ assert(errorResponse1.highestProtocolVersion === null)
+ assert(errorResponse2.highestProtocolVersion === null)
+ assert(errorResponse3.highestProtocolVersion === null)
+ assert(errorResponse4.highestProtocolVersion === null)
+ assert(errorResponse5.highestProtocolVersion === null)
+ assert(errorResponse6.highestProtocolVersion === null)
+ assert(errorResponse7.highestProtocolVersion === null)
+ assert(errorResponse8.highestProtocolVersion === StandaloneRestServer.PROTOCOL_VERSION)
+ }
+
+ test("server returns unknown fields") {
+ val masterUrl = startSmartServer()
+ val httpUrl = masterUrl.replace("spark://", "http://")
+ val v = StandaloneRestServer.PROTOCOL_VERSION
+ val submitRequestPath = s"$httpUrl/$v/submissions/create"
+ val oldJson = constructSubmitRequest(masterUrl).toJson
+ val oldFields = parse(oldJson).asInstanceOf[JObject].obj
+ val newFields = oldFields ++ Seq(
+ JField("tomato", JString("not-a-fruit")),
+ JField("potato", JString("not-po-tah-to"))
+ )
+ val newJson = pretty(render(JObject(newFields)))
+ // send two requests, one with the unknown fields and the other without
+ val (response1, code1) = sendHttpRequestWithResponse(submitRequestPath, "POST", oldJson)
+ val (response2, code2) = sendHttpRequestWithResponse(submitRequestPath, "POST", newJson)
+ val submitResponse1 = getSubmitResponse(response1)
+ val submitResponse2 = getSubmitResponse(response2)
+ assert(code1 === HttpServletResponse.SC_OK)
+ assert(code2 === HttpServletResponse.SC_OK)
+ // only the response to the modified request should have unknown fields set
+ assert(submitResponse1.unknownFields === null)
+ assert(submitResponse2.unknownFields === Array("tomato", "potato"))
+ }
+
+ test("client handles faulty server") {
+ val masterUrl = startFaultyServer()
+ val httpUrl = masterUrl.replace("spark://", "http://")
+ val v = StandaloneRestServer.PROTOCOL_VERSION
+ val submitRequestPath = s"$httpUrl/$v/submissions/create"
+ val killRequestPath = s"$httpUrl/$v/submissions/kill/anything"
+ val statusRequestPath = s"$httpUrl/$v/submissions/status/anything"
+ val json = constructSubmitRequest(masterUrl).toJson
+ // server returns malformed response unwittingly
+ // client should throw an appropriate exception to indicate server failure
+ val conn1 = sendHttpRequest(submitRequestPath, "POST", json)
+ intercept[SubmitRestProtocolException] { client.readResponse(conn1) }
+ // server attempts to send invalid response, but fails internally on validation
+ // client should receive an error response as server is able to recover
+ val conn2 = sendHttpRequest(killRequestPath, "POST")
+ val response2 = client.readResponse(conn2)
+ getErrorResponse(response2)
+ assert(conn2.getResponseCode === HttpServletResponse.SC_INTERNAL_SERVER_ERROR)
+ // server explodes internally beyond recovery
+ // client should throw an appropriate exception to indicate server failure
+ val conn3 = sendHttpRequest(statusRequestPath, "GET")
+ intercept[SubmitRestProtocolException] { client.readResponse(conn3) } // empty response
+ assert(conn3.getResponseCode === HttpServletResponse.SC_INTERNAL_SERVER_ERROR)
+ }
+
+ /* --------------------- *
+ | Helper methods |
+ * --------------------- */
+
+ /** Start a dummy server that responds to requests using the specified parameters. */
+ private def startDummyServer(
+ submitId: String = "fake-driver-id",
+ submitMessage: String = "driver is submitted",
+ killMessage: String = "driver is killed",
+ state: DriverState = FINISHED,
+ exception: Option[Exception] = None): String = {
+ startServer(new DummyMaster(submitId, submitMessage, killMessage, state, exception))
+ }
+
+ /** Start a smarter dummy server that keeps track of submitted driver states. */
+ private def startSmartServer(): String = {
+ startServer(new SmarterMaster)
+ }
+
+ /** Start a dummy server that is faulty in many ways... */
+ private def startFaultyServer(): String = {
+ startServer(new DummyMaster, faulty = true)
}
/**
- * Start a local cluster containing one Master and a few Workers.
- * Do not use [[org.apache.spark.deploy.LocalSparkCluster]] here because we want the REST URL.
- * Return the Master's REST URL to which applications should be submitted.
+ * Start a [[StandaloneRestServer]] that communicates with the given actor.
+ * If `faulty` is true, start an [[FaultyStandaloneRestServer]] instead.
+ * Return the master URL that corresponds to the address of this server.
*/
- private def startLocalCluster(): String = {
- val conf = new SparkConf(false)
- .set("spark.master.rest.enabled", "true")
- .set("spark.master.rest.port", "0")
- val (numWorkers, coresPerWorker, memPerWorker) = (2, 1, 512)
- val localHostName = Utils.localHostName()
- val (masterSystem, masterPort, _, _masterRestPort) =
- Master.startSystemAndActor(localHostName, 0, 0, conf)
- val masterRestPort = _masterRestPort.getOrElse { fail("REST server not started on Master!") }
- val masterUrl = "spark://" + localHostName + ":" + masterPort
- val masterRestUrl = "spark://" + localHostName + ":" + masterRestPort
- (1 to numWorkers).foreach { n =>
- val (workerSystem, _) = Worker.startSystemAndActor(
- localHostName, 0, 0, coresPerWorker, memPerWorker, Array(masterUrl), null, Some(n))
- systemsToStop.append(workerSystem)
- }
- systemsToStop.append(masterSystem)
- masterRestUrl
+ private def startServer(makeFakeMaster: => Actor, faulty: Boolean = false): String = {
+ val name = "test-standalone-rest-protocol"
+ val conf = new SparkConf
+ val localhost = Utils.localHostName()
+ val securityManager = new SecurityManager(conf)
+ val (_actorSystem, _) = AkkaUtils.createActorSystem(name, localhost, 0, conf, securityManager)
+ val fakeMasterRef = _actorSystem.actorOf(Props(makeFakeMaster))
+ val _server =
+ if (faulty) {
+ new FaultyStandaloneRestServer(localhost, 0, fakeMasterRef, "spark://fake:7077", conf)
+ } else {
+ new StandaloneRestServer(localhost, 0, fakeMasterRef, "spark://fake:7077", conf)
+ }
+ val port = _server.start()
+ // set these to clean them up after every test
+ actorSystem = Some(_actorSystem)
+ server = Some(_server)
+ s"spark://$localhost:$port"
}
- /** Submit the [[StandaloneRestApp]] and return the corresponding submission ID. */
- private def submitApplication(resultsFile: File, numbers: Seq[Int], size: Int): String = {
- val appArgs = Seq(resultsFile.getAbsolutePath) ++ numbers.map(_.toString) ++ Seq(size.toString)
+ /** Create a submit request with real parameters using Spark submit. */
+ private def constructSubmitRequest(
+ masterUrl: String,
+ appArgs: Array[String] = Array.empty): CreateSubmissionRequest = {
+ val mainClass = "main-class-not-used"
+ val mainJar = "dummy-jar-not-used.jar"
val commandLineArgs = Array(
"--deploy-mode", "cluster",
- "--master", masterRestUrl,
+ "--master", masterUrl,
"--name", mainClass,
"--class", mainClass,
mainJar) ++ appArgs
val args = new SparkSubmitArguments(commandLineArgs)
val (_, _, sparkProperties, _) = SparkSubmit.prepareSubmitEnvironment(args)
- val request = client.constructSubmitRequest(
- mainJar, mainClass, appArgs.toArray, sparkProperties.toMap, Map.empty)
- val response = client.createSubmission(masterRestUrl, request)
- val submitResponse = getSubmitResponse(response)
- val submissionId = submitResponse.submissionId
- assert(submissionId != null, "Application submission was unsuccessful!")
- submissionId
- }
-
- /** Wait until the given submission has finished running up to the specified timeout. */
- private def waitUntilFinished(submissionId: String, maxSeconds: Int = 30): Unit = {
- var finished = false
- val expireTime = System.currentTimeMillis + maxSeconds * 1000
- while (!finished) {
- val response = client.requestSubmissionStatus(masterRestUrl, submissionId)
- val statusResponse = getStatusResponse(response)
- val driverState = statusResponse.driverState
- finished =
- driverState != DriverState.SUBMITTED.toString &&
- driverState != DriverState.RUNNING.toString
- if (System.currentTimeMillis > expireTime) {
- fail(s"Driver $submissionId did not finish within $maxSeconds seconds.")
- }
- }
+ client.constructSubmitRequest(
+ mainJar, mainClass, appArgs, sparkProperties.toMap, Map.empty)
}
/** Return the response as a submit response, or fail with error otherwise. */
@@ -181,85 +452,151 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterAll with Bef
}
}
- /** Validate whether the application produced the corrupt output. */
- private def validateResult(resultsFile: File, numbers: Seq[Int], size: Int): Unit = {
- val lines = Source.fromFile(resultsFile.getAbsolutePath).getLines().toSeq
- val unexpectedContent =
- if (lines.nonEmpty) {
- "[\n" + lines.map { l => " " + l }.mkString("\n") + "\n]"
- } else {
- "[EMPTY]"
- }
- assert(lines.size === 2, s"Unexpected content in file: $unexpectedContent")
- assert(lines(0).toInt === numbers.sum, s"Sum of ${numbers.mkString(",")} is incorrect")
- assert(lines(1).toInt === (size / 2) + 1, "Result of Spark job is incorrect")
+ /** Return the response as an error response, or fail if the response was not an error. */
+ private def getErrorResponse(response: SubmitRestProtocolResponse): ErrorResponse = {
+ response match {
+ case e: ErrorResponse => e
+ case r => fail(s"Expected error response. Actual: ${r.toJson}")
+ }
}
-}
-
-private object StandaloneRestSubmitSuite {
- private val pathPrefix = this.getClass.getPackage.getName.replaceAll("\\.", "/")
/**
- * Create a jar that contains all the class files needed for running the [[StandaloneRestApp]].
- * Return the absolute path to that jar.
+ * Send an HTTP request to the given URL using the method and the body specified.
+ * Return the connection object.
*/
- def createJar(): String = {
- val jarFile = File.createTempFile("test-standalone-rest-protocol", ".jar")
- val jarFileStream = new FileOutputStream(jarFile)
- val jarStream = new JarOutputStream(jarFileStream, new java.util.jar.Manifest)
- jarStream.putNextEntry(new ZipEntry(pathPrefix))
- getClassFiles.foreach { cf =>
- jarStream.putNextEntry(new JarEntry(pathPrefix + "/" + cf.getName))
- val in = new FileInputStream(cf)
- ByteStreams.copy(in, jarStream)
- in.close()
+ private def sendHttpRequest(
+ url: String,
+ method: String,
+ body: String = ""): HttpURLConnection = {
+ val conn = new URL(url).openConnection().asInstanceOf[HttpURLConnection]
+ conn.setRequestMethod(method)
+ if (body.nonEmpty) {
+ conn.setDoOutput(true)
+ val out = new DataOutputStream(conn.getOutputStream)
+ out.write(body.getBytes(Charsets.UTF_8))
+ out.close()
}
- jarStream.close()
- jarFileStream.close()
- jarFile.getAbsolutePath
+ conn
}
/**
- * Return a list of class files compiled for [[StandaloneRestApp]].
- * This includes all the anonymous classes used in the application.
+ * Send an HTTP request to the given URL using the method and the body specified.
+ * Return a 2-tuple of the response message from the server and the response code.
*/
- private def getClassFiles: Seq[File] = {
- val className = Utils.getFormattedClassName(StandaloneRestApp)
- val clazz = StandaloneRestApp.getClass
- val basePath = clazz.getProtectionDomain.getCodeSource.getLocation.toURI.getPath
- val baseDir = new File(basePath + "/" + pathPrefix)
- baseDir.listFiles().filter(_.getName.contains(className))
+ private def sendHttpRequestWithResponse(
+ url: String,
+ method: String,
+ body: String = ""): (SubmitRestProtocolResponse, Int) = {
+ val conn = sendHttpRequest(url, method, body)
+ (client.readResponse(conn), conn.getResponseCode)
}
}
/**
- * Sample application to be submitted to the cluster using the REST gateway.
- * All relevant classes will be packaged into a jar at run time.
+ * A mock standalone Master that responds with dummy messages.
+ * In all responses, the success parameter is always true.
*/
-object StandaloneRestApp {
- // Usage: [path to results file] [num1] [num2] [num3] [rddSize]
- // The first line of the results file should be (num1 + num2 + num3)
- // The second line should be (rddSize / 2) + 1
- def main(args: Array[String]) {
- assert(args.size == 5, s"Expected exactly 5 arguments: ${args.mkString(",")}")
- val resultFile = new File(args(0))
- val writer = new PrintWriter(resultFile)
- try {
- val conf = new SparkConf()
- val sc = new SparkContext(conf)
- val firstLine = args(1).toInt + args(2).toInt + args(3).toInt
- val secondLine = sc.parallelize(1 to args(4).toInt)
- .map { i => (i / 2, i) }
- .reduceByKey(_ + _)
- .count()
- writer.println(firstLine)
- writer.println(secondLine)
- } catch {
- case e: Exception =>
- writer.println(e)
- e.getStackTrace.foreach { l => writer.println(" " + l) }
- } finally {
- writer.close()
+private class DummyMaster(
+ submitId: String = "fake-driver-id",
+ submitMessage: String = "submitted",
+ killMessage: String = "killed",
+ state: DriverState = FINISHED,
+ exception: Option[Exception] = None)
+ extends Actor {
+
+ override def receive = {
+ case RequestSubmitDriver(driverDesc) =>
+ sender ! SubmitDriverResponse(success = true, Some(submitId), submitMessage)
+ case RequestKillDriver(driverId) =>
+ sender ! KillDriverResponse(driverId, success = true, killMessage)
+ case RequestDriverStatus(driverId) =>
+ sender ! DriverStatusResponse(found = true, Some(state), None, None, exception)
+ }
+}
+
+/**
+ * A mock standalone Master that keeps track of drivers that have been submitted.
+ *
+ * If a driver is submitted, its state is immediately set to RUNNING.
+ * If an existing driver is killed, its state is immediately set to KILLED.
+ * If an existing driver's status is requested, its state is returned in the response.
+ * Submits are always successful while kills and status requests are successful only
+ * if the driver was submitted in the past.
+ */
+private class SmarterMaster extends Actor {
+ private var counter: Int = 0
+ private val submittedDrivers = new mutable.HashMap[String, DriverState]
+
+ override def receive = {
+ case RequestSubmitDriver(driverDesc) =>
+ val driverId = s"driver-$counter"
+ submittedDrivers(driverId) = RUNNING
+ counter += 1
+ sender ! SubmitDriverResponse(success = true, Some(driverId), "submitted")
+
+ case RequestKillDriver(driverId) =>
+ val success = submittedDrivers.contains(driverId)
+ if (success) {
+ submittedDrivers(driverId) = KILLED
+ }
+ sender ! KillDriverResponse(driverId, success, "killed")
+
+ case RequestDriverStatus(driverId) =>
+ val found = submittedDrivers.contains(driverId)
+ val state = submittedDrivers.get(driverId)
+ sender ! DriverStatusResponse(found, state, None, None, None)
+ }
+}
+
+/**
+ * A [[StandaloneRestServer]] that is faulty in many ways.
+ *
+ * When handling a submit request, the server returns a malformed JSON.
+ * When handling a kill request, the server returns an invalid JSON.
+ * When handling a status request, the server throws an internal exception.
+ * The purpose of this class is to test that client handles these cases gracefully.
+ */
+private class FaultyStandaloneRestServer(
+ host: String,
+ requestedPort: Int,
+ masterActor: ActorRef,
+ masterUrl: String,
+ masterConf: SparkConf)
+ extends StandaloneRestServer(host, requestedPort, masterActor, masterUrl, masterConf) {
+
+ protected override val contextToServlet = Map[String, StandaloneRestServlet](
+ s"$baseContext/create/*" -> new MalformedSubmitServlet,
+ s"$baseContext/kill/*" -> new InvalidKillServlet,
+ s"$baseContext/status/*" -> new ExplodingStatusServlet,
+ "/*" -> new ErrorServlet
+ )
+
+ /** A faulty servlet that produces malformed responses. */
+ class MalformedSubmitServlet extends SubmitRequestServlet(masterActor, masterUrl, masterConf) {
+ protected override def sendResponse(
+ responseMessage: SubmitRestProtocolResponse,
+ responseServlet: HttpServletResponse): Unit = {
+ val badJson = responseMessage.toJson.drop(10).dropRight(20)
+ responseServlet.getWriter.write(badJson)
+ }
+ }
+
+ /** A faulty servlet that produces invalid responses. */
+ class InvalidKillServlet extends KillRequestServlet(masterActor, masterConf) {
+ protected override def handleKill(submissionId: String): KillSubmissionResponse = {
+ val k = super.handleKill(submissionId)
+ k.submissionId = null
+ k
+ }
+ }
+
+ /** A faulty status servlet that explodes. */
+ class ExplodingStatusServlet extends StatusRequestServlet(masterActor, masterConf) {
+ private def explode: Int = 1 / 0
+ protected override def handleStatus(submissionId: String): SubmissionStatusResponse = {
+ val s = super.handleStatus(submissionId)
+ s.workerId = explode.toString
+ s
}
}
}