aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala119
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala73
2 files changed, 142 insertions, 50 deletions
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
index f4376dedea..289b0b93b0 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
@@ -32,7 +32,7 @@ import org.apache.spark.deploy.master.DriverState
import org.apache.spark.deploy.master.DriverState.DriverState
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.RpcEndpointRef
-import org.apache.spark.util.{Clock, SystemClock, Utils}
+import org.apache.spark.util.{Clock, ShutdownHookManager, SystemClock, Utils}
/**
* Manages the execution of one driver, including automatically restarting the driver on failure.
@@ -53,9 +53,11 @@ private[deploy] class DriverRunner(
@volatile private var killed = false
// Populated once finished
- private[worker] var finalState: Option[DriverState] = None
- private[worker] var finalException: Option[Exception] = None
- private var finalExitCode: Option[Int] = None
+ @volatile private[worker] var finalState: Option[DriverState] = None
+ @volatile private[worker] var finalException: Option[Exception] = None
+
+ // Timeout to wait for when trying to terminate a driver.
+ private val DRIVER_TERMINATE_TIMEOUT_MS = 10 * 1000
// Decoupled for testing
def setClock(_clock: Clock): Unit = {
@@ -78,49 +80,53 @@ private[deploy] class DriverRunner(
private[worker] def start() = {
new Thread("DriverRunner for " + driverId) {
override def run() {
+ var shutdownHook: AnyRef = null
try {
- val driverDir = createWorkingDirectory()
- val localJarFilename = downloadUserJar(driverDir)
-
- def substituteVariables(argument: String): String = argument match {
- case "{{WORKER_URL}}" => workerUrl
- case "{{USER_JAR}}" => localJarFilename
- case other => other
+ shutdownHook = ShutdownHookManager.addShutdownHook { () =>
+ logInfo(s"Worker shutting down, killing driver $driverId")
+ kill()
}
- // TODO: If we add ability to submit multiple jars they should also be added here
- val builder = CommandUtils.buildProcessBuilder(driverDesc.command, securityManager,
- driverDesc.mem, sparkHome.getAbsolutePath, substituteVariables)
- launchDriver(builder, driverDir, driverDesc.supervise)
- }
- catch {
- case e: Exception => finalException = Some(e)
- }
+ // prepare driver jars and run driver
+ val exitCode = prepareAndRunDriver()
- val state =
- if (killed) {
- DriverState.KILLED
- } else if (finalException.isDefined) {
- DriverState.ERROR
+ // set final state depending on if forcibly killed and process exit code
+ finalState = if (exitCode == 0) {
+ Some(DriverState.FINISHED)
+ } else if (killed) {
+ Some(DriverState.KILLED)
} else {
- finalExitCode match {
- case Some(0) => DriverState.FINISHED
- case _ => DriverState.FAILED
- }
+ Some(DriverState.FAILED)
}
+ } catch {
+ case e: Exception =>
+ kill()
+ finalState = Some(DriverState.ERROR)
+ finalException = Some(e)
+ } finally {
+ if (shutdownHook != null) {
+ ShutdownHookManager.removeShutdownHook(shutdownHook)
+ }
+ }
- finalState = Some(state)
-
- worker.send(DriverStateChanged(driverId, state, finalException))
+ // notify worker of final driver state, possible exception
+ worker.send(DriverStateChanged(driverId, finalState.get, finalException))
}
}.start()
}
/** Terminate this driver (or prevent it from ever starting if not yet started) */
- private[worker] def kill() {
+ private[worker] def kill(): Unit = {
+ logInfo("Killing driver process!")
+ killed = true
synchronized {
- process.foreach(_.destroy())
- killed = true
+ process.foreach { p =>
+ val exitCode = Utils.terminateProcess(p, DRIVER_TERMINATE_TIMEOUT_MS)
+ if (exitCode.isEmpty) {
+ logWarning("Failed to terminate driver process: " + p +
+ ". This process will likely be orphaned.")
+ }
+ }
}
}
@@ -142,7 +148,6 @@ private[deploy] class DriverRunner(
*/
private def downloadUserJar(driverDir: File): String = {
val jarPath = new Path(driverDesc.jarUrl)
-
val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
val destPath = new File(driverDir.getAbsolutePath, jarPath.getName)
val jarFileName = jarPath.getName
@@ -168,7 +173,24 @@ private[deploy] class DriverRunner(
localJarFilename
}
- private def launchDriver(builder: ProcessBuilder, baseDir: File, supervise: Boolean) {
+ private[worker] def prepareAndRunDriver(): Int = {
+ val driverDir = createWorkingDirectory()
+ val localJarFilename = downloadUserJar(driverDir)
+
+ def substituteVariables(argument: String): String = argument match {
+ case "{{WORKER_URL}}" => workerUrl
+ case "{{USER_JAR}}" => localJarFilename
+ case other => other
+ }
+
+ // TODO: If we add ability to submit multiple jars they should also be added here
+ val builder = CommandUtils.buildProcessBuilder(driverDesc.command, securityManager,
+ driverDesc.mem, sparkHome.getAbsolutePath, substituteVariables)
+
+ runDriver(builder, driverDir, driverDesc.supervise)
+ }
+
+ private def runDriver(builder: ProcessBuilder, baseDir: File, supervise: Boolean): Int = {
builder.directory(baseDir)
def initialize(process: Process): Unit = {
// Redirect stdout and stderr to files
@@ -184,39 +206,40 @@ private[deploy] class DriverRunner(
runCommandWithRetry(ProcessBuilderLike(builder), initialize, supervise)
}
- def runCommandWithRetry(
- command: ProcessBuilderLike, initialize: Process => Unit, supervise: Boolean): Unit = {
+ private[worker] def runCommandWithRetry(
+ command: ProcessBuilderLike, initialize: Process => Unit, supervise: Boolean): Int = {
+ var exitCode = -1
// Time to wait between submission retries.
var waitSeconds = 1
// A run of this many seconds resets the exponential back-off.
val successfulRunDuration = 5
-
var keepTrying = !killed
while (keepTrying) {
logInfo("Launch Command: " + command.command.mkString("\"", "\" \"", "\""))
synchronized {
- if (killed) { return }
+ if (killed) { return exitCode }
process = Some(command.start())
initialize(process.get)
}
val processStart = clock.getTimeMillis()
- val exitCode = process.get.waitFor()
- if (clock.getTimeMillis() - processStart > successfulRunDuration * 1000) {
- waitSeconds = 1
- }
+ exitCode = process.get.waitFor()
- if (supervise && exitCode != 0 && !killed) {
+ // check if attempting another run
+ keepTrying = supervise && exitCode != 0 && !killed
+ if (keepTrying) {
+ if (clock.getTimeMillis() - processStart > successfulRunDuration * 1000) {
+ waitSeconds = 1
+ }
logInfo(s"Command exited with status $exitCode, re-launching after $waitSeconds s.")
sleeper.sleep(waitSeconds)
waitSeconds = waitSeconds * 2 // exponential back-off
}
-
- keepTrying = supervise && exitCode != 0 && !killed
- finalExitCode = Some(exitCode)
}
+
+ exitCode
}
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala
index 2a1696be36..52956045d5 100644
--- a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala
@@ -19,13 +19,18 @@ package org.apache.spark.deploy.worker
import java.io.File
+import scala.concurrent.duration._
+
import org.mockito.Matchers._
import org.mockito.Mockito._
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
+import org.scalatest.concurrent.Eventually.{eventually, interval, timeout}
import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.deploy.{Command, DriverDescription}
+import org.apache.spark.deploy.master.DriverState
+import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.Clock
class DriverRunnerTest extends SparkFunSuite {
@@ -33,8 +38,10 @@ class DriverRunnerTest extends SparkFunSuite {
val command = new Command("mainClass", Seq(), Map(), Seq(), Seq(), Seq())
val driverDescription = new DriverDescription("jarUrl", 512, 1, true, command)
val conf = new SparkConf()
- new DriverRunner(conf, "driverId", new File("workDir"), new File("sparkHome"),
- driverDescription, null, "spark://1.2.3.4/worker/", new SecurityManager(conf))
+ val worker = mock(classOf[RpcEndpointRef])
+ doNothing().when(worker).send(any())
+ spy(new DriverRunner(conf, "driverId", new File("workDir"), new File("sparkHome"),
+ driverDescription, worker, "spark://1.2.3.4/worker/", new SecurityManager(conf)))
}
private def createProcessBuilderAndProcess(): (ProcessBuilderLike, Process) = {
@@ -45,6 +52,19 @@ class DriverRunnerTest extends SparkFunSuite {
(processBuilder, process)
}
+ private def createTestableDriverRunner(
+ processBuilder: ProcessBuilderLike,
+ superviseRetry: Boolean) = {
+ val runner = createDriverRunner()
+ runner.setSleeper(mock(classOf[Sleeper]))
+ doAnswer(new Answer[Int] {
+ def answer(invocation: InvocationOnMock): Int = {
+ runner.runCommandWithRetry(processBuilder, p => (), supervise = superviseRetry)
+ }
+ }).when(runner).prepareAndRunDriver()
+ runner
+ }
+
test("Process succeeds instantly") {
val runner = createDriverRunner()
@@ -145,4 +165,53 @@ class DriverRunnerTest extends SparkFunSuite {
verify(sleeper, times(2)).sleep(2)
}
+ test("Kill process finalized with state KILLED") {
+ val (processBuilder, process) = createProcessBuilderAndProcess()
+ val runner = createTestableDriverRunner(processBuilder, superviseRetry = true)
+
+ when(process.waitFor()).thenAnswer(new Answer[Int] {
+ def answer(invocation: InvocationOnMock): Int = {
+ runner.kill()
+ -1
+ }
+ })
+
+ runner.start()
+
+ eventually(timeout(10.seconds), interval(100.millis)) {
+ assert(runner.finalState.get === DriverState.KILLED)
+ }
+ verify(process, times(1)).waitFor()
+ }
+
+ test("Finalized with state FINISHED") {
+ val (processBuilder, process) = createProcessBuilderAndProcess()
+ val runner = createTestableDriverRunner(processBuilder, superviseRetry = true)
+ when(process.waitFor()).thenReturn(0)
+ runner.start()
+ eventually(timeout(10.seconds), interval(100.millis)) {
+ assert(runner.finalState.get === DriverState.FINISHED)
+ }
+ }
+
+ test("Finalized with state FAILED") {
+ val (processBuilder, process) = createProcessBuilderAndProcess()
+ val runner = createTestableDriverRunner(processBuilder, superviseRetry = false)
+ when(process.waitFor()).thenReturn(-1)
+ runner.start()
+ eventually(timeout(10.seconds), interval(100.millis)) {
+ assert(runner.finalState.get === DriverState.FAILED)
+ }
+ }
+
+ test("Handle exception starting process") {
+ val (processBuilder, process) = createProcessBuilderAndProcess()
+ val runner = createTestableDriverRunner(processBuilder, superviseRetry = false)
+ when(processBuilder.start()).thenThrow(new NullPointerException("bad command list"))
+ runner.start()
+ eventually(timeout(10.seconds), interval(100.millis)) {
+ assert(runner.finalState.get === DriverState.ERROR)
+ assert(runner.finalException.get.isInstanceOf[RuntimeException])
+ }
+ }
}