aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/org/apache
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala/org/apache')
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala112
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala15
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/util/AkkaUtils.scala19
-rw-r--r--core/src/main/scala/org/apache/spark/util/RpcUtils.scala20
9 files changed, 162 insertions, 42 deletions
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
index b3bb5f911d..334a5b1014 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
@@ -38,7 +38,7 @@ class WorkerWebUI(
extends WebUI(worker.securityMgr, requestedPort, worker.conf, name = "WorkerUI")
with Logging {
- private[ui] val timeout = RpcUtils.askTimeout(worker.conf)
+ private[ui] val timeout = RpcUtils.askRpcTimeout(worker.conf)
initialize()
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala
index 69181edb9a..6ae4789459 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala
@@ -17,8 +17,7 @@
package org.apache.spark.rpc
-import scala.concurrent.{Await, Future}
-import scala.concurrent.duration.FiniteDuration
+import scala.concurrent.Future
import scala.reflect.ClassTag
import org.apache.spark.util.RpcUtils
@@ -32,7 +31,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf)
private[this] val maxRetries = RpcUtils.numRetries(conf)
private[this] val retryWaitMs = RpcUtils.retryWaitMs(conf)
- private[this] val defaultAskTimeout = RpcUtils.askTimeout(conf)
+ private[this] val defaultAskTimeout = RpcUtils.askRpcTimeout(conf)
/**
* return the address for the [[RpcEndpointRef]]
@@ -52,7 +51,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf)
*
* This method only sends the message once and never retries.
*/
- def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T]
+ def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T]
/**
* Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a [[Future]] to
@@ -91,7 +90,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf)
* @tparam T type of the reply message
* @return the reply message from the corresponding [[RpcEndpoint]]
*/
- def askWithRetry[T: ClassTag](message: Any, timeout: FiniteDuration): T = {
+ def askWithRetry[T: ClassTag](message: Any, timeout: RpcTimeout): T = {
// TODO: Consider removing multiple attempts
var attempts = 0
var lastException: Exception = null
@@ -99,7 +98,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf)
attempts += 1
try {
val future = ask[T](message, timeout)
- val result = Await.result(future, timeout)
+ val result = timeout.awaitResult(future)
if (result == null) {
throw new SparkException("Actor returned null")
}
@@ -110,10 +109,14 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf)
lastException = e
logWarning(s"Error sending message [message = $message] in $attempts attempts", e)
}
- Thread.sleep(retryWaitMs)
+
+ if (attempts < maxRetries) {
+ Thread.sleep(retryWaitMs)
+ }
}
throw new SparkException(
s"Error sending message [message = $message]", lastException)
}
+
}
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
index 3b6938ec63..1709bdf560 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
@@ -18,8 +18,10 @@
package org.apache.spark.rpc
import java.net.URI
+import java.util.concurrent.TimeoutException
-import scala.concurrent.{Await, Future}
+import scala.concurrent.{Awaitable, Await, Future}
+import scala.concurrent.duration._
import scala.language.postfixOps
import org.apache.spark.{SecurityManager, SparkConf}
@@ -66,7 +68,7 @@ private[spark] object RpcEnv {
*/
private[spark] abstract class RpcEnv(conf: SparkConf) {
- private[spark] val defaultLookupTimeout = RpcUtils.lookupTimeout(conf)
+ private[spark] val defaultLookupTimeout = RpcUtils.lookupRpcTimeout(conf)
/**
* Return RpcEndpointRef of the registered [[RpcEndpoint]]. Will be used to implement
@@ -94,7 +96,7 @@ private[spark] abstract class RpcEnv(conf: SparkConf) {
* Retrieve the [[RpcEndpointRef]] represented by `uri`. This is a blocking action.
*/
def setupEndpointRefByURI(uri: String): RpcEndpointRef = {
- Await.result(asyncSetupEndpointRefByURI(uri), defaultLookupTimeout)
+ defaultLookupTimeout.awaitResult(asyncSetupEndpointRefByURI(uri))
}
/**
@@ -184,3 +186,107 @@ private[spark] object RpcAddress {
RpcAddress(host, port)
}
}
+
+
+/**
+ * An exception thrown if RpcTimeout modifies a [[TimeoutException]].
+ */
+private[rpc] class RpcTimeoutException(message: String, cause: TimeoutException)
+ extends TimeoutException(message) { initCause(cause) }
+
+
+/**
+ * Associates a timeout with a description so that a when a TimeoutException occurs, additional
+ * context about the timeout can be amended to the exception message.
+ * @param duration timeout duration in seconds
+ * @param timeoutProp the configuration property that controls this timeout
+ */
+private[spark] class RpcTimeout(val duration: FiniteDuration, val timeoutProp: String)
+ extends Serializable {
+
+ /** Amends the standard message of TimeoutException to include the description */
+ private def createRpcTimeoutException(te: TimeoutException): RpcTimeoutException = {
+ new RpcTimeoutException(te.getMessage() + ". This timeout is controlled by " + timeoutProp, te)
+ }
+
+ /**
+ * PartialFunction to match a TimeoutException and add the timeout description to the message
+ *
+ * @note This can be used in the recover callback of a Future to add to a TimeoutException
+ * Example:
+ * val timeout = new RpcTimeout(5 millis, "short timeout")
+ * Future(throw new TimeoutException).recover(timeout.addMessageIfTimeout)
+ */
+ def addMessageIfTimeout[T]: PartialFunction[Throwable, T] = {
+ // The exception has already been converted to a RpcTimeoutException so just raise it
+ case rte: RpcTimeoutException => throw rte
+ // Any other TimeoutException get converted to a RpcTimeoutException with modified message
+ case te: TimeoutException => throw createRpcTimeoutException(te)
+ }
+
+ /**
+ * Wait for the completed result and return it. If the result is not available within this
+ * timeout, throw a [[RpcTimeoutException]] to indicate which configuration controls the timeout.
+ * @param awaitable the `Awaitable` to be awaited
+ * @throws RpcTimeoutException if after waiting for the specified time `awaitable`
+ * is still not ready
+ */
+ def awaitResult[T](awaitable: Awaitable[T]): T = {
+ try {
+ Await.result(awaitable, duration)
+ } catch addMessageIfTimeout
+ }
+}
+
+
+private[spark] object RpcTimeout {
+
+ /**
+ * Lookup the timeout property in the configuration and create
+ * a RpcTimeout with the property key in the description.
+ * @param conf configuration properties containing the timeout
+ * @param timeoutProp property key for the timeout in seconds
+ * @throws NoSuchElementException if property is not set
+ */
+ def apply(conf: SparkConf, timeoutProp: String): RpcTimeout = {
+ val timeout = { conf.getTimeAsSeconds(timeoutProp) seconds }
+ new RpcTimeout(timeout, timeoutProp)
+ }
+
+ /**
+ * Lookup the timeout property in the configuration and create
+ * a RpcTimeout with the property key in the description.
+ * Uses the given default value if property is not set
+ * @param conf configuration properties containing the timeout
+ * @param timeoutProp property key for the timeout in seconds
+ * @param defaultValue default timeout value in seconds if property not found
+ */
+ def apply(conf: SparkConf, timeoutProp: String, defaultValue: String): RpcTimeout = {
+ val timeout = { conf.getTimeAsSeconds(timeoutProp, defaultValue) seconds }
+ new RpcTimeout(timeout, timeoutProp)
+ }
+
+ /**
+ * Lookup prioritized list of timeout properties in the configuration
+ * and create a RpcTimeout with the first set property key in the
+ * description.
+ * Uses the given default value if property is not set
+ * @param conf configuration properties containing the timeout
+ * @param timeoutPropList prioritized list of property keys for the timeout in seconds
+ * @param defaultValue default timeout value in seconds if no properties found
+ */
+ def apply(conf: SparkConf, timeoutPropList: Seq[String], defaultValue: String): RpcTimeout = {
+ require(timeoutPropList.nonEmpty)
+
+ // Find the first set property or use the default value with the first property
+ val itr = timeoutPropList.iterator
+ var foundProp: Option[(String, String)] = None
+ while (itr.hasNext && foundProp.isEmpty){
+ val propKey = itr.next()
+ conf.getOption(propKey).foreach { prop => foundProp = Some(propKey, prop) }
+ }
+ val finalProp = foundProp.getOrElse(timeoutPropList.head, defaultValue)
+ val timeout = { Utils.timeStringAsSeconds(finalProp._2) seconds }
+ new RpcTimeout(timeout, finalProp._1)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
index 31ebe5ac5b..f2d87f6834 100644
--- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
@@ -20,7 +20,6 @@ package org.apache.spark.rpc.akka
import java.util.concurrent.ConcurrentHashMap
import scala.concurrent.Future
-import scala.concurrent.duration._
import scala.language.postfixOps
import scala.reflect.ClassTag
import scala.util.control.NonFatal
@@ -214,8 +213,11 @@ private[spark] class AkkaRpcEnv private[akka] (
override def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = {
import actorSystem.dispatcher
- actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout).
- map(new AkkaRpcEndpointRef(defaultAddress, _, conf))
+ actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout.duration).
+ map(new AkkaRpcEndpointRef(defaultAddress, _, conf)).
+ // this is just in case there is a timeout from creating the future in resolveOne, we want the
+ // exception to indicate the conf that determines the timeout
+ recover(defaultLookupTimeout.addMessageIfTimeout)
}
override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = {
@@ -295,8 +297,8 @@ private[akka] class AkkaRpcEndpointRef(
actorRef ! AkkaMessage(message, false)
}
- override def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] = {
- actorRef.ask(AkkaMessage(message, true))(timeout).flatMap {
+ override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
+ actorRef.ask(AkkaMessage(message, true))(timeout.duration).flatMap {
// The function will run in the calling thread, so it should be short and never block.
case msg @ AkkaMessage(message, reply) =>
if (reply) {
@@ -307,7 +309,8 @@ private[akka] class AkkaRpcEndpointRef(
}
case AkkaFailure(e) =>
Future.failed(e)
- }(ThreadUtils.sameThread).mapTo[T]
+ }(ThreadUtils.sameThread).mapTo[T].
+ recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread)
}
override def toString: String = s"${getClass.getSimpleName}($actorRef)"
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index a7cf0c23d9..6841fa8357 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -35,6 +35,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd.RDD
+import org.apache.spark.rpc.RpcTimeout
import org.apache.spark.storage._
import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util._
@@ -188,7 +189,7 @@ class DAGScheduler(
blockManagerId: BlockManagerId): Boolean = {
listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics))
blockManagerMaster.driverEndpoint.askWithRetry[Boolean](
- BlockManagerHeartbeat(blockManagerId), 600 seconds)
+ BlockManagerHeartbeat(blockManagerId), new RpcTimeout(600 seconds, "BlockManagerHeartbeat"))
}
// Called by TaskScheduler when an executor fails.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
index 190ff61d68..bc67abb5df 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
@@ -46,7 +46,7 @@ private[spark] abstract class YarnSchedulerBackend(
private val yarnSchedulerEndpoint = rpcEnv.setupEndpoint(
YarnSchedulerBackend.ENDPOINT_NAME, new YarnSchedulerEndpoint(rpcEnv))
- private implicit val askTimeout = RpcUtils.askTimeout(sc.conf)
+ private implicit val askTimeout = RpcUtils.askRpcTimeout(sc.conf)
/**
* Request executors from the ApplicationMaster by specifying the total number desired.
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index 7cdae22b0e..f70f701494 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -33,7 +33,7 @@ class BlockManagerMaster(
isDriver: Boolean)
extends Logging {
- val timeout = RpcUtils.askTimeout(conf)
+ val timeout = RpcUtils.askRpcTimeout(conf)
/** Remove a dead executor from the driver endpoint. This is only called on the driver side. */
def removeExecutor(execId: String) {
@@ -106,7 +106,7 @@ class BlockManagerMaster(
logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}", e)
}(ThreadUtils.sameThread)
if (blocking) {
- Await.result(future, timeout)
+ timeout.awaitResult(future)
}
}
@@ -118,7 +118,7 @@ class BlockManagerMaster(
logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}", e)
}(ThreadUtils.sameThread)
if (blocking) {
- Await.result(future, timeout)
+ timeout.awaitResult(future)
}
}
@@ -132,7 +132,7 @@ class BlockManagerMaster(
s" with removeFromMaster = $removeFromMaster - ${e.getMessage}}", e)
}(ThreadUtils.sameThread)
if (blocking) {
- Await.result(future, timeout)
+ timeout.awaitResult(future)
}
}
@@ -176,8 +176,8 @@ class BlockManagerMaster(
CanBuildFrom[Iterable[Future[Option[BlockStatus]]],
Option[BlockStatus],
Iterable[Option[BlockStatus]]]]
- val blockStatus = Await.result(
- Future.sequence[Option[BlockStatus], Iterable](futures)(cbf, ThreadUtils.sameThread), timeout)
+ val blockStatus = timeout.awaitResult(
+ Future.sequence[Option[BlockStatus], Iterable](futures)(cbf, ThreadUtils.sameThread))
if (blockStatus == null) {
throw new SparkException("BlockManager returned null for BlockStatus query: " + blockId)
}
@@ -199,7 +199,7 @@ class BlockManagerMaster(
askSlaves: Boolean): Seq[BlockId] = {
val msg = GetMatchingBlockIds(filter, askSlaves)
val future = driverEndpoint.askWithRetry[Future[Seq[BlockId]]](msg)
- Await.result(future, timeout)
+ timeout.awaitResult(future)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
index 96aa2fe164..c179833e5b 100644
--- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
@@ -18,8 +18,6 @@
package org.apache.spark.util
import scala.collection.JavaConversions.mapAsJavaMap
-import scala.concurrent.Await
-import scala.concurrent.duration.FiniteDuration
import akka.actor.{ActorRef, ActorSystem, ExtendedActorSystem}
import akka.pattern.ask
@@ -28,6 +26,7 @@ import com.typesafe.config.ConfigFactory
import org.apache.log4j.{Level, Logger}
import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv, SparkException}
+import org.apache.spark.rpc.RpcTimeout
/**
* Various utility classes for working with Akka.
@@ -147,7 +146,7 @@ private[spark] object AkkaUtils extends Logging {
def askWithReply[T](
message: Any,
actor: ActorRef,
- timeout: FiniteDuration): T = {
+ timeout: RpcTimeout): T = {
askWithReply[T](message, actor, maxAttempts = 1, retryInterval = Int.MaxValue, timeout)
}
@@ -160,7 +159,7 @@ private[spark] object AkkaUtils extends Logging {
actor: ActorRef,
maxAttempts: Int,
retryInterval: Long,
- timeout: FiniteDuration): T = {
+ timeout: RpcTimeout): T = {
// TODO: Consider removing multiple attempts
if (actor == null) {
throw new SparkException(s"Error sending message [message = $message]" +
@@ -171,8 +170,8 @@ private[spark] object AkkaUtils extends Logging {
while (attempts < maxAttempts) {
attempts += 1
try {
- val future = actor.ask(message)(timeout)
- val result = Await.result(future, timeout)
+ val future = actor.ask(message)(timeout.duration)
+ val result = timeout.awaitResult(future)
if (result == null) {
throw new SparkException("Actor returned null")
}
@@ -198,9 +197,9 @@ private[spark] object AkkaUtils extends Logging {
val driverPort: Int = conf.getInt("spark.driver.port", 7077)
Utils.checkHost(driverHost, "Expected hostname")
val url = address(protocol(actorSystem), driverActorSystemName, driverHost, driverPort, name)
- val timeout = RpcUtils.lookupTimeout(conf)
+ val timeout = RpcUtils.lookupRpcTimeout(conf)
logInfo(s"Connecting to $name: $url")
- Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout)
+ timeout.awaitResult(actorSystem.actorSelection(url).resolveOne(timeout.duration))
}
def makeExecutorRef(
@@ -212,9 +211,9 @@ private[spark] object AkkaUtils extends Logging {
val executorActorSystemName = SparkEnv.executorActorSystemName
Utils.checkHost(host, "Expected hostname")
val url = address(protocol(actorSystem), executorActorSystemName, host, port, name)
- val timeout = RpcUtils.lookupTimeout(conf)
+ val timeout = RpcUtils.lookupRpcTimeout(conf)
logInfo(s"Connecting to $name: $url")
- Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout)
+ timeout.awaitResult(actorSystem.actorSelection(url).resolveOne(timeout.duration))
}
def protocol(actorSystem: ActorSystem): String = {
diff --git a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala
index f16cc8e7e4..7578a3b1d8 100644
--- a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala
@@ -17,11 +17,11 @@
package org.apache.spark.util
-import scala.concurrent.duration._
+import scala.concurrent.duration.FiniteDuration
import scala.language.postfixOps
import org.apache.spark.{SparkEnv, SparkConf}
-import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv}
+import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv, RpcTimeout}
object RpcUtils {
@@ -47,14 +47,22 @@ object RpcUtils {
}
/** Returns the default Spark timeout to use for RPC ask operations. */
+ private[spark] def askRpcTimeout(conf: SparkConf): RpcTimeout = {
+ RpcTimeout(conf, Seq("spark.rpc.askTimeout", "spark.network.timeout"), "120s")
+ }
+
+ @deprecated("use askRpcTimeout instead, this method was not intended to be public", "1.5.0")
def askTimeout(conf: SparkConf): FiniteDuration = {
- conf.getTimeAsSeconds("spark.rpc.askTimeout",
- conf.get("spark.network.timeout", "120s")) seconds
+ askRpcTimeout(conf).duration
}
/** Returns the default Spark timeout to use for RPC remote endpoint lookup. */
+ private[spark] def lookupRpcTimeout(conf: SparkConf): RpcTimeout = {
+ RpcTimeout(conf, Seq("spark.rpc.lookupTimeout", "spark.network.timeout"), "120s")
+ }
+
+ @deprecated("use lookupRpcTimeout instead, this method was not intended to be public", "1.5.0")
def lookupTimeout(conf: SparkConf): FiniteDuration = {
- conf.getTimeAsSeconds("spark.rpc.lookupTimeout",
- conf.get("spark.network.timeout", "120s")) seconds
+ lookupRpcTimeout(conf).duration
}
}