aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/org/apache/spark/util/Utils.scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala/org/apache/spark/util/Utils.scala')
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala136
1 files changed, 118 insertions, 18 deletions
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 1029b0f9fc..7b0de1ae55 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -21,7 +21,7 @@ import java.io._
import java.lang.management.ManagementFactory
import java.net._
import java.nio.ByteBuffer
-import java.util.{Properties, Locale, Random, UUID}
+import java.util.{PriorityQueue, Properties, Locale, Random, UUID}
import java.util.concurrent._
import javax.net.ssl.HttpsURLConnection
@@ -30,7 +30,7 @@ import scala.collection.Map
import scala.collection.mutable.ArrayBuffer
import scala.io.Source
import scala.reflect.ClassTag
-import scala.util.Try
+import scala.util.{Failure, Success, Try}
import scala.util.control.{ControlThrowable, NonFatal}
import com.google.common.io.{ByteStreams, Files}
@@ -64,9 +64,15 @@ private[spark] object CallSite {
private[spark] object Utils extends Logging {
val random = new Random()
+ val DEFAULT_SHUTDOWN_PRIORITY = 100
+
private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
@volatile private var localRootDirs: Array[String] = null
+
+ private val shutdownHooks = new SparkShutdownHookManager()
+ shutdownHooks.install()
+
/** Serialize an object using Java serialization */
def serialize[T](o: T): Array[Byte] = {
val bos = new ByteArrayOutputStream()
@@ -176,18 +182,16 @@ private[spark] object Utils extends Logging {
private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]()
// Add a shutdown hook to delete the temp dirs when the JVM exits
- Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dirs") {
- override def run(): Unit = Utils.logUncaughtExceptions {
- logDebug("Shutdown hook called")
- shutdownDeletePaths.foreach { dirPath =>
- try {
- Utils.deleteRecursively(new File(dirPath))
- } catch {
- case e: Exception => logError(s"Exception while deleting Spark temp dir: $dirPath", e)
- }
+ addShutdownHook { () =>
+ logDebug("Shutdown hook called")
+ shutdownDeletePaths.foreach { dirPath =>
+ try {
+ Utils.deleteRecursively(new File(dirPath))
+ } catch {
+ case e: Exception => logError(s"Exception while deleting Spark temp dir: $dirPath", e)
}
}
- })
+ }
// Register the path to be deleted via shutdown hook
def registerShutdownDeleteDir(file: File) {
@@ -613,7 +617,7 @@ private[spark] object Utils extends Logging {
}
Utils.setupSecureURLConnection(uc, securityMgr)
- val timeoutMs =
+ val timeoutMs =
conf.getTimeAsSeconds("spark.files.fetchTimeout", "60s").toInt * 1000
uc.setConnectTimeout(timeoutMs)
uc.setReadTimeout(timeoutMs)
@@ -1172,7 +1176,7 @@ private[spark] object Utils extends Logging {
/**
* Execute a block of code that evaluates to Unit, forwarding any uncaught exceptions to the
* default UncaughtExceptionHandler
- *
+ *
* NOTE: This method is to be called by the spark-started JVM process.
*/
def tryOrExit(block: => Unit) {
@@ -1185,11 +1189,11 @@ private[spark] object Utils extends Logging {
}
/**
- * Execute a block of code that evaluates to Unit, stop SparkContext is there is any uncaught
+ * Execute a block of code that evaluates to Unit, stop SparkContext is there is any uncaught
* exception
- *
- * NOTE: This method is to be called by the driver-side components to avoid stopping the
- * user-started JVM process completely; in contrast, tryOrExit is to be called in the
+ *
+ * NOTE: This method is to be called by the driver-side components to avoid stopping the
+ * user-started JVM process completely; in contrast, tryOrExit is to be called in the
* spark-started JVM process .
*/
def tryOrStopSparkContext(sc: SparkContext)(block: => Unit) {
@@ -2132,6 +2136,102 @@ private[spark] object Utils extends Logging {
.getOrElse(UserGroupInformation.getCurrentUser().getShortUserName())
}
+ /**
+ * Adds a shutdown hook with default priority.
+ *
+ * @param hook The code to run during shutdown.
+ * @return A handle that can be used to unregister the shutdown hook.
+ */
+ def addShutdownHook(hook: () => Unit): AnyRef = {
+ addShutdownHook(DEFAULT_SHUTDOWN_PRIORITY, hook)
+ }
+
+ /**
+ * Adds a shutdown hook with the given priority. Hooks with lower priority values run
+ * first.
+ *
+ * @param hook The code to run during shutdown.
+ * @return A handle that can be used to unregister the shutdown hook.
+ */
+ def addShutdownHook(priority: Int, hook: () => Unit): AnyRef = {
+ shutdownHooks.add(priority, hook)
+ }
+
+ /**
+ * Remove a previously installed shutdown hook.
+ *
+ * @param ref A handle returned by `addShutdownHook`.
+ * @return Whether the hook was removed.
+ */
+ def removeShutdownHook(ref: AnyRef): Boolean = {
+ shutdownHooks.remove(ref)
+ }
+
+}
+
+private [util] class SparkShutdownHookManager {
+
+ private val hooks = new PriorityQueue[SparkShutdownHook]()
+ private var shuttingDown = false
+
+ /**
+ * Install a hook to run at shutdown and run all registered hooks in order. Hadoop 1.x does not
+ * have `ShutdownHookManager`, so in that case we just use the JVM's `Runtime` object and hope for
+ * the best.
+ */
+ def install(): Unit = {
+ val hookTask = new Runnable() {
+ override def run(): Unit = runAll()
+ }
+ Try(Class.forName("org.apache.hadoop.util.ShutdownHookManager")) match {
+ case Success(shmClass) =>
+ val fsPriority = classOf[FileSystem].getField("SHUTDOWN_HOOK_PRIORITY").get()
+ .asInstanceOf[Int]
+ val shm = shmClass.getMethod("get").invoke(null)
+ shm.getClass().getMethod("addShutdownHook", classOf[Runnable], classOf[Int])
+ .invoke(shm, hookTask, Integer.valueOf(fsPriority + 30))
+
+ case Failure(_) =>
+ Runtime.getRuntime.addShutdownHook(new Thread(hookTask, "Spark Shutdown Hook"));
+ }
+ }
+
+ def runAll(): Unit = synchronized {
+ shuttingDown = true
+ while (!hooks.isEmpty()) {
+ Utils.logUncaughtExceptions(hooks.poll().run())
+ }
+ }
+
+ def add(priority: Int, hook: () => Unit): AnyRef = synchronized {
+ checkState()
+ val hookRef = new SparkShutdownHook(priority, hook)
+ hooks.add(hookRef)
+ hookRef
+ }
+
+ def remove(ref: AnyRef): Boolean = synchronized {
+ checkState()
+ hooks.remove(ref)
+ }
+
+ private def checkState(): Unit = {
+ if (shuttingDown) {
+ throw new IllegalStateException("Shutdown hooks cannot be modified during shutdown.")
+ }
+ }
+
+}
+
+private class SparkShutdownHook(private val priority: Int, hook: () => Unit)
+ extends Comparable[SparkShutdownHook] {
+
+ override def compareTo(other: SparkShutdownHook): Int = {
+ other.priority - priority
+ }
+
+ def run(): Unit = hook()
+
}
/**