diff options
Diffstat (limited to 'network/common')
7 files changed, 52 insertions, 15 deletions
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 23a8dba593..f0e2004d2d 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -116,7 +116,11 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> { } @Override - public void channelUnregistered() { + public void channelActive() { + } + + @Override + public void channelInactive() { if (numOutstandingRequests() > 0) { String remoteAddress = NettyUtils.getRemoteAddress(channel); logger.error("Still have {} requests outstanding when connection from {} is closed", diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index c215bd9d15..c41f5b6873 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -135,9 +135,14 @@ class SaslRpcHandler extends RpcHandler { } @Override - public void connectionTerminated(TransportClient client) { + public void channelActive(TransportClient client) { + delegate.channelActive(client); + } + + @Override + public void channelInactive(TransportClient client) { try { - delegate.connectionTerminated(client); + delegate.channelInactive(client); } finally { if (saslServer != null) { saslServer.dispose(); diff --git a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java index 3843406b27..4a1f28e9ff 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java @@ -28,9 +28,12 @@ public abstract class MessageHandler<T extends Message> { /** Handles the receipt of a single message. */ public abstract void handle(T message) throws Exception; + /** Invoked when the channel this MessageHandler is on is active. */ + public abstract void channelActive(); + /** Invoked when an exception was caught on the Channel. */ public abstract void exceptionCaught(Throwable cause); - /** Invoked when the channel this MessageHandler is on has been unregistered. */ - public abstract void channelUnregistered(); + /** Invoked when the channel this MessageHandler is on is inactive. */ + public abstract void channelInactive(); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java index ee1c683699..c6ed0f459a 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -69,10 +69,15 @@ public abstract class RpcHandler { } /** - * Invoked when the connection associated with the given client has been invalidated. + * Invoked when the channel associated with the given client is active. + */ + public void channelActive(TransportClient client) { } + + /** + * Invoked when the channel associated with the given client is inactive. * No further requests will come from this client. */ - public void connectionTerminated(TransportClient client) { } + public void channelInactive(TransportClient client) { } public void exceptionCaught(Throwable cause, TransportClient client) { } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index 09435bcbab..18a9b7887e 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -84,14 +84,29 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler<Message } @Override - public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { + public void channelActive(ChannelHandlerContext ctx) throws Exception { try { - requestHandler.channelUnregistered(); + requestHandler.channelActive(); + } catch (RuntimeException e) { + logger.error("Exception from request handler while registering channel", e); + } + try { + responseHandler.channelActive(); + } catch (RuntimeException e) { + logger.error("Exception from response handler while registering channel", e); + } + super.channelRegistered(ctx); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + try { + requestHandler.channelInactive(); } catch (RuntimeException e) { logger.error("Exception from request handler while unregistering channel", e); } try { - responseHandler.channelUnregistered(); + responseHandler.channelInactive(); } catch (RuntimeException e) { logger.error("Exception from response handler while unregistering channel", e); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 105f538831..296ced3db0 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -83,7 +83,12 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> { } @Override - public void channelUnregistered() { + public void channelActive() { + rpcHandler.channelActive(reverseClient); + } + + @Override + public void channelInactive() { if (streamManager != null) { try { streamManager.connectionTerminated(channel); @@ -91,7 +96,7 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> { logger.error("StreamManager connectionTerminated() callback failed.", e); } } - rpcHandler.connectionTerminated(reverseClient); + rpcHandler.channelInactive(reverseClient); } @Override diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 751516b9d8..045773317a 100644 --- a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -160,7 +160,7 @@ public class SparkSaslSuite { long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS); while (deadline > System.nanoTime()) { try { - verify(rpcHandler, times(2)).connectionTerminated(any(TransportClient.class)); + verify(rpcHandler, times(2)).channelInactive(any(TransportClient.class)); error = null; break; } catch (Throwable t) { @@ -362,8 +362,8 @@ public class SparkSaslSuite { saslHandler.getStreamManager(); verify(handler).getStreamManager(); - saslHandler.connectionTerminated(null); - verify(handler).connectionTerminated(any(TransportClient.class)); + saslHandler.channelInactive(null); + verify(handler).channelInactive(any(TransportClient.class)); saslHandler.exceptionCaught(null, null); verify(handler).exceptionCaught(any(Throwable.class), any(TransportClient.class)); |