diff options
Diffstat (limited to 'core/src')
4 files changed, 96 insertions, 77 deletions
diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala index 8ffcfc0878..4172d924c8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala @@ -65,7 +65,7 @@ private[mesos] class MesosExternalShuffleBlockHandler(transportConf: TransportCo /** * On connection termination, clean up shuffle files written by the associated application. */ - override def connectionTerminated(client: TransportClient): Unit = { + override def channelInactive(client: TransportClient): Unit = { val address = client.getSocketAddress if (connectedApps.contains(address)) { val appId = connectedApps(address) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 975ea1a1ab..090a1b9f6e 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -548,10 +548,6 @@ private[netty] class NettyRpcHandler( nettyEnv: NettyRpcEnv, streamManager: StreamManager) extends RpcHandler with Logging { - // TODO: Can we add connection callback (channel registered) to the underlying framework? - // A variable to track whether we should dispatch the RemoteProcessConnected message. - private val clients = new ConcurrentHashMap[TransportClient, JBoolean]() - // A variable to track the remote RpcEnv addresses of all clients private val remoteAddresses = new ConcurrentHashMap[RpcAddress, RpcAddress]() @@ -574,9 +570,6 @@ private[netty] class NettyRpcHandler( val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] assert(addr != null) val clientAddr = RpcAddress(addr.getHostName, addr.getPort) - if (clients.putIfAbsent(client, JBoolean.TRUE) == null) { - dispatcher.postToAll(RemoteProcessConnected(clientAddr)) - } val requestMessage = nettyEnv.deserialize[RequestMessage](client, message) if (requestMessage.senderAddress == null) { // Create a new message with the socket address of the client as the sender. @@ -613,10 +606,16 @@ private[netty] class NettyRpcHandler( } } - override def connectionTerminated(client: TransportClient): Unit = { + override def channelActive(client: TransportClient): Unit = { + val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] + assert(addr != null) + val clientAddr = RpcAddress(addr.getHostName, addr.getPort) + dispatcher.postToAll(RemoteProcessConnected(clientAddr)) + } + + override def channelInactive(client: TransportClient): Unit = { val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] if (addr != null) { - clients.remove(client) val clientAddr = RpcAddress(addr.getHostName, addr.getPort) nettyEnv.removeOutbox(clientAddr) dispatcher.postToAll(RemoteProcessDisconnected(clientAddr)) diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 49e3e0191c..7b3a17c172 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -484,10 +484,16 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } } - test("network events") { + /** + * Setup an [[RpcEndpoint]] to collect all network events. + * @return the [[RpcEndpointRef]] and an `Seq` that contains network events. + */ + private def setupNetworkEndpoint( + _env: RpcEnv, + name: String): (RpcEndpointRef, Seq[(Any, Any)]) = { val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)] - env.setupEndpoint("network-events", new ThreadSafeRpcEndpoint { - override val rpcEnv = env + val ref = _env.setupEndpoint("network-events-non-client", new ThreadSafeRpcEndpoint { + override val rpcEnv = _env override def receive: PartialFunction[Any, Unit] = { case "hello" => @@ -507,83 +513,97 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) + (ref, events) + } - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) - // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef( - "local", env.address, "network-events") - val remoteAddress = anotherEnv.address - rpcEndpointRef.send("hello") - eventually(timeout(5 seconds), interval(5 millis)) { - // anotherEnv is connected in client mode, so the remote address may be unknown depending on - // the implementation. Account for that when doing checks. - if (remoteAddress != null) { - assert(events === List(("onConnected", remoteAddress))) - } else { - assert(events.size === 1) - assert(events(0)._1 === "onConnected") + test("network events in sever RpcEnv when another RpcEnv is in server mode") { + val serverEnv1 = createRpcEnv(new SparkConf(), "server1", 0, clientMode = false) + val serverEnv2 = createRpcEnv(new SparkConf(), "server2", 0, clientMode = false) + val (_, events) = setupNetworkEndpoint(serverEnv1, "network-events") + val (serverRef2, _) = setupNetworkEndpoint(serverEnv2, "network-events") + try { + val serverRefInServer2 = + serverEnv1.setupEndpointRef("server2", serverRef2.address, serverRef2.name) + // Send a message to set up the connection + serverRefInServer2.send("hello") + + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events.contains(("onConnected", serverEnv2.address))) } - } - anotherEnv.shutdown() - anotherEnv.awaitTermination() - eventually(timeout(5 seconds), interval(5 millis)) { - // Account for anotherEnv not having an address due to running in client mode. - if (remoteAddress != null) { - assert(events === List( - ("onConnected", remoteAddress), - ("onNetworkError", remoteAddress), - ("onDisconnected", remoteAddress)) || - events === List( - ("onConnected", remoteAddress), - ("onDisconnected", remoteAddress))) - } else { - val eventNames = events.map(_._1) - assert(eventNames === List("onConnected", "onNetworkError", "onDisconnected") || - eventNames === List("onConnected", "onDisconnected")) + serverEnv2.shutdown() + serverEnv2.awaitTermination() + + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events.contains(("onConnected", serverEnv2.address))) + assert(events.contains(("onDisconnected", serverEnv2.address))) } + } finally { + serverEnv1.shutdown() + serverEnv2.shutdown() + serverEnv1.awaitTermination() + serverEnv2.awaitTermination() } } - test("network events between non-client-mode RpcEnvs") { - val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)] - env.setupEndpoint("network-events-non-client", new ThreadSafeRpcEndpoint { - override val rpcEnv = env + test("network events in sever RpcEnv when another RpcEnv is in client mode") { + val serverEnv = createRpcEnv(new SparkConf(), "server", 0, clientMode = false) + val (serverRef, events) = setupNetworkEndpoint(serverEnv, "network-events") + val clientEnv = createRpcEnv(new SparkConf(), "client", 0, clientMode = true) + try { + val serverRefInClient = + clientEnv.setupEndpointRef("server", serverRef.address, serverRef.name) + // Send a message to set up the connection + serverRefInClient.send("hello") - override def receive: PartialFunction[Any, Unit] = { - case "hello" => - case m => events += "receive" -> m + eventually(timeout(5 seconds), interval(5 millis)) { + // We don't know the exact client address but at least we can verify the message type + assert(events.map(_._1).contains("onConnected")) } - override def onConnected(remoteAddress: RpcAddress): Unit = { - events += "onConnected" -> remoteAddress - } + clientEnv.shutdown() + clientEnv.awaitTermination() - override def onDisconnected(remoteAddress: RpcAddress): Unit = { - events += "onDisconnected" -> remoteAddress + eventually(timeout(5 seconds), interval(5 millis)) { + // We don't know the exact client address but at least we can verify the message type + assert(events.map(_._1).contains("onConnected")) + assert(events.map(_._1).contains("onDisconnected")) } + } finally { + clientEnv.shutdown() + serverEnv.shutdown() + clientEnv.awaitTermination() + serverEnv.awaitTermination() + } + } - override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { - events += "onNetworkError" -> remoteAddress - } + test("network events in client RpcEnv when another RpcEnv is in server mode") { + val clientEnv = createRpcEnv(new SparkConf(), "client", 0, clientMode = true) + val serverEnv = createRpcEnv(new SparkConf(), "server", 0, clientMode = false) + val (_, events) = setupNetworkEndpoint(clientEnv, "network-events") + val (serverRef, _) = setupNetworkEndpoint(serverEnv, "network-events") + try { + val serverRefInClient = + clientEnv.setupEndpointRef("server", serverRef.address, serverRef.name) + // Send a message to set up the connection + serverRefInClient.send("hello") - }) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events.contains(("onConnected", serverEnv.address))) + } - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = false) - // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef( - "local", env.address, "network-events-non-client") - val remoteAddress = anotherEnv.address - rpcEndpointRef.send("hello") - eventually(timeout(5 seconds), interval(5 millis)) { - assert(events.contains(("onConnected", remoteAddress))) - } + serverEnv.shutdown() + serverEnv.awaitTermination() - anotherEnv.shutdown() - anotherEnv.awaitTermination() - eventually(timeout(5 seconds), interval(5 millis)) { - assert(events.contains(("onConnected", remoteAddress))) - assert(events.contains(("onDisconnected", remoteAddress))) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events.contains(("onConnected", serverEnv.address))) + assert(events.contains(("onDisconnected", serverEnv.address))) + } + } finally { + clientEnv.shutdown() + serverEnv.shutdown() + clientEnv.awaitTermination() + serverEnv.awaitTermination() } } diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala index ebd6f70071..d4aebe9fd9 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -43,7 +43,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite { val channel = mock(classOf[Channel]) val client = new TransportClient(channel, mock(classOf[TransportResponseHandler])) when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) - nettyRpcHandler.receive(client, null, null) + nettyRpcHandler.channelActive(client) verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000))) } @@ -55,10 +55,10 @@ class NettyRpcHandlerSuite extends SparkFunSuite { val channel = mock(classOf[Channel]) val client = new TransportClient(channel, mock(classOf[TransportResponseHandler])) when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) - nettyRpcHandler.receive(client, null, null) + nettyRpcHandler.channelActive(client) when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) - nettyRpcHandler.connectionTerminated(client) + nettyRpcHandler.channelInactive(client) verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000))) verify(dispatcher, times(1)).postToAll( |