diff options
11 files changed, 148 insertions, 92 deletions
diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala index 8ffcfc0878..4172d924c8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala @@ -65,7 +65,7 @@ private[mesos] class MesosExternalShuffleBlockHandler(transportConf: TransportCo /** * On connection termination, clean up shuffle files written by the associated application. */ - override def connectionTerminated(client: TransportClient): Unit = { + override def channelInactive(client: TransportClient): Unit = { val address = client.getSocketAddress if (connectedApps.contains(address)) { val appId = connectedApps(address) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 975ea1a1ab..090a1b9f6e 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -548,10 +548,6 @@ private[netty] class NettyRpcHandler( nettyEnv: NettyRpcEnv, streamManager: StreamManager) extends RpcHandler with Logging { - // TODO: Can we add connection callback (channel registered) to the underlying framework? - // A variable to track whether we should dispatch the RemoteProcessConnected message. - private val clients = new ConcurrentHashMap[TransportClient, JBoolean]() - // A variable to track the remote RpcEnv addresses of all clients private val remoteAddresses = new ConcurrentHashMap[RpcAddress, RpcAddress]() @@ -574,9 +570,6 @@ private[netty] class NettyRpcHandler( val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] assert(addr != null) val clientAddr = RpcAddress(addr.getHostName, addr.getPort) - if (clients.putIfAbsent(client, JBoolean.TRUE) == null) { - dispatcher.postToAll(RemoteProcessConnected(clientAddr)) - } val requestMessage = nettyEnv.deserialize[RequestMessage](client, message) if (requestMessage.senderAddress == null) { // Create a new message with the socket address of the client as the sender. @@ -613,10 +606,16 @@ private[netty] class NettyRpcHandler( } } - override def connectionTerminated(client: TransportClient): Unit = { + override def channelActive(client: TransportClient): Unit = { + val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] + assert(addr != null) + val clientAddr = RpcAddress(addr.getHostName, addr.getPort) + dispatcher.postToAll(RemoteProcessConnected(clientAddr)) + } + + override def channelInactive(client: TransportClient): Unit = { val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] if (addr != null) { - clients.remove(client) val clientAddr = RpcAddress(addr.getHostName, addr.getPort) nettyEnv.removeOutbox(clientAddr) dispatcher.postToAll(RemoteProcessDisconnected(clientAddr)) diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 49e3e0191c..7b3a17c172 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -484,10 +484,16 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } } - test("network events") { + /** + * Setup an [[RpcEndpoint]] to collect all network events. + * @return the [[RpcEndpointRef]] and an `Seq` that contains network events. + */ + private def setupNetworkEndpoint( + _env: RpcEnv, + name: String): (RpcEndpointRef, Seq[(Any, Any)]) = { val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)] - env.setupEndpoint("network-events", new ThreadSafeRpcEndpoint { - override val rpcEnv = env + val ref = _env.setupEndpoint("network-events-non-client", new ThreadSafeRpcEndpoint { + override val rpcEnv = _env override def receive: PartialFunction[Any, Unit] = { case "hello" => @@ -507,83 +513,97 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) + (ref, events) + } - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) - // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef( - "local", env.address, "network-events") - val remoteAddress = anotherEnv.address - rpcEndpointRef.send("hello") - eventually(timeout(5 seconds), interval(5 millis)) { - // anotherEnv is connected in client mode, so the remote address may be unknown depending on - // the implementation. Account for that when doing checks. - if (remoteAddress != null) { - assert(events === List(("onConnected", remoteAddress))) - } else { - assert(events.size === 1) - assert(events(0)._1 === "onConnected") + test("network events in sever RpcEnv when another RpcEnv is in server mode") { + val serverEnv1 = createRpcEnv(new SparkConf(), "server1", 0, clientMode = false) + val serverEnv2 = createRpcEnv(new SparkConf(), "server2", 0, clientMode = false) + val (_, events) = setupNetworkEndpoint(serverEnv1, "network-events") + val (serverRef2, _) = setupNetworkEndpoint(serverEnv2, "network-events") + try { + val serverRefInServer2 = + serverEnv1.setupEndpointRef("server2", serverRef2.address, serverRef2.name) + // Send a message to set up the connection + serverRefInServer2.send("hello") + + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events.contains(("onConnected", serverEnv2.address))) } - } - anotherEnv.shutdown() - anotherEnv.awaitTermination() - eventually(timeout(5 seconds), interval(5 millis)) { - // Account for anotherEnv not having an address due to running in client mode. - if (remoteAddress != null) { - assert(events === List( - ("onConnected", remoteAddress), - ("onNetworkError", remoteAddress), - ("onDisconnected", remoteAddress)) || - events === List( - ("onConnected", remoteAddress), - ("onDisconnected", remoteAddress))) - } else { - val eventNames = events.map(_._1) - assert(eventNames === List("onConnected", "onNetworkError", "onDisconnected") || - eventNames === List("onConnected", "onDisconnected")) + serverEnv2.shutdown() + serverEnv2.awaitTermination() + + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events.contains(("onConnected", serverEnv2.address))) + assert(events.contains(("onDisconnected", serverEnv2.address))) } + } finally { + serverEnv1.shutdown() + serverEnv2.shutdown() + serverEnv1.awaitTermination() + serverEnv2.awaitTermination() } } - test("network events between non-client-mode RpcEnvs") { - val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)] - env.setupEndpoint("network-events-non-client", new ThreadSafeRpcEndpoint { - override val rpcEnv = env + test("network events in sever RpcEnv when another RpcEnv is in client mode") { + val serverEnv = createRpcEnv(new SparkConf(), "server", 0, clientMode = false) + val (serverRef, events) = setupNetworkEndpoint(serverEnv, "network-events") + val clientEnv = createRpcEnv(new SparkConf(), "client", 0, clientMode = true) + try { + val serverRefInClient = + clientEnv.setupEndpointRef("server", serverRef.address, serverRef.name) + // Send a message to set up the connection + serverRefInClient.send("hello") - override def receive: PartialFunction[Any, Unit] = { - case "hello" => - case m => events += "receive" -> m + eventually(timeout(5 seconds), interval(5 millis)) { + // We don't know the exact client address but at least we can verify the message type + assert(events.map(_._1).contains("onConnected")) } - override def onConnected(remoteAddress: RpcAddress): Unit = { - events += "onConnected" -> remoteAddress - } + clientEnv.shutdown() + clientEnv.awaitTermination() - override def onDisconnected(remoteAddress: RpcAddress): Unit = { - events += "onDisconnected" -> remoteAddress + eventually(timeout(5 seconds), interval(5 millis)) { + // We don't know the exact client address but at least we can verify the message type + assert(events.map(_._1).contains("onConnected")) + assert(events.map(_._1).contains("onDisconnected")) } + } finally { + clientEnv.shutdown() + serverEnv.shutdown() + clientEnv.awaitTermination() + serverEnv.awaitTermination() + } + } - override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { - events += "onNetworkError" -> remoteAddress - } + test("network events in client RpcEnv when another RpcEnv is in server mode") { + val clientEnv = createRpcEnv(new SparkConf(), "client", 0, clientMode = true) + val serverEnv = createRpcEnv(new SparkConf(), "server", 0, clientMode = false) + val (_, events) = setupNetworkEndpoint(clientEnv, "network-events") + val (serverRef, _) = setupNetworkEndpoint(serverEnv, "network-events") + try { + val serverRefInClient = + clientEnv.setupEndpointRef("server", serverRef.address, serverRef.name) + // Send a message to set up the connection + serverRefInClient.send("hello") - }) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events.contains(("onConnected", serverEnv.address))) + } - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = false) - // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef( - "local", env.address, "network-events-non-client") - val remoteAddress = anotherEnv.address - rpcEndpointRef.send("hello") - eventually(timeout(5 seconds), interval(5 millis)) { - assert(events.contains(("onConnected", remoteAddress))) - } + serverEnv.shutdown() + serverEnv.awaitTermination() - anotherEnv.shutdown() - anotherEnv.awaitTermination() - eventually(timeout(5 seconds), interval(5 millis)) { - assert(events.contains(("onConnected", remoteAddress))) - assert(events.contains(("onDisconnected", remoteAddress))) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events.contains(("onConnected", serverEnv.address))) + assert(events.contains(("onDisconnected", serverEnv.address))) + } + } finally { + clientEnv.shutdown() + serverEnv.shutdown() + clientEnv.awaitTermination() + serverEnv.awaitTermination() } } diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala index ebd6f70071..d4aebe9fd9 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -43,7 +43,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite { val channel = mock(classOf[Channel]) val client = new TransportClient(channel, mock(classOf[TransportResponseHandler])) when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) - nettyRpcHandler.receive(client, null, null) + nettyRpcHandler.channelActive(client) verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000))) } @@ -55,10 +55,10 @@ class NettyRpcHandlerSuite extends SparkFunSuite { val channel = mock(classOf[Channel]) val client = new TransportClient(channel, mock(classOf[TransportResponseHandler])) when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) - nettyRpcHandler.receive(client, null, null) + nettyRpcHandler.channelActive(client) when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) - nettyRpcHandler.connectionTerminated(client) + nettyRpcHandler.channelInactive(client) verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000))) verify(dispatcher, times(1)).postToAll( diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 23a8dba593..f0e2004d2d 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -116,7 +116,11 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> { } @Override - public void channelUnregistered() { + public void channelActive() { + } + + @Override + public void channelInactive() { if (numOutstandingRequests() > 0) { String remoteAddress = NettyUtils.getRemoteAddress(channel); logger.error("Still have {} requests outstanding when connection from {} is closed", diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index c215bd9d15..c41f5b6873 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -135,9 +135,14 @@ class SaslRpcHandler extends RpcHandler { } @Override - public void connectionTerminated(TransportClient client) { + public void channelActive(TransportClient client) { + delegate.channelActive(client); + } + + @Override + public void channelInactive(TransportClient client) { try { - delegate.connectionTerminated(client); + delegate.channelInactive(client); } finally { if (saslServer != null) { saslServer.dispose(); diff --git a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java index 3843406b27..4a1f28e9ff 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java @@ -28,9 +28,12 @@ public abstract class MessageHandler<T extends Message> { /** Handles the receipt of a single message. */ public abstract void handle(T message) throws Exception; + /** Invoked when the channel this MessageHandler is on is active. */ + public abstract void channelActive(); + /** Invoked when an exception was caught on the Channel. */ public abstract void exceptionCaught(Throwable cause); - /** Invoked when the channel this MessageHandler is on has been unregistered. */ - public abstract void channelUnregistered(); + /** Invoked when the channel this MessageHandler is on is inactive. */ + public abstract void channelInactive(); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java index ee1c683699..c6ed0f459a 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -69,10 +69,15 @@ public abstract class RpcHandler { } /** - * Invoked when the connection associated with the given client has been invalidated. + * Invoked when the channel associated with the given client is active. + */ + public void channelActive(TransportClient client) { } + + /** + * Invoked when the channel associated with the given client is inactive. * No further requests will come from this client. */ - public void connectionTerminated(TransportClient client) { } + public void channelInactive(TransportClient client) { } public void exceptionCaught(Throwable cause, TransportClient client) { } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index 09435bcbab..18a9b7887e 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -84,14 +84,29 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler<Message } @Override - public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { + public void channelActive(ChannelHandlerContext ctx) throws Exception { try { - requestHandler.channelUnregistered(); + requestHandler.channelActive(); + } catch (RuntimeException e) { + logger.error("Exception from request handler while registering channel", e); + } + try { + responseHandler.channelActive(); + } catch (RuntimeException e) { + logger.error("Exception from response handler while registering channel", e); + } + super.channelRegistered(ctx); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + try { + requestHandler.channelInactive(); } catch (RuntimeException e) { logger.error("Exception from request handler while unregistering channel", e); } try { - responseHandler.channelUnregistered(); + responseHandler.channelInactive(); } catch (RuntimeException e) { logger.error("Exception from response handler while unregistering channel", e); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 105f538831..296ced3db0 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -83,7 +83,12 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> { } @Override - public void channelUnregistered() { + public void channelActive() { + rpcHandler.channelActive(reverseClient); + } + + @Override + public void channelInactive() { if (streamManager != null) { try { streamManager.connectionTerminated(channel); @@ -91,7 +96,7 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> { logger.error("StreamManager connectionTerminated() callback failed.", e); } } - rpcHandler.connectionTerminated(reverseClient); + rpcHandler.channelInactive(reverseClient); } @Override diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 751516b9d8..045773317a 100644 --- a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -160,7 +160,7 @@ public class SparkSaslSuite { long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS); while (deadline > System.nanoTime()) { try { - verify(rpcHandler, times(2)).connectionTerminated(any(TransportClient.class)); + verify(rpcHandler, times(2)).channelInactive(any(TransportClient.class)); error = null; break; } catch (Throwable t) { @@ -362,8 +362,8 @@ public class SparkSaslSuite { saslHandler.getStreamManager(); verify(handler).getStreamManager(); - saslHandler.connectionTerminated(null); - verify(handler).connectionTerminated(any(TransportClient.class)); + saslHandler.channelInactive(null); + verify(handler).channelInactive(any(TransportClient.class)); saslHandler.exceptionCaught(null, null); verify(handler).exceptionCaught(any(Throwable.class), any(TransportClient.class)); |