aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcelo Vanzin <vanzin@cloudera.com>2015-11-02 10:26:36 -0800
committerMarcelo Vanzin <vanzin@cloudera.com>2015-11-02 10:26:36 -0800
commit71d1c907dec446db566b19f912159fd8f46deb7d (patch)
tree8201803d422933421b5af731d61ef7dcd54ca6da
parenta930e624eb9feb0f7d37d99dcb8178feb9c0f177 (diff)
downloadspark-71d1c907dec446db566b19f912159fd8f46deb7d.tar.gz
spark-71d1c907dec446db566b19f912159fd8f46deb7d.tar.bz2
spark-71d1c907dec446db566b19f912159fd8f46deb7d.zip
[SPARK-10997][CORE] Add "client mode" to netty rpc env.
"Client mode" means the RPC env will not listen for incoming connections. This allows certain processes in the Spark stack (such as Executors or tha YARN client-mode AM) to act as pure clients when using the netty-based RPC backend, reducing the number of sockets needed by the app and also the number of open ports. Client connections are also preferred when endpoints that actually have a listening socket are involved; so, for example, if a Worker connects to a Master and the Master needs to send a message to a Worker endpoint, that client connection will be used, even though the Worker is also listening for incoming connections. With this change, the workaround for SPARK-10987 isn't necessary anymore, and is removed. The AM connects to the driver in "client mode", and that connection is used for all driver <-> AM communication, and so the AM is properly notified when the connection goes down. Author: Marcelo Vanzin <vanzin@cloudera.com> Closes #9210 from vanzin/SPARK-10997.
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala20
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala245
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala18
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala50
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala11
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala7
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala9
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala8
-rw-r--r--network/yarn/pom.xml5
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala6
17 files changed, 266 insertions, 190 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 398e093690..23ae9360f6 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -252,7 +252,8 @@ object SparkEnv extends Logging {
// Create the ActorSystem for Akka and get the port it binds to.
val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName
- val rpcEnv = RpcEnv.create(actorSystemName, hostname, port, conf, securityManager)
+ val rpcEnv = RpcEnv.create(actorSystemName, hostname, port, conf, securityManager,
+ clientMode = !isDriver)
val actorSystem: ActorSystem =
if (rpcEnv.isInstanceOf[AkkaRpcEnv]) {
rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem
@@ -262,9 +263,11 @@ object SparkEnv extends Logging {
}
// Figure out which port Akka actually bound to in case the original port is 0 or occupied.
+ // In the non-driver case, the RPC env's address may be null since it may not be listening
+ // for incoming connections.
if (isDriver) {
conf.set("spark.driver.port", rpcEnv.address.port.toString)
- } else {
+ } else if (rpcEnv.address != null) {
conf.set("spark.executor.port", rpcEnv.address.port.toString)
}
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index a9c6a05ecd..c2ebf30596 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -45,8 +45,6 @@ private[spark] class CoarseGrainedExecutorBackend(
env: SparkEnv)
extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging {
- Utils.checkHostPort(hostPort, "Expected hostport")
-
var executor: Executor = null
@volatile var driver: Option[RpcEndpointRef] = None
@@ -80,9 +78,8 @@ private[spark] class CoarseGrainedExecutorBackend(
}
override def receive: PartialFunction[Any, Unit] = {
- case RegisteredExecutor =>
+ case RegisteredExecutor(hostname) =>
logInfo("Successfully registered with driver")
- val (hostname, _) = Utils.parseHostPort(hostPort)
executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false)
case RegisterExecutorFailed(message) =>
@@ -163,7 +160,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
hostname,
port,
executorConf,
- new SecurityManager(executorConf))
+ new SecurityManager(executorConf),
+ clientMode = true)
val driver = fetcher.setupEndpointRefByURI(driverUrl)
val props = driver.askWithRetry[Seq[(String, String)]](RetrieveSparkProps) ++
Seq[(String, String)](("spark.app.id", appId))
@@ -188,12 +186,12 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
val env = SparkEnv.createExecutorEnv(
driverConf, executorId, hostname, port, cores, isLocal = false)
- // SparkEnv sets spark.driver.port so it shouldn't be 0 anymore.
- val boundPort = env.conf.getInt("spark.executor.port", 0)
- assert(boundPort != 0)
-
- // Start the CoarseGrainedExecutorBackend endpoint.
- val sparkHostPort = hostname + ":" + boundPort
+ // SparkEnv will set spark.executor.port if the rpc env is listening for incoming
+ // connections (e.g., if it's using akka). Otherwise, the executor is running in
+ // client mode only, and does not accept incoming connections.
+ val sparkHostPort = env.conf.getOption("spark.executor.port").map { port =>
+ hostname + ":" + port
+ }.orNull
env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend(
env.rpcEnv, driverUrl, executorId, sparkHostPort, cores, userClassPath, env))
workerUrl.foreach { url =>
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
index 2c4a8b9a0a..a560fd10cd 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
@@ -43,9 +43,10 @@ private[spark] object RpcEnv {
host: String,
port: Int,
conf: SparkConf,
- securityManager: SecurityManager): RpcEnv = {
+ securityManager: SecurityManager,
+ clientMode: Boolean = false): RpcEnv = {
// Using Reflection to create the RpcEnv to avoid to depend on Akka directly
- val config = RpcEnvConfig(conf, name, host, port, securityManager)
+ val config = RpcEnvConfig(conf, name, host, port, securityManager, clientMode)
getRpcEnvFactory(conf).create(config)
}
}
@@ -139,4 +140,5 @@ private[spark] case class RpcEnvConfig(
name: String,
host: String,
port: Int,
- securityManager: SecurityManager)
+ securityManager: SecurityManager,
+ clientMode: Boolean)
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
index 7bf44a6565..eb25d6c7b7 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
@@ -55,7 +55,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
private var stopped = false
def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = {
- val addr = new RpcEndpointAddress(nettyEnv.address.host, nettyEnv.address.port, name)
+ val addr = RpcEndpointAddress(nettyEnv.address, name)
val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv)
synchronized {
if (stopped) {
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
index 284284eb80..09093819bb 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
@@ -17,10 +17,12 @@
package org.apache.spark.rpc.netty
import java.io._
+import java.lang.{Boolean => JBoolean}
import java.net.{InetSocketAddress, URI}
import java.nio.ByteBuffer
import java.util.concurrent._
import java.util.concurrent.atomic.AtomicBoolean
+import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy
import scala.collection.mutable
@@ -29,6 +31,7 @@ import scala.reflect.ClassTag
import scala.util.{DynamicVariable, Failure, Success}
import scala.util.control.NonFatal
+import com.google.common.base.Preconditions
import org.apache.spark.{Logging, SecurityManager, SparkConf}
import org.apache.spark.network.TransportContext
import org.apache.spark.network.client._
@@ -45,15 +48,14 @@ private[netty] class NettyRpcEnv(
host: String,
securityManager: SecurityManager) extends RpcEnv(conf) with Logging {
- // Override numConnectionsPerPeer to 1 for RPC.
private val transportConf = SparkTransportConf.fromSparkConf(
conf.clone.set("spark.shuffle.io.numConnectionsPerPeer", "1"),
conf.getInt("spark.rpc.io.threads", 0))
private val dispatcher: Dispatcher = new Dispatcher(this)
- private val transportContext =
- new TransportContext(transportConf, new NettyRpcHandler(dispatcher, this))
+ private val transportContext = new TransportContext(transportConf,
+ new NettyRpcHandler(dispatcher, this))
private val clientFactory = {
val bootstraps: java.util.List[TransportClientBootstrap] =
@@ -95,7 +97,7 @@ private[netty] class NettyRpcEnv(
}
}
- def start(port: Int): Unit = {
+ def startServer(port: Int): Unit = {
val bootstraps: java.util.List[TransportServerBootstrap] =
if (securityManager.isAuthenticationEnabled()) {
java.util.Arrays.asList(new SaslServerBootstrap(transportConf, securityManager))
@@ -107,9 +109,9 @@ private[netty] class NettyRpcEnv(
RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher))
}
+ @Nullable
override lazy val address: RpcAddress = {
- require(server != null, "NettyRpcEnv has not yet started")
- RpcAddress(host, server.getPort)
+ if (server != null) RpcAddress(host, server.getPort()) else null
}
override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = {
@@ -120,7 +122,7 @@ private[netty] class NettyRpcEnv(
val addr = RpcEndpointAddress(uri)
val endpointRef = new NettyRpcEndpointRef(conf, addr, this)
val verifier = new NettyRpcEndpointRef(
- conf, RpcEndpointAddress(addr.host, addr.port, RpcEndpointVerifier.NAME), this)
+ conf, RpcEndpointAddress(addr.rpcAddress, RpcEndpointVerifier.NAME), this)
verifier.ask[Boolean](RpcEndpointVerifier.CheckExistence(endpointRef.name)).flatMap { find =>
if (find) {
Future.successful(endpointRef)
@@ -135,28 +137,34 @@ private[netty] class NettyRpcEnv(
dispatcher.stop(endpointRef)
}
- private def postToOutbox(address: RpcAddress, message: OutboxMessage): Unit = {
- val targetOutbox = {
- val outbox = outboxes.get(address)
- if (outbox == null) {
- val newOutbox = new Outbox(this, address)
- val oldOutbox = outboxes.putIfAbsent(address, newOutbox)
- if (oldOutbox == null) {
- newOutbox
+ private def postToOutbox(receiver: NettyRpcEndpointRef, message: OutboxMessage): Unit = {
+ if (receiver.client != null) {
+ receiver.client.sendRpc(message.content, message.createCallback(receiver.client));
+ } else {
+ require(receiver.address != null,
+ "Cannot send message to client endpoint with no listen address.")
+ val targetOutbox = {
+ val outbox = outboxes.get(receiver.address)
+ if (outbox == null) {
+ val newOutbox = new Outbox(this, receiver.address)
+ val oldOutbox = outboxes.putIfAbsent(receiver.address, newOutbox)
+ if (oldOutbox == null) {
+ newOutbox
+ } else {
+ oldOutbox
+ }
} else {
- oldOutbox
+ outbox
}
+ }
+ if (stopped.get) {
+ // It's possible that we put `targetOutbox` after stopping. So we need to clean it.
+ outboxes.remove(receiver.address)
+ targetOutbox.stop()
} else {
- outbox
+ targetOutbox.send(message)
}
}
- if (stopped.get) {
- // It's possible that we put `targetOutbox` after stopping. So we need to clean it.
- outboxes.remove(address)
- targetOutbox.stop()
- } else {
- targetOutbox.send(message)
- }
}
private[netty] def send(message: RequestMessage): Unit = {
@@ -174,17 +182,14 @@ private[netty] class NettyRpcEnv(
}(ThreadUtils.sameThread)
} else {
// Message to a remote RPC endpoint.
- postToOutbox(remoteAddr, OutboxMessage(serialize(message), new RpcResponseCallback {
-
- override def onFailure(e: Throwable): Unit = {
+ postToOutbox(message.receiver, OutboxMessage(serialize(message),
+ (e) => {
logWarning(s"Exception when sending $message", e)
- }
-
- override def onSuccess(response: Array[Byte]): Unit = {
- val ack = deserialize[Ack](response)
+ },
+ (client, response) => {
+ val ack = deserialize[Ack](client, response)
logDebug(s"Receive ack from ${ack.sender}")
- }
- }))
+ }))
}
}
@@ -214,16 +219,14 @@ private[netty] class NettyRpcEnv(
}
}(ThreadUtils.sameThread)
} else {
- postToOutbox(remoteAddr, OutboxMessage(serialize(message), new RpcResponseCallback {
-
- override def onFailure(e: Throwable): Unit = {
+ postToOutbox(message.receiver, OutboxMessage(serialize(message),
+ (e) => {
if (!promise.tryFailure(e)) {
logWarning("Ignore Exception", e)
}
- }
-
- override def onSuccess(response: Array[Byte]): Unit = {
- val reply = deserialize[AskResponse](response)
+ },
+ (client, response) => {
+ val reply = deserialize[AskResponse](client, response)
if (reply.reply.isInstanceOf[RpcFailure]) {
if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) {
logWarning(s"Ignore failure: ${reply.reply}")
@@ -231,8 +234,7 @@ private[netty] class NettyRpcEnv(
} else if (!promise.trySuccess(reply.reply)) {
logWarning(s"Ignore message: ${reply}")
}
- }
- }))
+ }))
}
promise.future
}
@@ -243,9 +245,11 @@ private[netty] class NettyRpcEnv(
buffer.array(), buffer.arrayOffset + buffer.position, buffer.arrayOffset + buffer.limit)
}
- private[netty] def deserialize[T: ClassTag](bytes: Array[Byte]): T = {
- deserialize { () =>
- javaSerializerInstance.deserialize[T](ByteBuffer.wrap(bytes))
+ private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: Array[Byte]): T = {
+ NettyRpcEnv.currentClient.withValue(client) {
+ deserialize { () =>
+ javaSerializerInstance.deserialize[T](ByteBuffer.wrap(bytes))
+ }
}
}
@@ -254,7 +258,7 @@ private[netty] class NettyRpcEnv(
}
override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String =
- new RpcEndpointAddress(address.host, address.port, endpointName).toString
+ new RpcEndpointAddress(address, endpointName).toString
override def shutdown(): Unit = {
cleanup()
@@ -297,6 +301,7 @@ private[netty] class NettyRpcEnv(
deserializationAction()
}
}
+
}
private[netty] object NettyRpcEnv extends Logging {
@@ -312,6 +317,13 @@ private[netty] object NettyRpcEnv extends Logging {
* }}}
*/
private[netty] val currentEnv = new DynamicVariable[NettyRpcEnv](null)
+
+ /**
+ * Similar to `currentEnv`, this variable references the client instance associated with an
+ * RPC, in case it's needed to find out the remote address during deserialization.
+ */
+ private[netty] val currentClient = new DynamicVariable[TransportClient](null)
+
}
private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
@@ -324,47 +336,68 @@ private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance]
val nettyEnv =
new NettyRpcEnv(sparkConf, javaSerializerInstance, config.host, config.securityManager)
- val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort =>
- nettyEnv.start(actualPort)
- (nettyEnv, actualPort)
- }
- try {
- Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, "NettyRpcEnv")._1
- } catch {
- case NonFatal(e) =>
- nettyEnv.shutdown()
- throw e
+ if (!config.clientMode) {
+ val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort =>
+ nettyEnv.startServer(actualPort)
+ (nettyEnv, actualPort)
+ }
+ try {
+ Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, "NettyRpcEnv")._1
+ } catch {
+ case NonFatal(e) =>
+ nettyEnv.shutdown()
+ throw e
+ }
}
+ nettyEnv
}
}
-private[netty] class NettyRpcEndpointRef(@transient private val conf: SparkConf)
+/**
+ * The NettyRpcEnv version of RpcEndpointRef.
+ *
+ * This class behaves differently depending on where it's created. On the node that "owns" the
+ * RpcEndpoint, it's a simple wrapper around the RpcEndpointAddress instance.
+ *
+ * On other machines that receive a serialized version of the reference, the behavior changes. The
+ * instance will keep track of the TransportClient that sent the reference, so that messages
+ * to the endpoint are sent over the client connection, instead of needing a new connection to
+ * be opened.
+ *
+ * The RpcAddress of this ref can be null; what that means is that the ref can only be used through
+ * a client connection, since the process hosting the endpoint is not listening for incoming
+ * connections. These refs should not be shared with 3rd parties, since they will not be able to
+ * send messages to the endpoint.
+ *
+ * @param conf Spark configuration.
+ * @param endpointAddress The address where the endpoint is listening.
+ * @param nettyEnv The RpcEnv associated with this ref.
+ * @param local Whether the referenced endpoint lives in the same process.
+ */
+private[netty] class NettyRpcEndpointRef(
+ @transient private val conf: SparkConf,
+ endpointAddress: RpcEndpointAddress,
+ @transient @volatile private var nettyEnv: NettyRpcEnv)
extends RpcEndpointRef(conf) with Serializable with Logging {
- @transient @volatile private var nettyEnv: NettyRpcEnv = _
+ @transient @volatile var client: TransportClient = _
- @transient @volatile private var _address: RpcEndpointAddress = _
+ private val _address = if (endpointAddress.rpcAddress != null) endpointAddress else null
+ private val _name = endpointAddress.name
- def this(conf: SparkConf, _address: RpcEndpointAddress, nettyEnv: NettyRpcEnv) {
- this(conf)
- this._address = _address
- this.nettyEnv = nettyEnv
- }
-
- override def address: RpcAddress = _address.toRpcAddress
+ override def address: RpcAddress = if (_address != null) _address.rpcAddress else null
private def readObject(in: ObjectInputStream): Unit = {
in.defaultReadObject()
- _address = in.readObject().asInstanceOf[RpcEndpointAddress]
nettyEnv = NettyRpcEnv.currentEnv.value
+ client = NettyRpcEnv.currentClient.value
}
private def writeObject(out: ObjectOutputStream): Unit = {
out.defaultWriteObject()
- out.writeObject(_address)
}
- override def name: String = _address.name
+ override def name: String = _name
override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
val promise = Promise[Any]()
@@ -429,41 +462,43 @@ private[netty] case class Ack(sender: NettyRpcEndpointRef) extends ResponseMessa
private[netty] case class RpcFailure(e: Throwable)
/**
- * Maintain the mapping relations between client addresses and [[RpcEnv]] addresses, broadcast
- * network events and forward messages to [[Dispatcher]].
+ * Dispatches incoming RPCs to registered endpoints.
+ *
+ * The handler keeps track of all client instances that communicate with it, so that the RpcEnv
+ * knows which `TransportClient` instance to use when sending RPCs to a client endpoint (i.e.,
+ * one that is not listening for incoming connections, but rather needs to be contacted via the
+ * client socket).
+ *
+ * Events are sent on a per-connection basis, so if a client opens multiple connections to the
+ * RpcEnv, multiple connection / disconnection events will be created for that client (albeit
+ * with different `RpcAddress` information).
*/
private[netty] class NettyRpcHandler(
dispatcher: Dispatcher, nettyEnv: NettyRpcEnv) extends RpcHandler with Logging {
- private type ClientAddress = RpcAddress
- private type RemoteEnvAddress = RpcAddress
-
- // Store all client addresses and their NettyRpcEnv addresses.
- // TODO: Is this even necessary?
- @GuardedBy("this")
- private val remoteAddresses = new mutable.HashMap[ClientAddress, RemoteEnvAddress]()
+ // TODO: Can we add connection callback (channel registered) to the underlying framework?
+ // A variable to track whether we should dispatch the RemoteProcessConnected message.
+ private val clients = new ConcurrentHashMap[TransportClient, JBoolean]()
override def receive(
- client: TransportClient, message: Array[Byte], callback: RpcResponseCallback): Unit = {
- val requestMessage = nettyEnv.deserialize[RequestMessage](message)
- val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
+ client: TransportClient,
+ message: Array[Byte],
+ callback: RpcResponseCallback): Unit = {
+ val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
assert(addr != null)
- val remoteEnvAddress = requestMessage.senderAddress
val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
-
- // TODO: Can we add connection callback (channel registered) to the underlying framework?
- // A variable to track whether we should dispatch the RemoteProcessConnected message.
- var dispatchRemoteProcessConnected = false
- synchronized {
- if (remoteAddresses.put(clientAddr, remoteEnvAddress).isEmpty) {
- // clientAddr connects at the first time, fire "RemoteProcessConnected"
- dispatchRemoteProcessConnected = true
- }
+ if (clients.putIfAbsent(client, JBoolean.TRUE) == null) {
+ dispatcher.postToAll(RemoteProcessConnected(clientAddr))
}
- if (dispatchRemoteProcessConnected) {
- dispatcher.postToAll(RemoteProcessConnected(remoteEnvAddress))
- }
- dispatcher.postRemoteMessage(requestMessage, callback)
+ val requestMessage = nettyEnv.deserialize[RequestMessage](client, message)
+ val messageToDispatch = if (requestMessage.senderAddress == null) {
+ // Create a new message with the socket address of the client as the sender.
+ RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content,
+ requestMessage.needReply)
+ } else {
+ requestMessage
+ }
+ dispatcher.postRemoteMessage(messageToDispatch, callback)
}
override def getStreamManager: StreamManager = new OneForOneStreamManager
@@ -472,15 +507,7 @@ private[netty] class NettyRpcHandler(
val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
if (addr != null) {
val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
- val broadcastMessage =
- synchronized {
- remoteAddresses.get(clientAddr).map(RemoteProcessConnectionError(cause, _))
- }
- if (broadcastMessage.isEmpty) {
- logError(cause.getMessage, cause)
- } else {
- dispatcher.postToAll(broadcastMessage.get)
- }
+ dispatcher.postToAll(RemoteProcessConnectionError(cause, clientAddr))
} else {
// If the channel is closed before connecting, its remoteAddress will be null.
// See java.net.Socket.getRemoteSocketAddress
@@ -493,15 +520,9 @@ private[netty] class NettyRpcHandler(
val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
if (addr != null) {
val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
+ clients.remove(client)
nettyEnv.removeOutbox(clientAddr)
- val messageOpt: Option[RemoteProcessDisconnected] =
- synchronized {
- remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress =>
- remoteAddresses -= clientAddr
- Some(RemoteProcessDisconnected(remoteEnvAddress))
- }
- }
- messageOpt.foreach(dispatcher.postToAll)
+ dispatcher.postToAll(RemoteProcessDisconnected(clientAddr))
} else {
// If the channel is closed before connecting, its remoteAddress will be null. In this case,
// we can ignore it since we don't fire "Associated".
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala
index 7d9d593b36..2f6817f2eb 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala
@@ -26,7 +26,21 @@ import org.apache.spark.SparkException
import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
import org.apache.spark.rpc.RpcAddress
-private[netty] case class OutboxMessage(content: Array[Byte], callback: RpcResponseCallback)
+private[netty] case class OutboxMessage(content: Array[Byte],
+ _onFailure: (Throwable) => Unit,
+ _onSuccess: (TransportClient, Array[Byte]) => Unit) {
+
+ def createCallback(client: TransportClient): RpcResponseCallback = new RpcResponseCallback() {
+ override def onFailure(e: Throwable): Unit = {
+ _onFailure(e)
+ }
+
+ override def onSuccess(response: Array[Byte]): Unit = {
+ _onSuccess(client, response)
+ }
+ }
+
+}
private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
@@ -68,7 +82,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
}
}
if (dropped) {
- message.callback.onFailure(new SparkException("Message is dropped because Outbox is stopped"))
+ message._onFailure(new SparkException("Message is dropped because Outbox is stopped"))
} else {
drainOutbox()
}
@@ -108,7 +122,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
try {
val _client = synchronized { client }
if (_client != null) {
- _client.sendRpc(message.content, message.callback)
+ _client.sendRpc(message.content, message.createCallback(_client))
} else {
assert(stopped == true)
}
@@ -181,7 +195,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
// update messages and it's safe to just drain the queue.
var message = messages.poll()
while (message != null) {
- message.callback.onFailure(e)
+ message._onFailure(e)
message = messages.poll()
}
assert(messages.isEmpty)
@@ -215,7 +229,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
// update messages and it's safe to just drain the queue.
var message = messages.poll()
while (message != null) {
- message.callback.onFailure(new SparkException("Message is dropped because Outbox is stopped"))
+ message._onFailure(new SparkException("Message is dropped because Outbox is stopped"))
message = messages.poll()
}
}
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala
index 87b6236936..d2e94f943a 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala
@@ -23,15 +23,25 @@ import org.apache.spark.rpc.RpcAddress
/**
* An address identifier for an RPC endpoint.
*
- * @param host host name of the remote process.
- * @param port the port the remote RPC environment binds to.
- * @param name name of the remote endpoint.
+ * The `rpcAddress` may be null, in which case the endpoint is registered via a client-only
+ * connection and can only be reached via the client that sent the endpoint reference.
+ *
+ * @param rpcAddress The socket address of the endpint.
+ * @param name Name of the endpoint.
*/
-private[netty] case class RpcEndpointAddress(host: String, port: Int, name: String) {
+private[netty] case class RpcEndpointAddress(val rpcAddress: RpcAddress, val name: String) {
+
+ require(name != null, "RpcEndpoint name must be provided.")
- def toRpcAddress: RpcAddress = RpcAddress(host, port)
+ def this(host: String, port: Int, name: String) = {
+ this(RpcAddress(host, port), name)
+ }
- override val toString = s"spark://$name@$host:$port"
+ override val toString = if (rpcAddress != null) {
+ s"spark://$name@${rpcAddress.host}:${rpcAddress.port}"
+ } else {
+ s"spark-client://$name"
+ }
}
private[netty] object RpcEndpointAddress {
@@ -51,7 +61,7 @@ private[netty] object RpcEndpointAddress {
uri.getQuery != null) {
throw new SparkException("Invalid Spark URL: " + sparkUrl)
}
- RpcEndpointAddress(host, port, name)
+ new RpcEndpointAddress(host, port, name)
} catch {
case e: java.net.URISyntaxException =>
throw new SparkException("Invalid Spark URL: " + sparkUrl, e)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
index 8103efa730..f3d0d85476 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
@@ -38,7 +38,7 @@ private[spark] object CoarseGrainedClusterMessages {
sealed trait RegisterExecutorResponse
- case object RegisteredExecutor extends CoarseGrainedClusterMessage
+ case class RegisteredExecutor(hostname: String) extends CoarseGrainedClusterMessage
with RegisterExecutorResponse
case class RegisterExecutorFailed(message: String) extends CoarseGrainedClusterMessage
@@ -51,9 +51,7 @@ private[spark] object CoarseGrainedClusterMessages {
hostPort: String,
cores: Int,
logUrls: Map[String, String])
- extends CoarseGrainedClusterMessage {
- Utils.checkHostPort(hostPort, "Expected host port")
- }
+ extends CoarseGrainedClusterMessage
case class StatusUpdate(executorId: String, taskId: Long, state: TaskState,
data: SerializableBuffer) extends CoarseGrainedClusterMessage
@@ -107,8 +105,4 @@ private[spark] object CoarseGrainedClusterMessages {
// Used internally by executors to shut themselves down.
case object Shutdown extends CoarseGrainedClusterMessage
- // SPARK-10987: workaround for netty RPC issue; forces a connection from the driver back
- // to the AM.
- case object DriverHello extends CoarseGrainedClusterMessage
-
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 55a564b5c8..439a119270 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -131,16 +131,22 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case RegisterExecutor(executorId, executorRef, hostPort, cores, logUrls) =>
- Utils.checkHostPort(hostPort, "Host port expected " + hostPort)
if (executorDataMap.contains(executorId)) {
context.reply(RegisterExecutorFailed("Duplicate executor ID: " + executorId))
} else {
- logInfo("Registered executor: " + executorRef + " with ID " + executorId)
- addressToExecutorId(executorRef.address) = executorId
+ // If the executor's rpc env is not listening for incoming connections, `hostPort`
+ // will be null, and the client connection should be used to contact the executor.
+ val executorAddress = if (executorRef.address != null) {
+ executorRef.address
+ } else {
+ context.senderAddress
+ }
+ logInfo(s"Registered executor $executorRef ($executorAddress) with ID $executorId")
+ addressToExecutorId(executorAddress) = executorId
totalCoreCount.addAndGet(cores)
totalRegisteredExecutors.addAndGet(1)
- val (host, _) = Utils.parseHostPort(hostPort)
- val data = new ExecutorData(executorRef, executorRef.address, host, cores, cores, logUrls)
+ val data = new ExecutorData(executorRef, executorRef.address, executorAddress.host,
+ cores, cores, logUrls)
// This must be synchronized because variables mutated
// in this block are read when requesting executors
CoarseGrainedSchedulerBackend.this.synchronized {
@@ -151,7 +157,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
}
}
// Note: some tests expect the reply to come after we put the executor in the map
- context.reply(RegisteredExecutor)
+ context.reply(RegisteredExecutor(executorAddress.host))
listenerBus.post(
SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data))
makeOffers()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
index e483688ede..cb24072d7d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
@@ -170,8 +170,6 @@ private[spark] abstract class YarnSchedulerBackend(
case RegisterClusterManager(am) =>
logInfo(s"ApplicationMaster registered as $am")
amEndpoint = Option(am)
- // See SPARK-10987.
- am.send(DriverHello)
case AddWebUIFilter(filterName, filterParams, proxyBase) =>
addWebUIFilter(filterName, filterParams, proxyBase)
diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
index 3bead6395d..834e4743df 100644
--- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
@@ -48,7 +48,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
}
}
- def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv
+ def createRpcEnv(conf: SparkConf, name: String, port: Int, clientMode: Boolean = false): RpcEnv
test("send a message locally") {
@volatile var message: String = null
@@ -76,7 +76,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
}
})
- val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345)
+ val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true)
// Use anotherEnv to find out the RpcEndpointRef
val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "send-remotely")
try {
@@ -130,7 +130,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
}
})
- val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345)
+ val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true)
// Use anotherEnv to find out the RpcEndpointRef
val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-remotely")
try {
@@ -158,7 +158,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
val shortProp = "spark.rpc.short.timeout"
conf.set("spark.rpc.retry.wait", "0")
conf.set("spark.rpc.numRetries", "1")
- val anotherEnv = createRpcEnv(conf, "remote", 13345)
+ val anotherEnv = createRpcEnv(conf, "remote", 13345, clientMode = true)
// Use anotherEnv to find out the RpcEndpointRef
val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-timeout")
try {
@@ -417,7 +417,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
}
})
- val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345)
+ val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true)
// Use anotherEnv to find out the RpcEndpointRef
val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "sendWithReply-remotely")
try {
@@ -457,7 +457,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
}
})
- val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345)
+ val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true)
// Use anotherEnv to find out the RpcEndpointRef
val rpcEndpointRef = anotherEnv.setupEndpointRef(
"local", env.address, "sendWithReply-remotely-error")
@@ -497,26 +497,40 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
})
- val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345)
+ val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true)
// Use anotherEnv to find out the RpcEndpointRef
val rpcEndpointRef = anotherEnv.setupEndpointRef(
"local", env.address, "network-events")
val remoteAddress = anotherEnv.address
rpcEndpointRef.send("hello")
eventually(timeout(5 seconds), interval(5 millis)) {
- assert(events === List(("onConnected", remoteAddress)))
+ // anotherEnv is connected in client mode, so the remote address may be unknown depending on
+ // the implementation. Account for that when doing checks.
+ if (remoteAddress != null) {
+ assert(events === List(("onConnected", remoteAddress)))
+ } else {
+ assert(events.size === 1)
+ assert(events(0)._1 === "onConnected")
+ }
}
anotherEnv.shutdown()
anotherEnv.awaitTermination()
eventually(timeout(5 seconds), interval(5 millis)) {
- assert(events === List(
- ("onConnected", remoteAddress),
- ("onNetworkError", remoteAddress),
- ("onDisconnected", remoteAddress)) ||
- events === List(
- ("onConnected", remoteAddress),
- ("onDisconnected", remoteAddress)))
+ // Account for anotherEnv not having an address due to running in client mode.
+ if (remoteAddress != null) {
+ assert(events === List(
+ ("onConnected", remoteAddress),
+ ("onNetworkError", remoteAddress),
+ ("onDisconnected", remoteAddress)) ||
+ events === List(
+ ("onConnected", remoteAddress),
+ ("onDisconnected", remoteAddress)))
+ } else {
+ val eventNames = events.map(_._1)
+ assert(eventNames === List("onConnected", "onNetworkError", "onDisconnected") ||
+ eventNames === List("onConnected", "onDisconnected"))
+ }
}
}
@@ -529,7 +543,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
}
})
- val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345)
+ val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true)
// Use anotherEnv to find out the RpcEndpointRef
val rpcEndpointRef = anotherEnv.setupEndpointRef(
"local", env.address, "sendWithReply-unserializable-error")
@@ -558,7 +572,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
conf.set("spark.authenticate.secret", "good")
val localEnv = createRpcEnv(conf, "authentication-local", 13345)
- val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345)
+ val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345, clientMode = true)
try {
@volatile var message: String = null
@@ -589,7 +603,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
conf.set("spark.authenticate.secret", "good")
val localEnv = createRpcEnv(conf, "authentication-local", 13345)
- val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345)
+ val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345, clientMode = true)
try {
localEnv.setupEndpoint("ask-authentication", new RpcEndpoint {
diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala
index 4aa75c9230..6478ab51c4 100644
--- a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala
@@ -22,9 +22,12 @@ import org.apache.spark.{SSLSampleConfigs, SecurityManager, SparkConf}
class AkkaRpcEnvSuite extends RpcEnvSuite {
- override def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv = {
+ override def createRpcEnv(conf: SparkConf,
+ name: String,
+ port: Int,
+ clientMode: Boolean = false): RpcEnv = {
new AkkaRpcEnvFactory().create(
- RpcEnvConfig(conf, name, "localhost", port, new SecurityManager(conf)))
+ RpcEnvConfig(conf, name, "localhost", port, new SecurityManager(conf), clientMode))
}
test("setupEndpointRef: systemName, address, endpointName") {
@@ -37,7 +40,7 @@ class AkkaRpcEnvSuite extends RpcEnvSuite {
})
val conf = new SparkConf()
val newRpcEnv = new AkkaRpcEnvFactory().create(
- RpcEnvConfig(conf, "test", "localhost", 12346, new SecurityManager(conf)))
+ RpcEnvConfig(conf, "test", "localhost", 12346, new SecurityManager(conf), false))
try {
val newRef = newRpcEnv.setupEndpointRef("local", ref.address, "test_endpoint")
assert(s"akka.tcp://local@${env.address}/user/test_endpoint" ===
@@ -56,7 +59,7 @@ class AkkaRpcEnvSuite extends RpcEnvSuite {
val conf = SSLSampleConfigs.sparkSSLConfig()
val securityManager = new SecurityManager(conf)
val rpcEnv = new AkkaRpcEnvFactory().create(
- RpcEnvConfig(conf, "test", "localhost", 12346, securityManager))
+ RpcEnvConfig(conf, "test", "localhost", 12346, securityManager, false))
try {
val uri = rpcEnv.uriOf("local", RpcAddress("1.2.3.4", 12345), "test_endpoint")
assert("akka.ssl.tcp://local@1.2.3.4:12345/user/test_endpoint" === uri)
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala
index 973a07a0bd..56743ba650 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala
@@ -22,8 +22,13 @@ import org.apache.spark.SparkFunSuite
class NettyRpcAddressSuite extends SparkFunSuite {
test("toString") {
- val addr = RpcEndpointAddress("localhost", 12345, "test")
+ val addr = new RpcEndpointAddress("localhost", 12345, "test")
assert(addr.toString === "spark://test@localhost:12345")
}
+ test("toString for client mode") {
+ val addr = RpcEndpointAddress(null, "test")
+ assert(addr.toString === "spark-client://test")
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala
index be19668e17..ce83087ec0 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala
@@ -22,8 +22,13 @@ import org.apache.spark.rpc._
class NettyRpcEnvSuite extends RpcEnvSuite {
- override def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv = {
- val config = RpcEnvConfig(conf, "test", "localhost", port, new SecurityManager(conf))
+ override def createRpcEnv(
+ conf: SparkConf,
+ name: String,
+ port: Int,
+ clientMode: Boolean = false): RpcEnv = {
+ val config = RpcEnvConfig(conf, "test", "localhost", port, new SecurityManager(conf),
+ clientMode)
new NettyRpcEnvFactory().create(config)
}
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
index 5430e4c0c4..f9d8e80c98 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
@@ -30,7 +30,7 @@ import org.apache.spark.rpc._
class NettyRpcHandlerSuite extends SparkFunSuite {
val env = mock(classOf[NettyRpcEnv])
- when(env.deserialize(any(classOf[Array[Byte]]))(any())).
+ when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any())).
thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false))
test("receive") {
@@ -42,7 +42,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite {
when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000))
nettyRpcHandler.receive(client, null, null)
- verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 12345)))
+ verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000)))
}
test("connectionTerminated") {
@@ -57,9 +57,9 @@ class NettyRpcHandlerSuite extends SparkFunSuite {
when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000))
nettyRpcHandler.connectionTerminated(client)
- verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 12345)))
+ verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000)))
verify(dispatcher, times(1)).postToAll(
- RemoteProcessDisconnected(RpcAddress("localhost", 12345)))
+ RemoteProcessDisconnected(RpcAddress("localhost", 40000)))
}
}
diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml
index 541ed9a8d0..e2360eff5c 100644
--- a/network/yarn/pom.xml
+++ b/network/yarn/pom.xml
@@ -54,6 +54,11 @@
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-client</artifactId>
</dependency>
+ <dependency>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-api</artifactId>
+ <scope>provided</scope>
+ </dependency>
</dependencies>
<build>
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index c6a6d7ac56..12ae350e4c 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -321,7 +321,8 @@ private[spark] class ApplicationMaster(
private def runExecutorLauncher(securityMgr: SecurityManager): Unit = {
val port = sparkConf.getInt("spark.yarn.am.port", 0)
- rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, port, sparkConf, securityMgr)
+ rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, port, sparkConf, securityMgr,
+ clientMode = true)
val driverRef = waitForSparkDriver()
addAmIpFilter()
registerAM(rpcEnv, driverRef, sparkConf.get("spark.driver.appUIAddress", ""), securityMgr)
@@ -574,9 +575,6 @@ private[spark] class ApplicationMaster(
case x: AddWebUIFilter =>
logInfo(s"Add WebUI Filter. $x")
driver.send(x)
-
- case DriverHello =>
- // SPARK-10987: no action needed for this message.
}
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {