aboutsummaryrefslogtreecommitdiff
path: root/yarn
diff options
context:
space:
mode:
authorMarcelo Vanzin <vanzin@cloudera.com>2016-08-17 11:12:21 -0700
committerMarcelo Vanzin <vanzin@cloudera.com>2016-08-17 11:12:21 -0700
commite3fec51fa1ed161789ab7aa32ed36efe357b5d31 (patch)
tree963eced70149018b74120f532e97484c0d69226e /yarn
parent928ca1c6d12b23d84f9b6205e22d2e756311f072 (diff)
downloadspark-e3fec51fa1ed161789ab7aa32ed36efe357b5d31.tar.gz
spark-e3fec51fa1ed161789ab7aa32ed36efe357b5d31.tar.bz2
spark-e3fec51fa1ed161789ab7aa32ed36efe357b5d31.zip
[SPARK-16930][YARN] Fix a couple of races in cluster app initialization.
There are two narrow races that could cause the ApplicationMaster to miss when the user application instantiates the SparkContext, which could cause app failures when nothing was wrong with the app. It was also possible for a failing application to get stuck in the loop that waits for the context for a long time, instead of failing quickly. The change uses a promise to track the SparkContext instance, which gets rid of the races and allows for some simplification of the code. Tested with existing unit tests, and a new one being added to test the timeout code. Author: Marcelo Vanzin <vanzin@cloudera.com> Closes #14542 from vanzin/SPARK-16930.
Diffstat (limited to 'yarn')
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala98
-rw-r--r--yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala5
-rw-r--r--yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala22
3 files changed, 66 insertions, 59 deletions
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 614278c8b2..a4b575c85d 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -20,9 +20,11 @@ package org.apache.spark.deploy.yarn
import java.io.{File, IOException}
import java.lang.reflect.InvocationTargetException
import java.net.{Socket, URI, URL}
-import java.util.concurrent.atomic.AtomicReference
+import java.util.concurrent.{TimeoutException, TimeUnit}
import scala.collection.mutable.HashMap
+import scala.concurrent.Promise
+import scala.concurrent.duration.Duration
import scala.util.control.NonFatal
import org.apache.hadoop.fs.{FileSystem, Path}
@@ -106,12 +108,11 @@ private[spark] class ApplicationMaster(
// Next wait interval before allocator poll.
private var nextAllocationInterval = initialAllocationInterval
- // Fields used in client mode.
private var rpcEnv: RpcEnv = null
private var amEndpoint: RpcEndpointRef = _
- // Fields used in cluster mode.
- private val sparkContextRef = new AtomicReference[SparkContext](null)
+ // In cluster mode, used to tell the AM when the user's SparkContext has been initialized.
+ private val sparkContextPromise = Promise[SparkContext]()
private var credentialRenewer: AMCredentialRenewer = _
@@ -316,23 +317,15 @@ private[spark] class ApplicationMaster(
}
private def sparkContextInitialized(sc: SparkContext) = {
- sparkContextRef.synchronized {
- sparkContextRef.compareAndSet(null, sc)
- sparkContextRef.notifyAll()
- }
- }
-
- private def sparkContextStopped(sc: SparkContext) = {
- sparkContextRef.compareAndSet(sc, null)
+ sparkContextPromise.success(sc)
}
private def registerAM(
+ _sparkConf: SparkConf,
_rpcEnv: RpcEnv,
driverRef: RpcEndpointRef,
uiAddress: String,
securityMgr: SecurityManager) = {
- val sc = sparkContextRef.get()
-
val appId = client.getAttemptId().getApplicationId().toString()
val attemptId = client.getAttemptId().getAttemptId().toString()
val historyAddress =
@@ -341,7 +334,6 @@ private[spark] class ApplicationMaster(
.map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}/${attemptId}" }
.getOrElse("")
- val _sparkConf = if (sc != null) sc.getConf else sparkConf
val driverUrl = RpcEndpointAddress(
_sparkConf.get("spark.driver.host"),
_sparkConf.get("spark.driver.port").toInt,
@@ -385,21 +377,35 @@ private[spark] class ApplicationMaster(
// This a bit hacky, but we need to wait until the spark.driver.port property has
// been set by the Thread executing the user class.
- val sc = waitForSparkContextInitialized()
-
- // If there is no SparkContext at this point, just fail the app.
- if (sc == null) {
- finish(FinalApplicationStatus.FAILED,
- ApplicationMaster.EXIT_SC_NOT_INITED,
- "Timed out waiting for SparkContext.")
- } else {
- rpcEnv = sc.env.rpcEnv
- val driverRef = runAMEndpoint(
- sc.getConf.get("spark.driver.host"),
- sc.getConf.get("spark.driver.port"),
- isClusterMode = true)
- registerAM(rpcEnv, driverRef, sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr)
+ logInfo("Waiting for spark context initialization...")
+ val totalWaitTime = sparkConf.get(AM_MAX_WAIT_TIME)
+ try {
+ val sc = ThreadUtils.awaitResult(sparkContextPromise.future,
+ Duration(totalWaitTime, TimeUnit.MILLISECONDS))
+ if (sc != null) {
+ rpcEnv = sc.env.rpcEnv
+ val driverRef = runAMEndpoint(
+ sc.getConf.get("spark.driver.host"),
+ sc.getConf.get("spark.driver.port"),
+ isClusterMode = true)
+ registerAM(sc.getConf, rpcEnv, driverRef, sc.ui.map(_.appUIAddress).getOrElse(""),
+ securityMgr)
+ } else {
+ // Sanity check; should never happen in normal operation, since sc should only be null
+ // if the user app did not create a SparkContext.
+ if (!finished) {
+ throw new IllegalStateException("SparkContext is null but app is still running!")
+ }
+ }
userClassThread.join()
+ } catch {
+ case e: SparkException if e.getCause().isInstanceOf[TimeoutException] =>
+ logError(
+ s"SparkContext did not initialize after waiting for $totalWaitTime ms. " +
+ "Please check earlier log output for errors. Failing the application.")
+ finish(FinalApplicationStatus.FAILED,
+ ApplicationMaster.EXIT_SC_NOT_INITED,
+ "Timed out waiting for SparkContext.")
}
}
@@ -409,7 +415,8 @@ private[spark] class ApplicationMaster(
clientMode = true)
val driverRef = waitForSparkDriver()
addAmIpFilter()
- registerAM(rpcEnv, driverRef, sparkConf.get("spark.driver.appUIAddress", ""), securityMgr)
+ registerAM(sparkConf, rpcEnv, driverRef, sparkConf.get("spark.driver.appUIAddress", ""),
+ securityMgr)
// In client mode the actor will stop the reporter thread.
reporterThread.join()
@@ -525,26 +532,6 @@ private[spark] class ApplicationMaster(
}
}
- private def waitForSparkContextInitialized(): SparkContext = {
- logInfo("Waiting for spark context initialization")
- sparkContextRef.synchronized {
- val totalWaitTime = sparkConf.get(AM_MAX_WAIT_TIME)
- val deadline = System.currentTimeMillis() + totalWaitTime
-
- while (sparkContextRef.get() == null && System.currentTimeMillis < deadline && !finished) {
- logInfo("Waiting for spark context initialization ... ")
- sparkContextRef.wait(10000L)
- }
-
- val sparkContext = sparkContextRef.get()
- if (sparkContext == null) {
- logError(("SparkContext did not initialize after waiting for %d ms. Please check earlier"
- + " log output for errors. Failing the application.").format(totalWaitTime))
- }
- sparkContext
- }
- }
-
private def waitForSparkDriver(): RpcEndpointRef = {
logInfo("Waiting for Spark driver to be reachable.")
var driverUp = false
@@ -647,6 +634,13 @@ private[spark] class ApplicationMaster(
ApplicationMaster.EXIT_EXCEPTION_USER_CLASS,
"User class threw exception: " + cause)
}
+ sparkContextPromise.tryFailure(e.getCause())
+ } finally {
+ // Notify the thread waiting for the SparkContext, in case the application did not
+ // instantiate one. This will do nothing when the user code instantiates a SparkContext
+ // (with the correct master), or when the user code throws an exception (due to the
+ // tryFailure above).
+ sparkContextPromise.trySuccess(null)
}
}
}
@@ -759,10 +753,6 @@ object ApplicationMaster extends Logging {
master.sparkContextInitialized(sc)
}
- private[spark] def sparkContextStopped(sc: SparkContext): Boolean = {
- master.sparkContextStopped(sc)
- }
-
private[spark] def getAttemptId(): ApplicationAttemptId = {
master.getAttemptId
}
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
index 72ec4d6b34..96c9151fc3 100644
--- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
@@ -34,9 +34,4 @@ private[spark] class YarnClusterScheduler(sc: SparkContext) extends YarnSchedule
logInfo("YarnClusterScheduler.postStartHook done")
}
- override def stop() {
- super.stop()
- ApplicationMaster.sparkContextStopped(sc)
- }
-
}
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
index 8ab7b21c22..fb7926f6a1 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
@@ -33,6 +33,7 @@ import org.scalatest.concurrent.Eventually._
import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.deploy.yarn.config._
import org.apache.spark.internal.Logging
import org.apache.spark.launcher._
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart,
@@ -192,6 +193,14 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
}
}
+ test("timeout to get SparkContext in cluster mode triggers failure") {
+ val timeout = 2000
+ val finalState = runSpark(false, mainClassName(SparkContextTimeoutApp.getClass),
+ appArgs = Seq((timeout * 4).toString),
+ extraConf = Map(AM_MAX_WAIT_TIME.key -> timeout.toString))
+ finalState should be (SparkAppHandle.State.FAILED)
+ }
+
private def testBasicYarnApp(clientMode: Boolean, conf: Map[String, String] = Map()): Unit = {
val result = File.createTempFile("result", null, tempDir)
val finalState = runSpark(clientMode, mainClassName(YarnClusterDriver.getClass),
@@ -469,3 +478,16 @@ private object YarnLauncherTestApp {
}
}
+
+/**
+ * Used to test code in the AM that detects the SparkContext instance. Expects a single argument
+ * with the duration to sleep for, in ms.
+ */
+private object SparkContextTimeoutApp {
+
+ def main(args: Array[String]): Unit = {
+ val Array(sleepTime) = args
+ Thread.sleep(java.lang.Long.parseLong(sleepTime))
+ }
+
+}