aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala24
-rw-r--r--core/src/test/scala/org/apache/spark/util/UtilsSuite.scala83
3 files changed, 112 insertions, 12 deletions
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
index 9a42487bb3..9c4b8cdc64 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
@@ -23,13 +23,12 @@ import scala.collection.JavaConverters._
import com.google.common.base.Charsets.UTF_8
import com.google.common.io.Files
-
-import org.apache.spark.rpc.RpcEndpointRef
-import org.apache.spark.{SecurityManager, SparkConf, Logging}
-import org.apache.spark.deploy.{ApplicationDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged
+import org.apache.spark.deploy.{ApplicationDescription, ExecutorState}
+import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.{ShutdownHookManager, Utils}
import org.apache.spark.util.logging.FileAppender
+import org.apache.spark.{Logging, SecurityManager, SparkConf}
/**
* Manages the execution of one executor process.
@@ -60,6 +59,9 @@ private[deploy] class ExecutorRunner(
private var stdoutAppender: FileAppender = null
private var stderrAppender: FileAppender = null
+ // Timeout to wait for when trying to terminate an executor.
+ private val EXECUTOR_TERMINATE_TIMEOUT_MS = 10 * 1000
+
// NOTE: This is now redundant with the automated shut-down enforced by the Executor. It might
// make sense to remove this in the future.
private var shutdownHook: AnyRef = null
@@ -94,8 +96,11 @@ private[deploy] class ExecutorRunner(
if (stderrAppender != null) {
stderrAppender.stop()
}
- process.destroy()
- exitCode = Some(process.waitFor())
+ exitCode = Utils.terminateProcess(process, EXECUTOR_TERMINATE_TIMEOUT_MS)
+ if (exitCode.isEmpty) {
+ logWarning("Failed to terminate process: " + process +
+ ". This process will likely be orphaned.")
+ }
}
try {
worker.send(ExecutorStateChanged(appId, execId, state, message, exitCode))
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index b8ca6b07e4..0c1f9c1ae2 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -1699,6 +1699,30 @@ private[spark] object Utils extends Logging {
}
/**
+ * Terminates a process waiting for at most the specified duration. Returns whether
+ * the process terminated.
+ */
+ def terminateProcess(process: Process, timeoutMs: Long): Option[Int] = {
+ try {
+ // Java8 added a new API which will more forcibly kill the process. Use that if available.
+ val destroyMethod = process.getClass().getMethod("destroyForcibly");
+ destroyMethod.setAccessible(true)
+ destroyMethod.invoke(process)
+ } catch {
+ case NonFatal(e) =>
+ if (!e.isInstanceOf[NoSuchMethodException]) {
+ logWarning("Exception when attempting to kill process", e)
+ }
+ process.destroy()
+ }
+ if (waitForProcess(process, timeoutMs)) {
+ Option(process.exitValue())
+ } else {
+ None
+ }
+ }
+
+ /**
* Wait for a process to terminate for at most the specified duration.
* Return whether the process actually terminated after the given timeout.
*/
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index fdb51d440e..7de995af51 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -17,26 +17,24 @@
package org.apache.spark.util
-import java.io.{File, ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream}
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, FileOutputStream}
import java.lang.{Double => JDouble, Float => JFloat}
import java.net.{BindException, ServerSocket, URI}
import java.nio.{ByteBuffer, ByteOrder}
import java.text.DecimalFormatSymbols
-import java.util.concurrent.TimeUnit
import java.util.Locale
+import java.util.concurrent.TimeUnit
import scala.collection.mutable.ListBuffer
import scala.util.Random
import com.google.common.base.Charsets.UTF_8
import com.google.common.io.Files
-
+import org.apache.commons.lang3.SystemUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
-
import org.apache.spark.network.util.ByteUnit
-import org.apache.spark.{Logging, SparkFunSuite}
-import org.apache.spark.SparkConf
+import org.apache.spark.{Logging, SparkConf, SparkFunSuite}
class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging {
@@ -745,4 +743,77 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging {
assert(Utils.decodeFileNameInURI(new URI("files:///abc")) === "abc")
assert(Utils.decodeFileNameInURI(new URI("files:///abc%20xyz")) === "abc xyz")
}
+
+ test("Kill process") {
+ // Verify that we can terminate a process even if it is in a bad state. This is only run
+ // on UNIX since it does some OS specific things to verify the correct behavior.
+ if (SystemUtils.IS_OS_UNIX) {
+ def getPid(p: Process): Int = {
+ val f = p.getClass().getDeclaredField("pid")
+ f.setAccessible(true)
+ f.get(p).asInstanceOf[Int]
+ }
+
+ def pidExists(pid: Int): Boolean = {
+ val p = Runtime.getRuntime.exec(s"kill -0 $pid")
+ p.waitFor()
+ p.exitValue() == 0
+ }
+
+ def signal(pid: Int, s: String): Unit = {
+ val p = Runtime.getRuntime.exec(s"kill -$s $pid")
+ p.waitFor()
+ }
+
+ // Start up a process that runs 'sleep 10'. Terminate the process and assert it takes
+ // less time and the process is no longer there.
+ val startTimeMs = System.currentTimeMillis()
+ val process = new ProcessBuilder("sleep", "10").start()
+ val pid = getPid(process)
+ try {
+ assert(pidExists(pid))
+ val terminated = Utils.terminateProcess(process, 5000)
+ assert(terminated.isDefined)
+ Utils.waitForProcess(process, 5000)
+ val durationMs = System.currentTimeMillis() - startTimeMs
+ assert(durationMs < 5000)
+ assert(!pidExists(pid))
+ } finally {
+ // Forcibly kill the test process just in case.
+ signal(pid, "SIGKILL")
+ }
+
+ val v: String = System.getProperty("java.version")
+ if (v >= "1.8.0") {
+ // Java8 added a way to forcibly terminate a process. We'll make sure that works by
+ // creating a very misbehaving process. It ignores SIGTERM and has been SIGSTOPed. On
+ // older versions of java, this will *not* terminate.
+ val file = File.createTempFile("temp-file-name", ".tmp")
+ val cmd =
+ s"""
+ |#!/bin/bash
+ |trap "" SIGTERM
+ |sleep 10
+ """.stripMargin
+ Files.write(cmd.getBytes(), file)
+ file.getAbsoluteFile.setExecutable(true)
+
+ val process = new ProcessBuilder(file.getAbsolutePath).start()
+ val pid = getPid(process)
+ assert(pidExists(pid))
+ try {
+ signal(pid, "SIGSTOP")
+ val start = System.currentTimeMillis()
+ val terminated = Utils.terminateProcess(process, 5000)
+ assert(terminated.isDefined)
+ Utils.waitForProcess(process, 5000)
+ val duration = System.currentTimeMillis() - start
+ assert(duration < 5000)
+ assert(!pidExists(pid))
+ } finally {
+ signal(pid, "SIGKILL")
+ }
+ }
+ }
+ }
}