diff options
Diffstat (limited to 'network/common')
3 files changed, 37 insertions, 3 deletions
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index 3f2ebe3288..7033adb9ca 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -115,9 +115,18 @@ class SaslRpcHandler extends RpcHandler { @Override public void connectionTerminated(TransportClient client) { - if (saslServer != null) { - saslServer.dispose(); + try { + delegate.connectionTerminated(client); + } finally { + if (saslServer != null) { + saslServer.dispose(); + } } } + @Override + public void exceptionCaught(Throwable cause, TransportClient client) { + delegate.exceptionCaught(cause, client); + } + } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 96941d26be..9b8b047b49 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -76,7 +76,13 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> { @Override public void channelUnregistered() { - streamManager.connectionTerminated(channel); + if (streamManager != null) { + try { + streamManager.connectionTerminated(channel); + } catch (RuntimeException e) { + logger.error("StreamManager connectionTerminated() callback failed.", e); + } + } rpcHandler.connectionTerminated(reverseClient); } diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 8104004847..3469e84e7f 100644 --- a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -153,6 +153,8 @@ public class SparkSaslSuite { assertEquals("Pong", new String(response, StandardCharsets.UTF_8)); } finally { ctx.close(); + // There should be 2 terminated events; one for the client, one for the server. + verify(rpcHandler, times(2)).connectionTerminated(any(TransportClient.class)); } } @@ -334,6 +336,23 @@ public class SparkSaslSuite { } } + @Test + public void testRpcHandlerDelegate() throws Exception { + // Tests all delegates exception for receive(), which is more complicated and already handled + // by all other tests. + RpcHandler handler = mock(RpcHandler.class); + RpcHandler saslHandler = new SaslRpcHandler(null, null, handler, null); + + saslHandler.getStreamManager(); + verify(handler).getStreamManager(); + + saslHandler.connectionTerminated(null); + verify(handler).connectionTerminated(any(TransportClient.class)); + + saslHandler.exceptionCaught(null, null); + verify(handler).exceptionCaught(any(Throwable.class), any(TransportClient.class)); + } + private static class SaslTestCtx { final TransportClient client; |