aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorBryan Cutler <cutlerb@gmail.com>2016-08-11 14:49:11 -0700
committerMarcelo Vanzin <vanzin@cloudera.com>2016-08-11 14:49:11 -0700
commit1c9a386c6b6812a3931f3fb0004249894a01f657 (patch)
tree5548fcaa17b7e6a72846c25600d6df21fb95b040 /core
parentcf9367826c38e5f34ae69b409f5d09c55ed1d319 (diff)
downloadspark-1c9a386c6b6812a3931f3fb0004249894a01f657.tar.gz
spark-1c9a386c6b6812a3931f3fb0004249894a01f657.tar.bz2
spark-1c9a386c6b6812a3931f3fb0004249894a01f657.zip
[SPARK-13602][CORE] Add shutdown hook to DriverRunner to prevent driver process leak
## What changes were proposed in this pull request? Added shutdown hook to DriverRunner to kill the driver process in case the Worker JVM exits suddenly and the `WorkerWatcher` was unable to properly catch this. Did some cleanup to consolidate driver state management and setting of finalized vars within the running thread. ## How was this patch tested? Added unit tests to verify that final state and exception variables are set accordingly for successfull, failed, and errors in the driver process. Retrofitted existing test to verify killing of mocked process ends with the correct state and stops properly Manually tested (with deploy-mode=cluster) that the shutdown hook is called by forcibly exiting the `Worker` and various points in the code with the `WorkerWatcher` both disabled and enabled. Also, manually killed the driver through the ui and verified that the `DriverRunner` interrupted, killed the process and exited properly. Author: Bryan Cutler <cutlerb@gmail.com> Closes #11746 from BryanCutler/DriverRunner-shutdown-hook-SPARK-13602.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala119
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala73
2 files changed, 142 insertions, 50 deletions
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
index f4376dedea..289b0b93b0 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
@@ -32,7 +32,7 @@ import org.apache.spark.deploy.master.DriverState
import org.apache.spark.deploy.master.DriverState.DriverState
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.RpcEndpointRef
-import org.apache.spark.util.{Clock, SystemClock, Utils}
+import org.apache.spark.util.{Clock, ShutdownHookManager, SystemClock, Utils}
/**
* Manages the execution of one driver, including automatically restarting the driver on failure.
@@ -53,9 +53,11 @@ private[deploy] class DriverRunner(
@volatile private var killed = false
// Populated once finished
- private[worker] var finalState: Option[DriverState] = None
- private[worker] var finalException: Option[Exception] = None
- private var finalExitCode: Option[Int] = None
+ @volatile private[worker] var finalState: Option[DriverState] = None
+ @volatile private[worker] var finalException: Option[Exception] = None
+
+ // Timeout to wait for when trying to terminate a driver.
+ private val DRIVER_TERMINATE_TIMEOUT_MS = 10 * 1000
// Decoupled for testing
def setClock(_clock: Clock): Unit = {
@@ -78,49 +80,53 @@ private[deploy] class DriverRunner(
private[worker] def start() = {
new Thread("DriverRunner for " + driverId) {
override def run() {
+ var shutdownHook: AnyRef = null
try {
- val driverDir = createWorkingDirectory()
- val localJarFilename = downloadUserJar(driverDir)
-
- def substituteVariables(argument: String): String = argument match {
- case "{{WORKER_URL}}" => workerUrl
- case "{{USER_JAR}}" => localJarFilename
- case other => other
+ shutdownHook = ShutdownHookManager.addShutdownHook { () =>
+ logInfo(s"Worker shutting down, killing driver $driverId")
+ kill()
}
- // TODO: If we add ability to submit multiple jars they should also be added here
- val builder = CommandUtils.buildProcessBuilder(driverDesc.command, securityManager,
- driverDesc.mem, sparkHome.getAbsolutePath, substituteVariables)
- launchDriver(builder, driverDir, driverDesc.supervise)
- }
- catch {
- case e: Exception => finalException = Some(e)
- }
+ // prepare driver jars and run driver
+ val exitCode = prepareAndRunDriver()
- val state =
- if (killed) {
- DriverState.KILLED
- } else if (finalException.isDefined) {
- DriverState.ERROR
+ // set final state depending on if forcibly killed and process exit code
+ finalState = if (exitCode == 0) {
+ Some(DriverState.FINISHED)
+ } else if (killed) {
+ Some(DriverState.KILLED)
} else {
- finalExitCode match {
- case Some(0) => DriverState.FINISHED
- case _ => DriverState.FAILED
- }
+ Some(DriverState.FAILED)
}
+ } catch {
+ case e: Exception =>
+ kill()
+ finalState = Some(DriverState.ERROR)
+ finalException = Some(e)
+ } finally {
+ if (shutdownHook != null) {
+ ShutdownHookManager.removeShutdownHook(shutdownHook)
+ }
+ }
- finalState = Some(state)
-
- worker.send(DriverStateChanged(driverId, state, finalException))
+ // notify worker of final driver state, possible exception
+ worker.send(DriverStateChanged(driverId, finalState.get, finalException))
}
}.start()
}
/** Terminate this driver (or prevent it from ever starting if not yet started) */
- private[worker] def kill() {
+ private[worker] def kill(): Unit = {
+ logInfo("Killing driver process!")
+ killed = true
synchronized {
- process.foreach(_.destroy())
- killed = true
+ process.foreach { p =>
+ val exitCode = Utils.terminateProcess(p, DRIVER_TERMINATE_TIMEOUT_MS)
+ if (exitCode.isEmpty) {
+ logWarning("Failed to terminate driver process: " + p +
+ ". This process will likely be orphaned.")
+ }
+ }
}
}
@@ -142,7 +148,6 @@ private[deploy] class DriverRunner(
*/
private def downloadUserJar(driverDir: File): String = {
val jarPath = new Path(driverDesc.jarUrl)
-
val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
val destPath = new File(driverDir.getAbsolutePath, jarPath.getName)
val jarFileName = jarPath.getName
@@ -168,7 +173,24 @@ private[deploy] class DriverRunner(
localJarFilename
}
- private def launchDriver(builder: ProcessBuilder, baseDir: File, supervise: Boolean) {
+ private[worker] def prepareAndRunDriver(): Int = {
+ val driverDir = createWorkingDirectory()
+ val localJarFilename = downloadUserJar(driverDir)
+
+ def substituteVariables(argument: String): String = argument match {
+ case "{{WORKER_URL}}" => workerUrl
+ case "{{USER_JAR}}" => localJarFilename
+ case other => other
+ }
+
+ // TODO: If we add ability to submit multiple jars they should also be added here
+ val builder = CommandUtils.buildProcessBuilder(driverDesc.command, securityManager,
+ driverDesc.mem, sparkHome.getAbsolutePath, substituteVariables)
+
+ runDriver(builder, driverDir, driverDesc.supervise)
+ }
+
+ private def runDriver(builder: ProcessBuilder, baseDir: File, supervise: Boolean): Int = {
builder.directory(baseDir)
def initialize(process: Process): Unit = {
// Redirect stdout and stderr to files
@@ -184,39 +206,40 @@ private[deploy] class DriverRunner(
runCommandWithRetry(ProcessBuilderLike(builder), initialize, supervise)
}
- def runCommandWithRetry(
- command: ProcessBuilderLike, initialize: Process => Unit, supervise: Boolean): Unit = {
+ private[worker] def runCommandWithRetry(
+ command: ProcessBuilderLike, initialize: Process => Unit, supervise: Boolean): Int = {
+ var exitCode = -1
// Time to wait between submission retries.
var waitSeconds = 1
// A run of this many seconds resets the exponential back-off.
val successfulRunDuration = 5
-
var keepTrying = !killed
while (keepTrying) {
logInfo("Launch Command: " + command.command.mkString("\"", "\" \"", "\""))
synchronized {
- if (killed) { return }
+ if (killed) { return exitCode }
process = Some(command.start())
initialize(process.get)
}
val processStart = clock.getTimeMillis()
- val exitCode = process.get.waitFor()
- if (clock.getTimeMillis() - processStart > successfulRunDuration * 1000) {
- waitSeconds = 1
- }
+ exitCode = process.get.waitFor()
- if (supervise && exitCode != 0 && !killed) {
+ // check if attempting another run
+ keepTrying = supervise && exitCode != 0 && !killed
+ if (keepTrying) {
+ if (clock.getTimeMillis() - processStart > successfulRunDuration * 1000) {
+ waitSeconds = 1
+ }
logInfo(s"Command exited with status $exitCode, re-launching after $waitSeconds s.")
sleeper.sleep(waitSeconds)
waitSeconds = waitSeconds * 2 // exponential back-off
}
-
- keepTrying = supervise && exitCode != 0 && !killed
- finalExitCode = Some(exitCode)
}
+
+ exitCode
}
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala
index 2a1696be36..52956045d5 100644
--- a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala
@@ -19,13 +19,18 @@ package org.apache.spark.deploy.worker
import java.io.File
+import scala.concurrent.duration._
+
import org.mockito.Matchers._
import org.mockito.Mockito._
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
+import org.scalatest.concurrent.Eventually.{eventually, interval, timeout}
import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.deploy.{Command, DriverDescription}
+import org.apache.spark.deploy.master.DriverState
+import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.Clock
class DriverRunnerTest extends SparkFunSuite {
@@ -33,8 +38,10 @@ class DriverRunnerTest extends SparkFunSuite {
val command = new Command("mainClass", Seq(), Map(), Seq(), Seq(), Seq())
val driverDescription = new DriverDescription("jarUrl", 512, 1, true, command)
val conf = new SparkConf()
- new DriverRunner(conf, "driverId", new File("workDir"), new File("sparkHome"),
- driverDescription, null, "spark://1.2.3.4/worker/", new SecurityManager(conf))
+ val worker = mock(classOf[RpcEndpointRef])
+ doNothing().when(worker).send(any())
+ spy(new DriverRunner(conf, "driverId", new File("workDir"), new File("sparkHome"),
+ driverDescription, worker, "spark://1.2.3.4/worker/", new SecurityManager(conf)))
}
private def createProcessBuilderAndProcess(): (ProcessBuilderLike, Process) = {
@@ -45,6 +52,19 @@ class DriverRunnerTest extends SparkFunSuite {
(processBuilder, process)
}
+ private def createTestableDriverRunner(
+ processBuilder: ProcessBuilderLike,
+ superviseRetry: Boolean) = {
+ val runner = createDriverRunner()
+ runner.setSleeper(mock(classOf[Sleeper]))
+ doAnswer(new Answer[Int] {
+ def answer(invocation: InvocationOnMock): Int = {
+ runner.runCommandWithRetry(processBuilder, p => (), supervise = superviseRetry)
+ }
+ }).when(runner).prepareAndRunDriver()
+ runner
+ }
+
test("Process succeeds instantly") {
val runner = createDriverRunner()
@@ -145,4 +165,53 @@ class DriverRunnerTest extends SparkFunSuite {
verify(sleeper, times(2)).sleep(2)
}
+ test("Kill process finalized with state KILLED") {
+ val (processBuilder, process) = createProcessBuilderAndProcess()
+ val runner = createTestableDriverRunner(processBuilder, superviseRetry = true)
+
+ when(process.waitFor()).thenAnswer(new Answer[Int] {
+ def answer(invocation: InvocationOnMock): Int = {
+ runner.kill()
+ -1
+ }
+ })
+
+ runner.start()
+
+ eventually(timeout(10.seconds), interval(100.millis)) {
+ assert(runner.finalState.get === DriverState.KILLED)
+ }
+ verify(process, times(1)).waitFor()
+ }
+
+ test("Finalized with state FINISHED") {
+ val (processBuilder, process) = createProcessBuilderAndProcess()
+ val runner = createTestableDriverRunner(processBuilder, superviseRetry = true)
+ when(process.waitFor()).thenReturn(0)
+ runner.start()
+ eventually(timeout(10.seconds), interval(100.millis)) {
+ assert(runner.finalState.get === DriverState.FINISHED)
+ }
+ }
+
+ test("Finalized with state FAILED") {
+ val (processBuilder, process) = createProcessBuilderAndProcess()
+ val runner = createTestableDriverRunner(processBuilder, superviseRetry = false)
+ when(process.waitFor()).thenReturn(-1)
+ runner.start()
+ eventually(timeout(10.seconds), interval(100.millis)) {
+ assert(runner.finalState.get === DriverState.FAILED)
+ }
+ }
+
+ test("Handle exception starting process") {
+ val (processBuilder, process) = createProcessBuilderAndProcess()
+ val runner = createTestableDriverRunner(processBuilder, superviseRetry = false)
+ when(processBuilder.start()).thenThrow(new NullPointerException("bad command list"))
+ runner.start()
+ eventually(timeout(10.seconds), interval(100.millis)) {
+ assert(runner.finalState.get === DriverState.ERROR)
+ assert(runner.finalException.get.isInstanceOf[RuntimeException])
+ }
+ }
}