aboutsummaryrefslogtreecommitdiff
path: root/yarn
diff options
context:
space:
mode:
Diffstat (limited to 'yarn')
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala98
-rw-r--r--yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala5
-rw-r--r--yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala22
3 files changed, 66 insertions, 59 deletions
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 614278c8b2..a4b575c85d 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -20,9 +20,11 @@ package org.apache.spark.deploy.yarn
import java.io.{File, IOException}
import java.lang.reflect.InvocationTargetException
import java.net.{Socket, URI, URL}
-import java.util.concurrent.atomic.AtomicReference
+import java.util.concurrent.{TimeoutException, TimeUnit}
import scala.collection.mutable.HashMap
+import scala.concurrent.Promise
+import scala.concurrent.duration.Duration
import scala.util.control.NonFatal
import org.apache.hadoop.fs.{FileSystem, Path}
@@ -106,12 +108,11 @@ private[spark] class ApplicationMaster(
// Next wait interval before allocator poll.
private var nextAllocationInterval = initialAllocationInterval
- // Fields used in client mode.
private var rpcEnv: RpcEnv = null
private var amEndpoint: RpcEndpointRef = _
- // Fields used in cluster mode.
- private val sparkContextRef = new AtomicReference[SparkContext](null)
+ // In cluster mode, used to tell the AM when the user's SparkContext has been initialized.
+ private val sparkContextPromise = Promise[SparkContext]()
private var credentialRenewer: AMCredentialRenewer = _
@@ -316,23 +317,15 @@ private[spark] class ApplicationMaster(
}
private def sparkContextInitialized(sc: SparkContext) = {
- sparkContextRef.synchronized {
- sparkContextRef.compareAndSet(null, sc)
- sparkContextRef.notifyAll()
- }
- }
-
- private def sparkContextStopped(sc: SparkContext) = {
- sparkContextRef.compareAndSet(sc, null)
+ sparkContextPromise.success(sc)
}
private def registerAM(
+ _sparkConf: SparkConf,
_rpcEnv: RpcEnv,
driverRef: RpcEndpointRef,
uiAddress: String,
securityMgr: SecurityManager) = {
- val sc = sparkContextRef.get()
-
val appId = client.getAttemptId().getApplicationId().toString()
val attemptId = client.getAttemptId().getAttemptId().toString()
val historyAddress =
@@ -341,7 +334,6 @@ private[spark] class ApplicationMaster(
.map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}/${attemptId}" }
.getOrElse("")
- val _sparkConf = if (sc != null) sc.getConf else sparkConf
val driverUrl = RpcEndpointAddress(
_sparkConf.get("spark.driver.host"),
_sparkConf.get("spark.driver.port").toInt,
@@ -385,21 +377,35 @@ private[spark] class ApplicationMaster(
// This a bit hacky, but we need to wait until the spark.driver.port property has
// been set by the Thread executing the user class.
- val sc = waitForSparkContextInitialized()
-
- // If there is no SparkContext at this point, just fail the app.
- if (sc == null) {
- finish(FinalApplicationStatus.FAILED,
- ApplicationMaster.EXIT_SC_NOT_INITED,
- "Timed out waiting for SparkContext.")
- } else {
- rpcEnv = sc.env.rpcEnv
- val driverRef = runAMEndpoint(
- sc.getConf.get("spark.driver.host"),
- sc.getConf.get("spark.driver.port"),
- isClusterMode = true)
- registerAM(rpcEnv, driverRef, sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr)
+ logInfo("Waiting for spark context initialization...")
+ val totalWaitTime = sparkConf.get(AM_MAX_WAIT_TIME)
+ try {
+ val sc = ThreadUtils.awaitResult(sparkContextPromise.future,
+ Duration(totalWaitTime, TimeUnit.MILLISECONDS))
+ if (sc != null) {
+ rpcEnv = sc.env.rpcEnv
+ val driverRef = runAMEndpoint(
+ sc.getConf.get("spark.driver.host"),
+ sc.getConf.get("spark.driver.port"),
+ isClusterMode = true)
+ registerAM(sc.getConf, rpcEnv, driverRef, sc.ui.map(_.appUIAddress).getOrElse(""),
+ securityMgr)
+ } else {
+ // Sanity check; should never happen in normal operation, since sc should only be null
+ // if the user app did not create a SparkContext.
+ if (!finished) {
+ throw new IllegalStateException("SparkContext is null but app is still running!")
+ }
+ }
userClassThread.join()
+ } catch {
+ case e: SparkException if e.getCause().isInstanceOf[TimeoutException] =>
+ logError(
+ s"SparkContext did not initialize after waiting for $totalWaitTime ms. " +
+ "Please check earlier log output for errors. Failing the application.")
+ finish(FinalApplicationStatus.FAILED,
+ ApplicationMaster.EXIT_SC_NOT_INITED,
+ "Timed out waiting for SparkContext.")
}
}
@@ -409,7 +415,8 @@ private[spark] class ApplicationMaster(
clientMode = true)
val driverRef = waitForSparkDriver()
addAmIpFilter()
- registerAM(rpcEnv, driverRef, sparkConf.get("spark.driver.appUIAddress", ""), securityMgr)
+ registerAM(sparkConf, rpcEnv, driverRef, sparkConf.get("spark.driver.appUIAddress", ""),
+ securityMgr)
// In client mode the actor will stop the reporter thread.
reporterThread.join()
@@ -525,26 +532,6 @@ private[spark] class ApplicationMaster(
}
}
- private def waitForSparkContextInitialized(): SparkContext = {
- logInfo("Waiting for spark context initialization")
- sparkContextRef.synchronized {
- val totalWaitTime = sparkConf.get(AM_MAX_WAIT_TIME)
- val deadline = System.currentTimeMillis() + totalWaitTime
-
- while (sparkContextRef.get() == null && System.currentTimeMillis < deadline && !finished) {
- logInfo("Waiting for spark context initialization ... ")
- sparkContextRef.wait(10000L)
- }
-
- val sparkContext = sparkContextRef.get()
- if (sparkContext == null) {
- logError(("SparkContext did not initialize after waiting for %d ms. Please check earlier"
- + " log output for errors. Failing the application.").format(totalWaitTime))
- }
- sparkContext
- }
- }
-
private def waitForSparkDriver(): RpcEndpointRef = {
logInfo("Waiting for Spark driver to be reachable.")
var driverUp = false
@@ -647,6 +634,13 @@ private[spark] class ApplicationMaster(
ApplicationMaster.EXIT_EXCEPTION_USER_CLASS,
"User class threw exception: " + cause)
}
+ sparkContextPromise.tryFailure(e.getCause())
+ } finally {
+ // Notify the thread waiting for the SparkContext, in case the application did not
+ // instantiate one. This will do nothing when the user code instantiates a SparkContext
+ // (with the correct master), or when the user code throws an exception (due to the
+ // tryFailure above).
+ sparkContextPromise.trySuccess(null)
}
}
}
@@ -759,10 +753,6 @@ object ApplicationMaster extends Logging {
master.sparkContextInitialized(sc)
}
- private[spark] def sparkContextStopped(sc: SparkContext): Boolean = {
- master.sparkContextStopped(sc)
- }
-
private[spark] def getAttemptId(): ApplicationAttemptId = {
master.getAttemptId
}
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
index 72ec4d6b34..96c9151fc3 100644
--- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
@@ -34,9 +34,4 @@ private[spark] class YarnClusterScheduler(sc: SparkContext) extends YarnSchedule
logInfo("YarnClusterScheduler.postStartHook done")
}
- override def stop() {
- super.stop()
- ApplicationMaster.sparkContextStopped(sc)
- }
-
}
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
index 8ab7b21c22..fb7926f6a1 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
@@ -33,6 +33,7 @@ import org.scalatest.concurrent.Eventually._
import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.deploy.yarn.config._
import org.apache.spark.internal.Logging
import org.apache.spark.launcher._
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart,
@@ -192,6 +193,14 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
}
}
+ test("timeout to get SparkContext in cluster mode triggers failure") {
+ val timeout = 2000
+ val finalState = runSpark(false, mainClassName(SparkContextTimeoutApp.getClass),
+ appArgs = Seq((timeout * 4).toString),
+ extraConf = Map(AM_MAX_WAIT_TIME.key -> timeout.toString))
+ finalState should be (SparkAppHandle.State.FAILED)
+ }
+
private def testBasicYarnApp(clientMode: Boolean, conf: Map[String, String] = Map()): Unit = {
val result = File.createTempFile("result", null, tempDir)
val finalState = runSpark(clientMode, mainClassName(YarnClusterDriver.getClass),
@@ -469,3 +478,16 @@ private object YarnLauncherTestApp {
}
}
+
+/**
+ * Used to test code in the AM that detects the SparkContext instance. Expects a single argument
+ * with the duration to sleep for, in ms.
+ */
+private object SparkContextTimeoutApp {
+
+ def main(args: Array[String]): Unit = {
+ val Array(sleepTime) = args
+ Thread.sleep(java.lang.Long.parseLong(sleepTime))
+ }
+
+}