aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala17
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala148
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala6
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java6
-rw-r--r--network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java9
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java7
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java9
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java21
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java9
-rw-r--r--network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java6
11 files changed, 148 insertions, 92 deletions
diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala
index 8ffcfc0878..4172d924c8 100644
--- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala
@@ -65,7 +65,7 @@ private[mesos] class MesosExternalShuffleBlockHandler(transportConf: TransportCo
/**
* On connection termination, clean up shuffle files written by the associated application.
*/
- override def connectionTerminated(client: TransportClient): Unit = {
+ override def channelInactive(client: TransportClient): Unit = {
val address = client.getSocketAddress
if (connectedApps.contains(address)) {
val appId = connectedApps(address)
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
index 975ea1a1ab..090a1b9f6e 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
@@ -548,10 +548,6 @@ private[netty] class NettyRpcHandler(
nettyEnv: NettyRpcEnv,
streamManager: StreamManager) extends RpcHandler with Logging {
- // TODO: Can we add connection callback (channel registered) to the underlying framework?
- // A variable to track whether we should dispatch the RemoteProcessConnected message.
- private val clients = new ConcurrentHashMap[TransportClient, JBoolean]()
-
// A variable to track the remote RpcEnv addresses of all clients
private val remoteAddresses = new ConcurrentHashMap[RpcAddress, RpcAddress]()
@@ -574,9 +570,6 @@ private[netty] class NettyRpcHandler(
val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
assert(addr != null)
val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
- if (clients.putIfAbsent(client, JBoolean.TRUE) == null) {
- dispatcher.postToAll(RemoteProcessConnected(clientAddr))
- }
val requestMessage = nettyEnv.deserialize[RequestMessage](client, message)
if (requestMessage.senderAddress == null) {
// Create a new message with the socket address of the client as the sender.
@@ -613,10 +606,16 @@ private[netty] class NettyRpcHandler(
}
}
- override def connectionTerminated(client: TransportClient): Unit = {
+ override def channelActive(client: TransportClient): Unit = {
+ val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
+ assert(addr != null)
+ val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
+ dispatcher.postToAll(RemoteProcessConnected(clientAddr))
+ }
+
+ override def channelInactive(client: TransportClient): Unit = {
val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
if (addr != null) {
- clients.remove(client)
val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
nettyEnv.removeOutbox(clientAddr)
dispatcher.postToAll(RemoteProcessDisconnected(clientAddr))
diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
index 49e3e0191c..7b3a17c172 100644
--- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
@@ -484,10 +484,16 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
}
}
- test("network events") {
+ /**
+ * Setup an [[RpcEndpoint]] to collect all network events.
+ * @return the [[RpcEndpointRef]] and an `Seq` that contains network events.
+ */
+ private def setupNetworkEndpoint(
+ _env: RpcEnv,
+ name: String): (RpcEndpointRef, Seq[(Any, Any)]) = {
val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)]
- env.setupEndpoint("network-events", new ThreadSafeRpcEndpoint {
- override val rpcEnv = env
+ val ref = _env.setupEndpoint("network-events-non-client", new ThreadSafeRpcEndpoint {
+ override val rpcEnv = _env
override def receive: PartialFunction[Any, Unit] = {
case "hello" =>
@@ -507,83 +513,97 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
}
})
+ (ref, events)
+ }
- val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true)
- // Use anotherEnv to find out the RpcEndpointRef
- val rpcEndpointRef = anotherEnv.setupEndpointRef(
- "local", env.address, "network-events")
- val remoteAddress = anotherEnv.address
- rpcEndpointRef.send("hello")
- eventually(timeout(5 seconds), interval(5 millis)) {
- // anotherEnv is connected in client mode, so the remote address may be unknown depending on
- // the implementation. Account for that when doing checks.
- if (remoteAddress != null) {
- assert(events === List(("onConnected", remoteAddress)))
- } else {
- assert(events.size === 1)
- assert(events(0)._1 === "onConnected")
+ test("network events in sever RpcEnv when another RpcEnv is in server mode") {
+ val serverEnv1 = createRpcEnv(new SparkConf(), "server1", 0, clientMode = false)
+ val serverEnv2 = createRpcEnv(new SparkConf(), "server2", 0, clientMode = false)
+ val (_, events) = setupNetworkEndpoint(serverEnv1, "network-events")
+ val (serverRef2, _) = setupNetworkEndpoint(serverEnv2, "network-events")
+ try {
+ val serverRefInServer2 =
+ serverEnv1.setupEndpointRef("server2", serverRef2.address, serverRef2.name)
+ // Send a message to set up the connection
+ serverRefInServer2.send("hello")
+
+ eventually(timeout(5 seconds), interval(5 millis)) {
+ assert(events.contains(("onConnected", serverEnv2.address)))
}
- }
- anotherEnv.shutdown()
- anotherEnv.awaitTermination()
- eventually(timeout(5 seconds), interval(5 millis)) {
- // Account for anotherEnv not having an address due to running in client mode.
- if (remoteAddress != null) {
- assert(events === List(
- ("onConnected", remoteAddress),
- ("onNetworkError", remoteAddress),
- ("onDisconnected", remoteAddress)) ||
- events === List(
- ("onConnected", remoteAddress),
- ("onDisconnected", remoteAddress)))
- } else {
- val eventNames = events.map(_._1)
- assert(eventNames === List("onConnected", "onNetworkError", "onDisconnected") ||
- eventNames === List("onConnected", "onDisconnected"))
+ serverEnv2.shutdown()
+ serverEnv2.awaitTermination()
+
+ eventually(timeout(5 seconds), interval(5 millis)) {
+ assert(events.contains(("onConnected", serverEnv2.address)))
+ assert(events.contains(("onDisconnected", serverEnv2.address)))
}
+ } finally {
+ serverEnv1.shutdown()
+ serverEnv2.shutdown()
+ serverEnv1.awaitTermination()
+ serverEnv2.awaitTermination()
}
}
- test("network events between non-client-mode RpcEnvs") {
- val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)]
- env.setupEndpoint("network-events-non-client", new ThreadSafeRpcEndpoint {
- override val rpcEnv = env
+ test("network events in sever RpcEnv when another RpcEnv is in client mode") {
+ val serverEnv = createRpcEnv(new SparkConf(), "server", 0, clientMode = false)
+ val (serverRef, events) = setupNetworkEndpoint(serverEnv, "network-events")
+ val clientEnv = createRpcEnv(new SparkConf(), "client", 0, clientMode = true)
+ try {
+ val serverRefInClient =
+ clientEnv.setupEndpointRef("server", serverRef.address, serverRef.name)
+ // Send a message to set up the connection
+ serverRefInClient.send("hello")
- override def receive: PartialFunction[Any, Unit] = {
- case "hello" =>
- case m => events += "receive" -> m
+ eventually(timeout(5 seconds), interval(5 millis)) {
+ // We don't know the exact client address but at least we can verify the message type
+ assert(events.map(_._1).contains("onConnected"))
}
- override def onConnected(remoteAddress: RpcAddress): Unit = {
- events += "onConnected" -> remoteAddress
- }
+ clientEnv.shutdown()
+ clientEnv.awaitTermination()
- override def onDisconnected(remoteAddress: RpcAddress): Unit = {
- events += "onDisconnected" -> remoteAddress
+ eventually(timeout(5 seconds), interval(5 millis)) {
+ // We don't know the exact client address but at least we can verify the message type
+ assert(events.map(_._1).contains("onConnected"))
+ assert(events.map(_._1).contains("onDisconnected"))
}
+ } finally {
+ clientEnv.shutdown()
+ serverEnv.shutdown()
+ clientEnv.awaitTermination()
+ serverEnv.awaitTermination()
+ }
+ }
- override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = {
- events += "onNetworkError" -> remoteAddress
- }
+ test("network events in client RpcEnv when another RpcEnv is in server mode") {
+ val clientEnv = createRpcEnv(new SparkConf(), "client", 0, clientMode = true)
+ val serverEnv = createRpcEnv(new SparkConf(), "server", 0, clientMode = false)
+ val (_, events) = setupNetworkEndpoint(clientEnv, "network-events")
+ val (serverRef, _) = setupNetworkEndpoint(serverEnv, "network-events")
+ try {
+ val serverRefInClient =
+ clientEnv.setupEndpointRef("server", serverRef.address, serverRef.name)
+ // Send a message to set up the connection
+ serverRefInClient.send("hello")
- })
+ eventually(timeout(5 seconds), interval(5 millis)) {
+ assert(events.contains(("onConnected", serverEnv.address)))
+ }
- val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = false)
- // Use anotherEnv to find out the RpcEndpointRef
- val rpcEndpointRef = anotherEnv.setupEndpointRef(
- "local", env.address, "network-events-non-client")
- val remoteAddress = anotherEnv.address
- rpcEndpointRef.send("hello")
- eventually(timeout(5 seconds), interval(5 millis)) {
- assert(events.contains(("onConnected", remoteAddress)))
- }
+ serverEnv.shutdown()
+ serverEnv.awaitTermination()
- anotherEnv.shutdown()
- anotherEnv.awaitTermination()
- eventually(timeout(5 seconds), interval(5 millis)) {
- assert(events.contains(("onConnected", remoteAddress)))
- assert(events.contains(("onDisconnected", remoteAddress)))
+ eventually(timeout(5 seconds), interval(5 millis)) {
+ assert(events.contains(("onConnected", serverEnv.address)))
+ assert(events.contains(("onDisconnected", serverEnv.address)))
+ }
+ } finally {
+ clientEnv.shutdown()
+ serverEnv.shutdown()
+ clientEnv.awaitTermination()
+ serverEnv.awaitTermination()
}
}
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
index ebd6f70071..d4aebe9fd9 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
@@ -43,7 +43,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite {
val channel = mock(classOf[Channel])
val client = new TransportClient(channel, mock(classOf[TransportResponseHandler]))
when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000))
- nettyRpcHandler.receive(client, null, null)
+ nettyRpcHandler.channelActive(client)
verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000)))
}
@@ -55,10 +55,10 @@ class NettyRpcHandlerSuite extends SparkFunSuite {
val channel = mock(classOf[Channel])
val client = new TransportClient(channel, mock(classOf[TransportResponseHandler]))
when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000))
- nettyRpcHandler.receive(client, null, null)
+ nettyRpcHandler.channelActive(client)
when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000))
- nettyRpcHandler.connectionTerminated(client)
+ nettyRpcHandler.channelInactive(client)
verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000)))
verify(dispatcher, times(1)).postToAll(
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
index 23a8dba593..f0e2004d2d 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
@@ -116,7 +116,11 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
}
@Override
- public void channelUnregistered() {
+ public void channelActive() {
+ }
+
+ @Override
+ public void channelInactive() {
if (numOutstandingRequests() > 0) {
String remoteAddress = NettyUtils.getRemoteAddress(channel);
logger.error("Still have {} requests outstanding when connection from {} is closed",
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
index c215bd9d15..c41f5b6873 100644
--- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
@@ -135,9 +135,14 @@ class SaslRpcHandler extends RpcHandler {
}
@Override
- public void connectionTerminated(TransportClient client) {
+ public void channelActive(TransportClient client) {
+ delegate.channelActive(client);
+ }
+
+ @Override
+ public void channelInactive(TransportClient client) {
try {
- delegate.connectionTerminated(client);
+ delegate.channelInactive(client);
} finally {
if (saslServer != null) {
saslServer.dispose();
diff --git a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java
index 3843406b27..4a1f28e9ff 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java
@@ -28,9 +28,12 @@ public abstract class MessageHandler<T extends Message> {
/** Handles the receipt of a single message. */
public abstract void handle(T message) throws Exception;
+ /** Invoked when the channel this MessageHandler is on is active. */
+ public abstract void channelActive();
+
/** Invoked when an exception was caught on the Channel. */
public abstract void exceptionCaught(Throwable cause);
- /** Invoked when the channel this MessageHandler is on has been unregistered. */
- public abstract void channelUnregistered();
+ /** Invoked when the channel this MessageHandler is on is inactive. */
+ public abstract void channelInactive();
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
index ee1c683699..c6ed0f459a 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
@@ -69,10 +69,15 @@ public abstract class RpcHandler {
}
/**
- * Invoked when the connection associated with the given client has been invalidated.
+ * Invoked when the channel associated with the given client is active.
+ */
+ public void channelActive(TransportClient client) { }
+
+ /**
+ * Invoked when the channel associated with the given client is inactive.
* No further requests will come from this client.
*/
- public void connectionTerminated(TransportClient client) { }
+ public void channelInactive(TransportClient client) { }
public void exceptionCaught(Throwable cause, TransportClient client) { }
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
index 09435bcbab..18a9b7887e 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
@@ -84,14 +84,29 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler<Message
}
@Override
- public void channelUnregistered(ChannelHandlerContext ctx) throws Exception {
+ public void channelActive(ChannelHandlerContext ctx) throws Exception {
try {
- requestHandler.channelUnregistered();
+ requestHandler.channelActive();
+ } catch (RuntimeException e) {
+ logger.error("Exception from request handler while registering channel", e);
+ }
+ try {
+ responseHandler.channelActive();
+ } catch (RuntimeException e) {
+ logger.error("Exception from response handler while registering channel", e);
+ }
+ super.channelRegistered(ctx);
+ }
+
+ @Override
+ public void channelInactive(ChannelHandlerContext ctx) throws Exception {
+ try {
+ requestHandler.channelInactive();
} catch (RuntimeException e) {
logger.error("Exception from request handler while unregistering channel", e);
}
try {
- responseHandler.channelUnregistered();
+ responseHandler.channelInactive();
} catch (RuntimeException e) {
logger.error("Exception from response handler while unregistering channel", e);
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
index 105f538831..296ced3db0 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
@@ -83,7 +83,12 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
}
@Override
- public void channelUnregistered() {
+ public void channelActive() {
+ rpcHandler.channelActive(reverseClient);
+ }
+
+ @Override
+ public void channelInactive() {
if (streamManager != null) {
try {
streamManager.connectionTerminated(channel);
@@ -91,7 +96,7 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
logger.error("StreamManager connectionTerminated() callback failed.", e);
}
}
- rpcHandler.connectionTerminated(reverseClient);
+ rpcHandler.channelInactive(reverseClient);
}
@Override
diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
index 751516b9d8..045773317a 100644
--- a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
@@ -160,7 +160,7 @@ public class SparkSaslSuite {
long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS);
while (deadline > System.nanoTime()) {
try {
- verify(rpcHandler, times(2)).connectionTerminated(any(TransportClient.class));
+ verify(rpcHandler, times(2)).channelInactive(any(TransportClient.class));
error = null;
break;
} catch (Throwable t) {
@@ -362,8 +362,8 @@ public class SparkSaslSuite {
saslHandler.getStreamManager();
verify(handler).getStreamManager();
- saslHandler.connectionTerminated(null);
- verify(handler).connectionTerminated(any(TransportClient.class));
+ saslHandler.channelInactive(null);
+ verify(handler).channelInactive(any(TransportClient.class));
saslHandler.exceptionCaught(null, null);
verify(handler).exceptionCaught(any(Throwable.class), any(TransportClient.class));