aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala20
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala245
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala18
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala50
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala11
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala7
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala9
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala8
-rw-r--r--network/yarn/pom.xml5
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala6
17 files changed, 266 insertions, 190 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 398e093690..23ae9360f6 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -252,7 +252,8 @@ object SparkEnv extends Logging {
// Create the ActorSystem for Akka and get the port it binds to.
val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName
- val rpcEnv = RpcEnv.create(actorSystemName, hostname, port, conf, securityManager)
+ val rpcEnv = RpcEnv.create(actorSystemName, hostname, port, conf, securityManager,
+ clientMode = !isDriver)
val actorSystem: ActorSystem =
if (rpcEnv.isInstanceOf[AkkaRpcEnv]) {
rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem
@@ -262,9 +263,11 @@ object SparkEnv extends Logging {
}
// Figure out which port Akka actually bound to in case the original port is 0 or occupied.
+ // In the non-driver case, the RPC env's address may be null since it may not be listening
+ // for incoming connections.
if (isDriver) {
conf.set("spark.driver.port", rpcEnv.address.port.toString)
- } else {
+ } else if (rpcEnv.address != null) {
conf.set("spark.executor.port", rpcEnv.address.port.toString)
}
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index a9c6a05ecd..c2ebf30596 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -45,8 +45,6 @@ private[spark] class CoarseGrainedExecutorBackend(
env: SparkEnv)
extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging {
- Utils.checkHostPort(hostPort, "Expected hostport")
-
var executor: Executor = null
@volatile var driver: Option[RpcEndpointRef] = None
@@ -80,9 +78,8 @@ private[spark] class CoarseGrainedExecutorBackend(
}
override def receive: PartialFunction[Any, Unit] = {
- case RegisteredExecutor =>
+ case RegisteredExecutor(hostname) =>
logInfo("Successfully registered with driver")
- val (hostname, _) = Utils.parseHostPort(hostPort)
executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false)
case RegisterExecutorFailed(message) =>
@@ -163,7 +160,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
hostname,
port,
executorConf,
- new SecurityManager(executorConf))
+ new SecurityManager(executorConf),
+ clientMode = true)
val driver = fetcher.setupEndpointRefByURI(driverUrl)
val props = driver.askWithRetry[Seq[(String, String)]](RetrieveSparkProps) ++
Seq[(String, String)](("spark.app.id", appId))
@@ -188,12 +186,12 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
val env = SparkEnv.createExecutorEnv(
driverConf, executorId, hostname, port, cores, isLocal = false)
- // SparkEnv sets spark.driver.port so it shouldn't be 0 anymore.
- val boundPort = env.conf.getInt("spark.executor.port", 0)
- assert(boundPort != 0)
-
- // Start the CoarseGrainedExecutorBackend endpoint.
- val sparkHostPort = hostname + ":" + boundPort
+ // SparkEnv will set spark.executor.port if the rpc env is listening for incoming
+ // connections (e.g., if it's using akka). Otherwise, the executor is running in
+ // client mode only, and does not accept incoming connections.
+ val sparkHostPort = env.conf.getOption("spark.executor.port").map { port =>
+ hostname + ":" + port
+ }.orNull
env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend(
env.rpcEnv, driverUrl, executorId, sparkHostPort, cores, userClassPath, env))
workerUrl.foreach { url =>
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
index 2c4a8b9a0a..a560fd10cd 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
@@ -43,9 +43,10 @@ private[spark] object RpcEnv {
host: String,
port: Int,
conf: SparkConf,
- securityManager: SecurityManager): RpcEnv = {
+ securityManager: SecurityManager,
+ clientMode: Boolean = false): RpcEnv = {
// Using Reflection to create the RpcEnv to avoid to depend on Akka directly
- val config = RpcEnvConfig(conf, name, host, port, securityManager)
+ val config = RpcEnvConfig(conf, name, host, port, securityManager, clientMode)
getRpcEnvFactory(conf).create(config)
}
}
@@ -139,4 +140,5 @@ private[spark] case class RpcEnvConfig(
name: String,
host: String,
port: Int,
- securityManager: SecurityManager)
+ securityManager: SecurityManager,
+ clientMode: Boolean)
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
index 7bf44a6565..eb25d6c7b7 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
@@ -55,7 +55,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
private var stopped = false
def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = {
- val addr = new RpcEndpointAddress(nettyEnv.address.host, nettyEnv.address.port, name)
+ val addr = RpcEndpointAddress(nettyEnv.address, name)
val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv)
synchronized {
if (stopped) {
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
index 284284eb80..09093819bb 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
@@ -17,10 +17,12 @@
package org.apache.spark.rpc.netty
import java.io._
+import java.lang.{Boolean => JBoolean}
import java.net.{InetSocketAddress, URI}
import java.nio.ByteBuffer
import java.util.concurrent._
import java.util.concurrent.atomic.AtomicBoolean
+import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy
import scala.collection.mutable
@@ -29,6 +31,7 @@ import scala.reflect.ClassTag
import scala.util.{DynamicVariable, Failure, Success}
import scala.util.control.NonFatal
+import com.google.common.base.Preconditions
import org.apache.spark.{Logging, SecurityManager, SparkConf}
import org.apache.spark.network.TransportContext
import org.apache.spark.network.client._
@@ -45,15 +48,14 @@ private[netty] class NettyRpcEnv(
host: String,
securityManager: SecurityManager) extends RpcEnv(conf) with Logging {
- // Override numConnectionsPerPeer to 1 for RPC.
private val transportConf = SparkTransportConf.fromSparkConf(
conf.clone.set("spark.shuffle.io.numConnectionsPerPeer", "1"),
conf.getInt("spark.rpc.io.threads", 0))
private val dispatcher: Dispatcher = new Dispatcher(this)
- private val transportContext =
- new TransportContext(transportConf, new NettyRpcHandler(dispatcher, this))
+ private val transportContext = new TransportContext(transportConf,
+ new NettyRpcHandler(dispatcher, this))
private val clientFactory = {
val bootstraps: java.util.List[TransportClientBootstrap] =
@@ -95,7 +97,7 @@ private[netty] class NettyRpcEnv(
}
}
- def start(port: Int): Unit = {
+ def startServer(port: Int): Unit = {
val bootstraps: java.util.List[TransportServerBootstrap] =
if (securityManager.isAuthenticationEnabled()) {
java.util.Arrays.asList(new SaslServerBootstrap(transportConf, securityManager))
@@ -107,9 +109,9 @@ private[netty] class NettyRpcEnv(
RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher))
}
+ @Nullable
override lazy val address: RpcAddress = {
- require(server != null, "NettyRpcEnv has not yet started")
- RpcAddress(host, server.getPort)
+ if (server != null) RpcAddress(host, server.getPort()) else null
}
override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = {
@@ -120,7 +122,7 @@ private[netty] class NettyRpcEnv(
val addr = RpcEndpointAddress(uri)
val endpointRef = new NettyRpcEndpointRef(conf, addr, this)
val verifier = new NettyRpcEndpointRef(
- conf, RpcEndpointAddress(addr.host, addr.port, RpcEndpointVerifier.NAME), this)
+ conf, RpcEndpointAddress(addr.rpcAddress, RpcEndpointVerifier.NAME), this)
verifier.ask[Boolean](RpcEndpointVerifier.CheckExistence(endpointRef.name)).flatMap { find =>
if (find) {
Future.successful(endpointRef)
@@ -135,28 +137,34 @@ private[netty] class NettyRpcEnv(
dispatcher.stop(endpointRef)
}
- private def postToOutbox(address: RpcAddress, message: OutboxMessage): Unit = {
- val targetOutbox = {
- val outbox = outboxes.get(address)
- if (outbox == null) {
- val newOutbox = new Outbox(this, address)
- val oldOutbox = outboxes.putIfAbsent(address, newOutbox)
- if (oldOutbox == null) {
- newOutbox
+ private def postToOutbox(receiver: NettyRpcEndpointRef, message: OutboxMessage): Unit = {
+ if (receiver.client != null) {
+ receiver.client.sendRpc(message.content, message.createCallback(receiver.client));
+ } else {
+ require(receiver.address != null,
+ "Cannot send message to client endpoint with no listen address.")
+ val targetOutbox = {
+ val outbox = outboxes.get(receiver.address)
+ if (outbox == null) {
+ val newOutbox = new Outbox(this, receiver.address)
+ val oldOutbox = outboxes.putIfAbsent(receiver.address, newOutbox)
+ if (oldOutbox == null) {
+ newOutbox
+ } else {
+ oldOutbox
+ }
} else {
- oldOutbox
+ outbox
}
+ }
+ if (stopped.get) {
+ // It's possible that we put `targetOutbox` after stopping. So we need to clean it.
+ outboxes.remove(receiver.address)
+ targetOutbox.stop()
} else {
- outbox
+ targetOutbox.send(message)
}
}
- if (stopped.get) {
- // It's possible that we put `targetOutbox` after stopping. So we need to clean it.
- outboxes.remove(address)
- targetOutbox.stop()
- } else {
- targetOutbox.send(message)
- }
}
private[netty] def send(message: RequestMessage): Unit = {
@@ -174,17 +182,14 @@ private[netty] class NettyRpcEnv(
}(ThreadUtils.sameThread)
} else {
// Message to a remote RPC endpoint.
- postToOutbox(remoteAddr, OutboxMessage(serialize(message), new RpcResponseCallback {
-
- override def onFailure(e: Throwable): Unit = {
+ postToOutbox(message.receiver, OutboxMessage(serialize(message),
+ (e) => {
logWarning(s"Exception when sending $message", e)
- }
-
- override def onSuccess(response: Array[Byte]): Unit = {
- val ack = deserialize[Ack](response)
+ },
+ (client, response) => {
+ val ack = deserialize[Ack](client, response)
logDebug(s"Receive ack from ${ack.sender}")
- }
- }))
+ }))
}
}
@@ -214,16 +219,14 @@ private[netty] class NettyRpcEnv(
}
}(ThreadUtils.sameThread)
} else {
- postToOutbox(remoteAddr, OutboxMessage(serialize(message), new RpcResponseCallback {
-
- override def onFailure(e: Throwable): Unit = {
+ postToOutbox(message.receiver, OutboxMessage(serialize(message),
+ (e) => {
if (!promise.tryFailure(e)) {
logWarning("Ignore Exception", e)
}
- }
-
- override def onSuccess(response: Array[Byte]): Unit = {
- val reply = deserialize[AskResponse](response)
+ },
+ (client, response) => {
+ val reply = deserialize[AskResponse](client, response)
if (reply.reply.isInstanceOf[RpcFailure]) {
if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) {
logWarning(s"Ignore failure: ${reply.reply}")
@@ -231,8 +234,7 @@ private[netty] class NettyRpcEnv(
} else if (!promise.trySuccess(reply.reply)) {
logWarning(s"Ignore message: ${reply}")
}
- }
- }))
+ }))
}
promise.future
}
@@ -243,9 +245,11 @@ private[netty] class NettyRpcEnv(
buffer.array(), buffer.arrayOffset + buffer.position, buffer.arrayOffset + buffer.limit)
}
- private[netty] def deserialize[T: ClassTag](bytes: Array[Byte]): T = {
- deserialize { () =>
- javaSerializerInstance.deserialize[T](ByteBuffer.wrap(bytes))
+ private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: Array[Byte]): T = {
+ NettyRpcEnv.currentClient.withValue(client) {
+ deserialize { () =>
+ javaSerializerInstance.deserialize[T](ByteBuffer.wrap(bytes))
+ }
}
}
@@ -254,7 +258,7 @@ private[netty] class NettyRpcEnv(
}
override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String =
- new RpcEndpointAddress(address.host, address.port, endpointName).toString
+ new RpcEndpointAddress(address, endpointName).toString
override def shutdown(): Unit = {
cleanup()
@@ -297,6 +301,7 @@ private[netty] class NettyRpcEnv(
deserializationAction()
}
}
+
}
private[netty] object NettyRpcEnv extends Logging {
@@ -312,6 +317,13 @@ private[netty] object NettyRpcEnv extends Logging {
* }}}
*/
private[netty] val currentEnv = new DynamicVariable[NettyRpcEnv](null)
+
+ /**
+ * Similar to `currentEnv`, this variable references the client instance associated with an
+ * RPC, in case it's needed to find out the remote address during deserialization.
+ */
+ private[netty] val currentClient = new DynamicVariable[TransportClient](null)
+
}
private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
@@ -324,47 +336,68 @@ private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance]
val nettyEnv =
new NettyRpcEnv(sparkConf, javaSerializerInstance, config.host, config.securityManager)
- val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort =>
- nettyEnv.start(actualPort)
- (nettyEnv, actualPort)
- }
- try {
- Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, "NettyRpcEnv")._1
- } catch {
- case NonFatal(e) =>
- nettyEnv.shutdown()
- throw e
+ if (!config.clientMode) {
+ val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort =>
+ nettyEnv.startServer(actualPort)
+ (nettyEnv, actualPort)
+ }
+ try {
+ Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, "NettyRpcEnv")._1
+ } catch {
+ case NonFatal(e) =>
+ nettyEnv.shutdown()
+ throw e
+ }
}
+ nettyEnv
}
}
-private[netty] class NettyRpcEndpointRef(@transient private val conf: SparkConf)
+/**
+ * The NettyRpcEnv version of RpcEndpointRef.
+ *
+ * This class behaves differently depending on where it's created. On the node that "owns" the
+ * RpcEndpoint, it's a simple wrapper around the RpcEndpointAddress instance.
+ *
+ * On other machines that receive a serialized version of the reference, the behavior changes. The
+ * instance will keep track of the TransportClient that sent the reference, so that messages
+ * to the endpoint are sent over the client connection, instead of needing a new connection to
+ * be opened.
+ *
+ * The RpcAddress of this ref can be null; what that means is that the ref can only be used through
+ * a client connection, since the process hosting the endpoint is not listening for incoming
+ * connections. These refs should not be shared with 3rd parties, since they will not be able to
+ * send messages to the endpoint.
+ *
+ * @param conf Spark configuration.
+ * @param endpointAddress The address where the endpoint is listening.
+ * @param nettyEnv The RpcEnv associated with this ref.
+ * @param local Whether the referenced endpoint lives in the same process.
+ */
+private[netty] class NettyRpcEndpointRef(
+ @transient private val conf: SparkConf,
+ endpointAddress: RpcEndpointAddress,
+ @transient @volatile private var nettyEnv: NettyRpcEnv)
extends RpcEndpointRef(conf) with Serializable with Logging {
- @transient @volatile private var nettyEnv: NettyRpcEnv = _
+ @transient @volatile var client: TransportClient = _
- @transient @volatile private var _address: RpcEndpointAddress = _
+ private val _address = if (endpointAddress.rpcAddress != null) endpointAddress else null
+ private val _name = endpointAddress.name
- def this(conf: SparkConf, _address: RpcEndpointAddress, nettyEnv: NettyRpcEnv) {
- this(conf)
- this._address = _address
- this.nettyEnv = nettyEnv
- }
-
- override def address: RpcAddress = _address.toRpcAddress
+ override def address: RpcAddress = if (_address != null) _address.rpcAddress else null
private def readObject(in: ObjectInputStream): Unit = {
in.defaultReadObject()
- _address = in.readObject().asInstanceOf[RpcEndpointAddress]
nettyEnv = NettyRpcEnv.currentEnv.value
+ client = NettyRpcEnv.currentClient.value
}
private def writeObject(out: ObjectOutputStream): Unit = {
out.defaultWriteObject()
- out.writeObject(_address)
}
- override def name: String = _address.name
+ override def name: String = _name
override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
val promise = Promise[Any]()
@@ -429,41 +462,43 @@ private[netty] case class Ack(sender: NettyRpcEndpointRef) extends ResponseMessa
private[netty] case class RpcFailure(e: Throwable)
/**
- * Maintain the mapping relations between client addresses and [[RpcEnv]] addresses, broadcast
- * network events and forward messages to [[Dispatcher]].
+ * Dispatches incoming RPCs to registered endpoints.
+ *
+ * The handler keeps track of all client instances that communicate with it, so that the RpcEnv
+ * knows which `TransportClient` instance to use when sending RPCs to a client endpoint (i.e.,
+ * one that is not listening for incoming connections, but rather needs to be contacted via the
+ * client socket).
+ *
+ * Events are sent on a per-connection basis, so if a client opens multiple connections to the
+ * RpcEnv, multiple connection / disconnection events will be created for that client (albeit
+ * with different `RpcAddress` information).
*/
private[netty] class NettyRpcHandler(
dispatcher: Dispatcher, nettyEnv: NettyRpcEnv) extends RpcHandler with Logging {
- private type ClientAddress = RpcAddress
- private type RemoteEnvAddress = RpcAddress
-
- // Store all client addresses and their NettyRpcEnv addresses.
- // TODO: Is this even necessary?
- @GuardedBy("this")
- private val remoteAddresses = new mutable.HashMap[ClientAddress, RemoteEnvAddress]()
+ // TODO: Can we add connection callback (channel registered) to the underlying framework?
+ // A variable to track whether we should dispatch the RemoteProcessConnected message.
+ private val clients = new ConcurrentHashMap[TransportClient, JBoolean]()
override def receive(
- client: TransportClient, message: Array[Byte], callback: RpcResponseCallback): Unit = {
- val requestMessage = nettyEnv.deserialize[RequestMessage](message)
- val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
+ client: TransportClient,
+ message: Array[Byte],
+ callback: RpcResponseCallback): Unit = {
+ val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
assert(addr != null)
- val remoteEnvAddress = requestMessage.senderAddress
val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
-
- // TODO: Can we add connection callback (channel registered) to the underlying framework?
- // A variable to track whether we should dispatch the RemoteProcessConnected message.
- var dispatchRemoteProcessConnected = false
- synchronized {
- if (remoteAddresses.put(clientAddr, remoteEnvAddress).isEmpty) {
- // clientAddr connects at the first time, fire "RemoteProcessConnected"
- dispatchRemoteProcessConnected = true
- }
+ if (clients.putIfAbsent(client, JBoolean.TRUE) == null) {
+ dispatcher.postToAll(RemoteProcessConnected(clientAddr))
}
- if (dispatchRemoteProcessConnected) {
- dispatcher.postToAll(RemoteProcessConnected(remoteEnvAddress))
- }
- dispatcher.postRemoteMessage(requestMessage, callback)
+ val requestMessage = nettyEnv.deserialize[RequestMessage](client, message)
+ val messageToDispatch = if (requestMessage.senderAddress == null) {
+ // Create a new message with the socket address of the client as the sender.
+ RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content,
+ requestMessage.needReply)
+ } else {
+ requestMessage
+ }
+ dispatcher.postRemoteMessage(messageToDispatch, callback)
}
override def getStreamManager: StreamManager = new OneForOneStreamManager
@@ -472,15 +507,7 @@ private[netty] class NettyRpcHandler(
val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
if (addr != null) {
val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
- val broadcastMessage =
- synchronized {
- remoteAddresses.get(clientAddr).map(RemoteProcessConnectionError(cause, _))
- }
- if (broadcastMessage.isEmpty) {
- logError(cause.getMessage, cause)
- } else {
- dispatcher.postToAll(broadcastMessage.get)
- }
+ dispatcher.postToAll(RemoteProcessConnectionError(cause, clientAddr))
} else {
// If the channel is closed before connecting, its remoteAddress will be null.
// See java.net.Socket.getRemoteSocketAddress
@@ -493,15 +520,9 @@ private[netty] class NettyRpcHandler(
val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
if (addr != null) {
val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
+ clients.remove(client)
nettyEnv.removeOutbox(clientAddr)
- val messageOpt: Option[RemoteProcessDisconnected] =
- synchronized {
- remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress =>
- remoteAddresses -= clientAddr
- Some(RemoteProcessDisconnected(remoteEnvAddress))
- }
- }
- messageOpt.foreach(dispatcher.postToAll)
+ dispatcher.postToAll(RemoteProcessDisconnected(clientAddr))
} else {
// If the channel is closed before connecting, its remoteAddress will be null. In this case,
// we can ignore it since we don't fire "Associated".
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala
index 7d9d593b36..2f6817f2eb 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala
@@ -26,7 +26,21 @@ import org.apache.spark.SparkException
import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
import org.apache.spark.rpc.RpcAddress
-private[netty] case class OutboxMessage(content: Array[Byte], callback: RpcResponseCallback)
+private[netty] case class OutboxMessage(content: Array[Byte],
+ _onFailure: (Throwable) => Unit,
+ _onSuccess: (TransportClient, Array[Byte]) => Unit) {
+
+ def createCallback(client: TransportClient): RpcResponseCallback = new RpcResponseCallback() {
+ override def onFailure(e: Throwable): Unit = {
+ _onFailure(e)
+ }
+
+ override def onSuccess(response: Array[Byte]): Unit = {
+ _onSuccess(client, response)
+ }
+ }
+
+}
private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
@@ -68,7 +82,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
}
}
if (dropped) {
- message.callback.onFailure(new SparkException("Message is dropped because Outbox is stopped"))
+ message._onFailure(new SparkException("Message is dropped because Outbox is stopped"))
} else {
drainOutbox()
}
@@ -108,7 +122,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
try {
val _client = synchronized { client }
if (_client != null) {
- _client.sendRpc(message.content, message.callback)
+ _client.sendRpc(message.content, message.createCallback(_client))
} else {
assert(stopped == true)
}
@@ -181,7 +195,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
// update messages and it's safe to just drain the queue.
var message = messages.poll()
while (message != null) {
- message.callback.onFailure(e)
+ message._onFailure(e)
message = messages.poll()
}
assert(messages.isEmpty)
@@ -215,7 +229,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
// update messages and it's safe to just drain the queue.
var message = messages.poll()
while (message != null) {
- message.callback.onFailure(new SparkException("Message is dropped because Outbox is stopped"))
+ message._onFailure(new SparkException("Message is dropped because Outbox is stopped"))
message = messages.poll()
}
}
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala
index 87b6236936..d2e94f943a 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala
@@ -23,15 +23,25 @@ import org.apache.spark.rpc.RpcAddress
/**
* An address identifier for an RPC endpoint.
*
- * @param host host name of the remote process.
- * @param port the port the remote RPC environment binds to.
- * @param name name of the remote endpoint.
+ * The `rpcAddress` may be null, in which case the endpoint is registered via a client-only
+ * connection and can only be reached via the client that sent the endpoint reference.
+ *
+ * @param rpcAddress The socket address of the endpint.
+ * @param name Name of the endpoint.
*/
-private[netty] case class RpcEndpointAddress(host: String, port: Int, name: String) {
+private[netty] case class RpcEndpointAddress(val rpcAddress: RpcAddress, val name: String) {
+
+ require(name != null, "RpcEndpoint name must be provided.")
- def toRpcAddress: RpcAddress = RpcAddress(host, port)
+ def this(host: String, port: Int, name: String) = {
+ this(RpcAddress(host, port), name)
+ }
- override val toString = s"spark://$name@$host:$port"
+ override val toString = if (rpcAddress != null) {
+ s"spark://$name@${rpcAddress.host}:${rpcAddress.port}"
+ } else {
+ s"spark-client://$name"
+ }
}
private[netty] object RpcEndpointAddress {
@@ -51,7 +61,7 @@ private[netty] object RpcEndpointAddress {
uri.getQuery != null) {
throw new SparkException("Invalid Spark URL: " + sparkUrl)
}
- RpcEndpointAddress(host, port, name)
+ new RpcEndpointAddress(host, port, name)
} catch {
case e: java.net.URISyntaxException =>
throw new SparkException("Invalid Spark URL: " + sparkUrl, e)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
index 8103efa730..f3d0d85476 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
@@ -38,7 +38,7 @@ private[spark] object CoarseGrainedClusterMessages {
sealed trait RegisterExecutorResponse
- case object RegisteredExecutor extends CoarseGrainedClusterMessage
+ case class RegisteredExecutor(hostname: String) extends CoarseGrainedClusterMessage
with RegisterExecutorResponse
case class RegisterExecutorFailed(message: String) extends CoarseGrainedClusterMessage
@@ -51,9 +51,7 @@ private[spark] object CoarseGrainedClusterMessages {
hostPort: String,
cores: Int,
logUrls: Map[String, String])
- extends CoarseGrainedClusterMessage {
- Utils.checkHostPort(hostPort, "Expected host port")
- }
+ extends CoarseGrainedClusterMessage
case class StatusUpdate(executorId: String, taskId: Long, state: TaskState,
data: SerializableBuffer) extends CoarseGrainedClusterMessage
@@ -107,8 +105,4 @@ private[spark] object CoarseGrainedClusterMessages {
// Used internally by executors to shut themselves down.
case object Shutdown extends CoarseGrainedClusterMessage
- // SPARK-10987: workaround for netty RPC issue; forces a connection from the driver back
- // to the AM.
- case object DriverHello extends CoarseGrainedClusterMessage
-
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 55a564b5c8..439a119270 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -131,16 +131,22 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case RegisterExecutor(executorId, executorRef, hostPort, cores, logUrls) =>
- Utils.checkHostPort(hostPort, "Host port expected " + hostPort)
if (executorDataMap.contains(executorId)) {
context.reply(RegisterExecutorFailed("Duplicate executor ID: " + executorId))
} else {
- logInfo("Registered executor: " + executorRef + " with ID " + executorId)
- addressToExecutorId(executorRef.address) = executorId
+ // If the executor's rpc env is not listening for incoming connections, `hostPort`
+ // will be null, and the client connection should be used to contact the executor.
+ val executorAddress = if (executorRef.address != null) {
+ executorRef.address
+ } else {
+ context.senderAddress
+ }
+ logInfo(s"Registered executor $executorRef ($executorAddress) with ID $executorId")
+ addressToExecutorId(executorAddress) = executorId
totalCoreCount.addAndGet(cores)
totalRegisteredExecutors.addAndGet(1)
- val (host, _) = Utils.parseHostPort(hostPort)
- val data = new ExecutorData(executorRef, executorRef.address, host, cores, cores, logUrls)
+ val data = new ExecutorData(executorRef, executorRef.address, executorAddress.host,
+ cores, cores, logUrls)
// This must be synchronized because variables mutated
// in this block are read when requesting executors
CoarseGrainedSchedulerBackend.this.synchronized {
@@ -151,7 +157,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
}
}
// Note: some tests expect the reply to come after we put the executor in the map
- context.reply(RegisteredExecutor)
+ context.reply(RegisteredExecutor(executorAddress.host))
listenerBus.post(
SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data))
makeOffers()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
index e483688ede..cb24072d7d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
@@ -170,8 +170,6 @@ private[spark] abstract class YarnSchedulerBackend(
case RegisterClusterManager(am) =>
logInfo(s"ApplicationMaster registered as $am")
amEndpoint = Option(am)
- // See SPARK-10987.
- am.send(DriverHello)
case AddWebUIFilter(filterName, filterParams, proxyBase) =>
addWebUIFilter(filterName, filterParams, proxyBase)
diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
index 3bead6395d..834e4743df 100644
--- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
@@ -48,7 +48,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
}
}
- def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv
+ def createRpcEnv(conf: SparkConf, name: String, port: Int, clientMode: Boolean = false): RpcEnv
test("send a message locally") {
@volatile var message: String = null
@@ -76,7 +76,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
}
})
- val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345)
+ val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true)
// Use anotherEnv to find out the RpcEndpointRef
val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "send-remotely")
try {
@@ -130,7 +130,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
}
})
- val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345)
+ val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true)
// Use anotherEnv to find out the RpcEndpointRef
val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-remotely")
try {
@@ -158,7 +158,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
val shortProp = "spark.rpc.short.timeout"
conf.set("spark.rpc.retry.wait", "0")
conf.set("spark.rpc.numRetries", "1")
- val anotherEnv = createRpcEnv(conf, "remote", 13345)
+ val anotherEnv = createRpcEnv(conf, "remote", 13345, clientMode = true)
// Use anotherEnv to find out the RpcEndpointRef
val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-timeout")
try {
@@ -417,7 +417,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
}
})
- val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345)
+ val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true)
// Use anotherEnv to find out the RpcEndpointRef
val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "sendWithReply-remotely")
try {
@@ -457,7 +457,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
}
})
- val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345)
+ val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true)
// Use anotherEnv to find out the RpcEndpointRef
val rpcEndpointRef = anotherEnv.setupEndpointRef(
"local", env.address, "sendWithReply-remotely-error")
@@ -497,26 +497,40 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
})
- val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345)
+ val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true)
// Use anotherEnv to find out the RpcEndpointRef
val rpcEndpointRef = anotherEnv.setupEndpointRef(
"local", env.address, "network-events")
val remoteAddress = anotherEnv.address
rpcEndpointRef.send("hello")
eventually(timeout(5 seconds), interval(5 millis)) {
- assert(events === List(("onConnected", remoteAddress)))
+ // anotherEnv is connected in client mode, so the remote address may be unknown depending on
+ // the implementation. Account for that when doing checks.
+ if (remoteAddress != null) {
+ assert(events === List(("onConnected", remoteAddress)))
+ } else {
+ assert(events.size === 1)
+ assert(events(0)._1 === "onConnected")
+ }
}
anotherEnv.shutdown()
anotherEnv.awaitTermination()
eventually(timeout(5 seconds), interval(5 millis)) {
- assert(events === List(
- ("onConnected", remoteAddress),
- ("onNetworkError", remoteAddress),
- ("onDisconnected", remoteAddress)) ||
- events === List(
- ("onConnected", remoteAddress),
- ("onDisconnected", remoteAddress)))
+ // Account for anotherEnv not having an address due to running in client mode.
+ if (remoteAddress != null) {
+ assert(events === List(
+ ("onConnected", remoteAddress),
+ ("onNetworkError", remoteAddress),
+ ("onDisconnected", remoteAddress)) ||
+ events === List(
+ ("onConnected", remoteAddress),
+ ("onDisconnected", remoteAddress)))
+ } else {
+ val eventNames = events.map(_._1)
+ assert(eventNames === List("onConnected", "onNetworkError", "onDisconnected") ||
+ eventNames === List("onConnected", "onDisconnected"))
+ }
}
}
@@ -529,7 +543,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
}
})
- val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345)
+ val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true)
// Use anotherEnv to find out the RpcEndpointRef
val rpcEndpointRef = anotherEnv.setupEndpointRef(
"local", env.address, "sendWithReply-unserializable-error")
@@ -558,7 +572,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
conf.set("spark.authenticate.secret", "good")
val localEnv = createRpcEnv(conf, "authentication-local", 13345)
- val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345)
+ val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345, clientMode = true)
try {
@volatile var message: String = null
@@ -589,7 +603,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
conf.set("spark.authenticate.secret", "good")
val localEnv = createRpcEnv(conf, "authentication-local", 13345)
- val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345)
+ val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345, clientMode = true)
try {
localEnv.setupEndpoint("ask-authentication", new RpcEndpoint {
diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala
index 4aa75c9230..6478ab51c4 100644
--- a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala
@@ -22,9 +22,12 @@ import org.apache.spark.{SSLSampleConfigs, SecurityManager, SparkConf}
class AkkaRpcEnvSuite extends RpcEnvSuite {
- override def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv = {
+ override def createRpcEnv(conf: SparkConf,
+ name: String,
+ port: Int,
+ clientMode: Boolean = false): RpcEnv = {
new AkkaRpcEnvFactory().create(
- RpcEnvConfig(conf, name, "localhost", port, new SecurityManager(conf)))
+ RpcEnvConfig(conf, name, "localhost", port, new SecurityManager(conf), clientMode))
}
test("setupEndpointRef: systemName, address, endpointName") {
@@ -37,7 +40,7 @@ class AkkaRpcEnvSuite extends RpcEnvSuite {
})
val conf = new SparkConf()
val newRpcEnv = new AkkaRpcEnvFactory().create(
- RpcEnvConfig(conf, "test", "localhost", 12346, new SecurityManager(conf)))
+ RpcEnvConfig(conf, "test", "localhost", 12346, new SecurityManager(conf), false))
try {
val newRef = newRpcEnv.setupEndpointRef("local", ref.address, "test_endpoint")
assert(s"akka.tcp://local@${env.address}/user/test_endpoint" ===
@@ -56,7 +59,7 @@ class AkkaRpcEnvSuite extends RpcEnvSuite {
val conf = SSLSampleConfigs.sparkSSLConfig()
val securityManager = new SecurityManager(conf)
val rpcEnv = new AkkaRpcEnvFactory().create(
- RpcEnvConfig(conf, "test", "localhost", 12346, securityManager))
+ RpcEnvConfig(conf, "test", "localhost", 12346, securityManager, false))
try {
val uri = rpcEnv.uriOf("local", RpcAddress("1.2.3.4", 12345), "test_endpoint")
assert("akka.ssl.tcp://local@1.2.3.4:12345/user/test_endpoint" === uri)
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala
index 973a07a0bd..56743ba650 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala
@@ -22,8 +22,13 @@ import org.apache.spark.SparkFunSuite
class NettyRpcAddressSuite extends SparkFunSuite {
test("toString") {
- val addr = RpcEndpointAddress("localhost", 12345, "test")
+ val addr = new RpcEndpointAddress("localhost", 12345, "test")
assert(addr.toString === "spark://test@localhost:12345")
}
+ test("toString for client mode") {
+ val addr = RpcEndpointAddress(null, "test")
+ assert(addr.toString === "spark-client://test")
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala
index be19668e17..ce83087ec0 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala
@@ -22,8 +22,13 @@ import org.apache.spark.rpc._
class NettyRpcEnvSuite extends RpcEnvSuite {
- override def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv = {
- val config = RpcEnvConfig(conf, "test", "localhost", port, new SecurityManager(conf))
+ override def createRpcEnv(
+ conf: SparkConf,
+ name: String,
+ port: Int,
+ clientMode: Boolean = false): RpcEnv = {
+ val config = RpcEnvConfig(conf, "test", "localhost", port, new SecurityManager(conf),
+ clientMode)
new NettyRpcEnvFactory().create(config)
}
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
index 5430e4c0c4..f9d8e80c98 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
@@ -30,7 +30,7 @@ import org.apache.spark.rpc._
class NettyRpcHandlerSuite extends SparkFunSuite {
val env = mock(classOf[NettyRpcEnv])
- when(env.deserialize(any(classOf[Array[Byte]]))(any())).
+ when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any())).
thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false))
test("receive") {
@@ -42,7 +42,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite {
when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000))
nettyRpcHandler.receive(client, null, null)
- verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 12345)))
+ verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000)))
}
test("connectionTerminated") {
@@ -57,9 +57,9 @@ class NettyRpcHandlerSuite extends SparkFunSuite {
when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000))
nettyRpcHandler.connectionTerminated(client)
- verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 12345)))
+ verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000)))
verify(dispatcher, times(1)).postToAll(
- RemoteProcessDisconnected(RpcAddress("localhost", 12345)))
+ RemoteProcessDisconnected(RpcAddress("localhost", 40000)))
}
}
diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml
index 541ed9a8d0..e2360eff5c 100644
--- a/network/yarn/pom.xml
+++ b/network/yarn/pom.xml
@@ -54,6 +54,11 @@
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-client</artifactId>
</dependency>
+ <dependency>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-api</artifactId>
+ <scope>provided</scope>
+ </dependency>
</dependencies>
<build>
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index c6a6d7ac56..12ae350e4c 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -321,7 +321,8 @@ private[spark] class ApplicationMaster(
private def runExecutorLauncher(securityMgr: SecurityManager): Unit = {
val port = sparkConf.getInt("spark.yarn.am.port", 0)
- rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, port, sparkConf, securityMgr)
+ rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, port, sparkConf, securityMgr,
+ clientMode = true)
val driverRef = waitForSparkDriver()
addAmIpFilter()
registerAM(rpcEnv, driverRef, sparkConf.get("spark.driver.appUIAddress", ""), securityMgr)
@@ -574,9 +575,6 @@ private[spark] class ApplicationMaster(
case x: AddWebUIFilter =>
logInfo(s"Add WebUI Filter. $x")
driver.send(x)
-
- case DriverHello =>
- // SPARK-10987: no action needed for this message.
}
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {