aboutsummaryrefslogtreecommitdiff
path: root/core/src
diff options
context:
space:
mode:
authorShixiong Zhu <shixiong@databricks.com>2015-12-18 16:06:37 -0800
committerAndrew Or <andrew@databricks.com>2015-12-18 16:06:37 -0800
commit007a32f90af1065bfa3ca4cdb194c40c06e87abf (patch)
treeb530e66b60864804fe4baaf7360b9d3b812122cb /core/src
parent0514e8d4b69615ba8918649e7e3c46b5713b6540 (diff)
downloadspark-007a32f90af1065bfa3ca4cdb194c40c06e87abf.tar.gz
spark-007a32f90af1065bfa3ca4cdb194c40c06e87abf.tar.bz2
spark-007a32f90af1065bfa3ca4cdb194c40c06e87abf.zip
[SPARK-11097][CORE] Add channelActive callback to RpcHandler to monitor the new connections
Added `channelActive` to `RpcHandler` so that `NettyRpcHandler` doesn't need `clients` any more. Author: Shixiong Zhu <shixiong@databricks.com> Closes #10301 from zsxwing/network-events.
Diffstat (limited to 'core/src')
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala17
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala148
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala6
4 files changed, 96 insertions, 77 deletions
diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala
index 8ffcfc0878..4172d924c8 100644
--- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala
@@ -65,7 +65,7 @@ private[mesos] class MesosExternalShuffleBlockHandler(transportConf: TransportCo
/**
* On connection termination, clean up shuffle files written by the associated application.
*/
- override def connectionTerminated(client: TransportClient): Unit = {
+ override def channelInactive(client: TransportClient): Unit = {
val address = client.getSocketAddress
if (connectedApps.contains(address)) {
val appId = connectedApps(address)
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
index 975ea1a1ab..090a1b9f6e 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
@@ -548,10 +548,6 @@ private[netty] class NettyRpcHandler(
nettyEnv: NettyRpcEnv,
streamManager: StreamManager) extends RpcHandler with Logging {
- // TODO: Can we add connection callback (channel registered) to the underlying framework?
- // A variable to track whether we should dispatch the RemoteProcessConnected message.
- private val clients = new ConcurrentHashMap[TransportClient, JBoolean]()
-
// A variable to track the remote RpcEnv addresses of all clients
private val remoteAddresses = new ConcurrentHashMap[RpcAddress, RpcAddress]()
@@ -574,9 +570,6 @@ private[netty] class NettyRpcHandler(
val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
assert(addr != null)
val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
- if (clients.putIfAbsent(client, JBoolean.TRUE) == null) {
- dispatcher.postToAll(RemoteProcessConnected(clientAddr))
- }
val requestMessage = nettyEnv.deserialize[RequestMessage](client, message)
if (requestMessage.senderAddress == null) {
// Create a new message with the socket address of the client as the sender.
@@ -613,10 +606,16 @@ private[netty] class NettyRpcHandler(
}
}
- override def connectionTerminated(client: TransportClient): Unit = {
+ override def channelActive(client: TransportClient): Unit = {
+ val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
+ assert(addr != null)
+ val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
+ dispatcher.postToAll(RemoteProcessConnected(clientAddr))
+ }
+
+ override def channelInactive(client: TransportClient): Unit = {
val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
if (addr != null) {
- clients.remove(client)
val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
nettyEnv.removeOutbox(clientAddr)
dispatcher.postToAll(RemoteProcessDisconnected(clientAddr))
diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
index 49e3e0191c..7b3a17c172 100644
--- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
@@ -484,10 +484,16 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
}
}
- test("network events") {
+ /**
+ * Setup an [[RpcEndpoint]] to collect all network events.
+ * @return the [[RpcEndpointRef]] and an `Seq` that contains network events.
+ */
+ private def setupNetworkEndpoint(
+ _env: RpcEnv,
+ name: String): (RpcEndpointRef, Seq[(Any, Any)]) = {
val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)]
- env.setupEndpoint("network-events", new ThreadSafeRpcEndpoint {
- override val rpcEnv = env
+ val ref = _env.setupEndpoint("network-events-non-client", new ThreadSafeRpcEndpoint {
+ override val rpcEnv = _env
override def receive: PartialFunction[Any, Unit] = {
case "hello" =>
@@ -507,83 +513,97 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
}
})
+ (ref, events)
+ }
- val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true)
- // Use anotherEnv to find out the RpcEndpointRef
- val rpcEndpointRef = anotherEnv.setupEndpointRef(
- "local", env.address, "network-events")
- val remoteAddress = anotherEnv.address
- rpcEndpointRef.send("hello")
- eventually(timeout(5 seconds), interval(5 millis)) {
- // anotherEnv is connected in client mode, so the remote address may be unknown depending on
- // the implementation. Account for that when doing checks.
- if (remoteAddress != null) {
- assert(events === List(("onConnected", remoteAddress)))
- } else {
- assert(events.size === 1)
- assert(events(0)._1 === "onConnected")
+ test("network events in sever RpcEnv when another RpcEnv is in server mode") {
+ val serverEnv1 = createRpcEnv(new SparkConf(), "server1", 0, clientMode = false)
+ val serverEnv2 = createRpcEnv(new SparkConf(), "server2", 0, clientMode = false)
+ val (_, events) = setupNetworkEndpoint(serverEnv1, "network-events")
+ val (serverRef2, _) = setupNetworkEndpoint(serverEnv2, "network-events")
+ try {
+ val serverRefInServer2 =
+ serverEnv1.setupEndpointRef("server2", serverRef2.address, serverRef2.name)
+ // Send a message to set up the connection
+ serverRefInServer2.send("hello")
+
+ eventually(timeout(5 seconds), interval(5 millis)) {
+ assert(events.contains(("onConnected", serverEnv2.address)))
}
- }
- anotherEnv.shutdown()
- anotherEnv.awaitTermination()
- eventually(timeout(5 seconds), interval(5 millis)) {
- // Account for anotherEnv not having an address due to running in client mode.
- if (remoteAddress != null) {
- assert(events === List(
- ("onConnected", remoteAddress),
- ("onNetworkError", remoteAddress),
- ("onDisconnected", remoteAddress)) ||
- events === List(
- ("onConnected", remoteAddress),
- ("onDisconnected", remoteAddress)))
- } else {
- val eventNames = events.map(_._1)
- assert(eventNames === List("onConnected", "onNetworkError", "onDisconnected") ||
- eventNames === List("onConnected", "onDisconnected"))
+ serverEnv2.shutdown()
+ serverEnv2.awaitTermination()
+
+ eventually(timeout(5 seconds), interval(5 millis)) {
+ assert(events.contains(("onConnected", serverEnv2.address)))
+ assert(events.contains(("onDisconnected", serverEnv2.address)))
}
+ } finally {
+ serverEnv1.shutdown()
+ serverEnv2.shutdown()
+ serverEnv1.awaitTermination()
+ serverEnv2.awaitTermination()
}
}
- test("network events between non-client-mode RpcEnvs") {
- val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)]
- env.setupEndpoint("network-events-non-client", new ThreadSafeRpcEndpoint {
- override val rpcEnv = env
+ test("network events in sever RpcEnv when another RpcEnv is in client mode") {
+ val serverEnv = createRpcEnv(new SparkConf(), "server", 0, clientMode = false)
+ val (serverRef, events) = setupNetworkEndpoint(serverEnv, "network-events")
+ val clientEnv = createRpcEnv(new SparkConf(), "client", 0, clientMode = true)
+ try {
+ val serverRefInClient =
+ clientEnv.setupEndpointRef("server", serverRef.address, serverRef.name)
+ // Send a message to set up the connection
+ serverRefInClient.send("hello")
- override def receive: PartialFunction[Any, Unit] = {
- case "hello" =>
- case m => events += "receive" -> m
+ eventually(timeout(5 seconds), interval(5 millis)) {
+ // We don't know the exact client address but at least we can verify the message type
+ assert(events.map(_._1).contains("onConnected"))
}
- override def onConnected(remoteAddress: RpcAddress): Unit = {
- events += "onConnected" -> remoteAddress
- }
+ clientEnv.shutdown()
+ clientEnv.awaitTermination()
- override def onDisconnected(remoteAddress: RpcAddress): Unit = {
- events += "onDisconnected" -> remoteAddress
+ eventually(timeout(5 seconds), interval(5 millis)) {
+ // We don't know the exact client address but at least we can verify the message type
+ assert(events.map(_._1).contains("onConnected"))
+ assert(events.map(_._1).contains("onDisconnected"))
}
+ } finally {
+ clientEnv.shutdown()
+ serverEnv.shutdown()
+ clientEnv.awaitTermination()
+ serverEnv.awaitTermination()
+ }
+ }
- override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = {
- events += "onNetworkError" -> remoteAddress
- }
+ test("network events in client RpcEnv when another RpcEnv is in server mode") {
+ val clientEnv = createRpcEnv(new SparkConf(), "client", 0, clientMode = true)
+ val serverEnv = createRpcEnv(new SparkConf(), "server", 0, clientMode = false)
+ val (_, events) = setupNetworkEndpoint(clientEnv, "network-events")
+ val (serverRef, _) = setupNetworkEndpoint(serverEnv, "network-events")
+ try {
+ val serverRefInClient =
+ clientEnv.setupEndpointRef("server", serverRef.address, serverRef.name)
+ // Send a message to set up the connection
+ serverRefInClient.send("hello")
- })
+ eventually(timeout(5 seconds), interval(5 millis)) {
+ assert(events.contains(("onConnected", serverEnv.address)))
+ }
- val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = false)
- // Use anotherEnv to find out the RpcEndpointRef
- val rpcEndpointRef = anotherEnv.setupEndpointRef(
- "local", env.address, "network-events-non-client")
- val remoteAddress = anotherEnv.address
- rpcEndpointRef.send("hello")
- eventually(timeout(5 seconds), interval(5 millis)) {
- assert(events.contains(("onConnected", remoteAddress)))
- }
+ serverEnv.shutdown()
+ serverEnv.awaitTermination()
- anotherEnv.shutdown()
- anotherEnv.awaitTermination()
- eventually(timeout(5 seconds), interval(5 millis)) {
- assert(events.contains(("onConnected", remoteAddress)))
- assert(events.contains(("onDisconnected", remoteAddress)))
+ eventually(timeout(5 seconds), interval(5 millis)) {
+ assert(events.contains(("onConnected", serverEnv.address)))
+ assert(events.contains(("onDisconnected", serverEnv.address)))
+ }
+ } finally {
+ clientEnv.shutdown()
+ serverEnv.shutdown()
+ clientEnv.awaitTermination()
+ serverEnv.awaitTermination()
}
}
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
index ebd6f70071..d4aebe9fd9 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
@@ -43,7 +43,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite {
val channel = mock(classOf[Channel])
val client = new TransportClient(channel, mock(classOf[TransportResponseHandler]))
when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000))
- nettyRpcHandler.receive(client, null, null)
+ nettyRpcHandler.channelActive(client)
verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000)))
}
@@ -55,10 +55,10 @@ class NettyRpcHandlerSuite extends SparkFunSuite {
val channel = mock(classOf[Channel])
val client = new TransportClient(channel, mock(classOf[TransportResponseHandler]))
when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000))
- nettyRpcHandler.receive(client, null, null)
+ nettyRpcHandler.channelActive(client)
when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000))
- nettyRpcHandler.connectionTerminated(client)
+ nettyRpcHandler.channelInactive(client)
verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000)))
verify(dispatcher, times(1)).postToAll(