aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-10-13 09:51:20 -0700
committerReynold Xin <rxin@databricks.com>2015-10-13 09:51:20 -0700
commit1797055dbf1d2fd7714d7c65c8d2efde2f15efc1 (patch)
treea2fd05a8ba259c25dd01bf0b4af48c2466a39b83
parent6987c067937a50867b4d5788f5bf496ecdfdb62c (diff)
downloadspark-1797055dbf1d2fd7714d7c65c8d2efde2f15efc1.tar.gz
spark-1797055dbf1d2fd7714d7c65c8d2efde2f15efc1.tar.bz2
spark-1797055dbf1d2fd7714d7c65c8d2efde2f15efc1.zip
[SPARK-11079] Post-hoc review Netty-based RPC - round 1
I'm going through the implementation right now for post-doc review. Adding more comments and renaming things as I go through them. I also want to write higher level documentation about how the whole thing works -- but those will come in other pull requests. Author: Reynold Xin <rxin@databricks.com> Closes #9091 from rxin/rpc-review.
-rw-r--r--core/src/main/scala/org/apache/spark/MapOutputTracker.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/RpcAddress.scala50
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala153
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala131
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala108
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala119
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala38
-rw-r--r--core/src/main/scala/org/apache/spark/util/ThreadUtils.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala1
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala7
15 files changed, 336 insertions, 302 deletions
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 45e12e40c8..72355cdfa6 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -48,7 +48,7 @@ private[spark] class MapOutputTrackerMasterEndpoint(
val hostPort = context.senderAddress.hostPort
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId)
- val serializedSize = mapOutputStatuses.size
+ val serializedSize = mapOutputStatuses.length
if (serializedSize > maxAkkaFrameSize) {
val msg = s"Map output statuses were $serializedSize bytes which " +
s"exceeds spark.akka.frameSize ($maxAkkaFrameSize bytes)."
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcAddress.scala b/core/src/main/scala/org/apache/spark/rpc/RpcAddress.scala
new file mode 100644
index 0000000000..eb0b26947f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcAddress.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rpc
+
+import org.apache.spark.util.Utils
+
+
+/**
+ * Address for an RPC environment, with hostname and port.
+ */
+private[spark] case class RpcAddress(host: String, port: Int) {
+
+ def hostPort: String = host + ":" + port
+
+ /** Returns a string in the form of "spark://host:port". */
+ def toSparkURL: String = "spark://" + hostPort
+
+ override def toString: String = hostPort
+}
+
+
+private[spark] object RpcAddress {
+
+ /** Return the [[RpcAddress]] represented by `uri`. */
+ def fromURIString(uri: String): RpcAddress = {
+ val uriObj = new java.net.URI(uri)
+ RpcAddress(uriObj.getHost, uriObj.getPort)
+ }
+
+ /** Returns the [[RpcAddress]] encoded in the form of "spark://host:port" */
+ def fromSparkURL(sparkUrl: String): RpcAddress = {
+ val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl)
+ RpcAddress(host, port)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala
index f1ddc6d2cd..0ba9516952 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala
@@ -145,5 +145,4 @@ private[spark] trait RpcEndpoint {
* However, there is no guarantee that the same thread will be executing the same
* [[ThreadSafeRpcEndpoint]] for different messages.
*/
-private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint {
-}
+private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
index 35e402c725..ef491a0ae4 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
@@ -17,12 +17,7 @@
package org.apache.spark.rpc
-import java.net.URI
-import java.util.concurrent.TimeoutException
-
-import scala.concurrent.{Awaitable, Await, Future}
-import scala.concurrent.duration._
-import scala.language.postfixOps
+import scala.concurrent.Future
import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.util.{RpcUtils, Utils}
@@ -35,8 +30,8 @@ import org.apache.spark.util.{RpcUtils, Utils}
private[spark] object RpcEnv {
private def getRpcEnvFactory(conf: SparkConf): RpcEnvFactory = {
- // Add more RpcEnv implementations here
- val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory",
+ val rpcEnvNames = Map(
+ "akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory",
"netty" -> "org.apache.spark.rpc.netty.NettyRpcEnvFactory")
val rpcEnvName = conf.get("spark.rpc", "netty")
val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName)
@@ -53,7 +48,6 @@ private[spark] object RpcEnv {
val config = RpcEnvConfig(conf, name, host, port, securityManager)
getRpcEnvFactory(conf).create(config)
}
-
}
@@ -155,144 +149,3 @@ private[spark] case class RpcEnvConfig(
host: String,
port: Int,
securityManager: SecurityManager)
-
-
-/**
- * Represents a host and port.
- */
-private[spark] case class RpcAddress(host: String, port: Int) {
- // TODO do we need to add the type of RpcEnv in the address?
-
- val hostPort: String = host + ":" + port
-
- override val toString: String = hostPort
-
- def toSparkURL: String = "spark://" + hostPort
-}
-
-
-private[spark] object RpcAddress {
-
- /**
- * Return the [[RpcAddress]] represented by `uri`.
- */
- def fromURI(uri: URI): RpcAddress = {
- RpcAddress(uri.getHost, uri.getPort)
- }
-
- /**
- * Return the [[RpcAddress]] represented by `uri`.
- */
- def fromURIString(uri: String): RpcAddress = {
- fromURI(new java.net.URI(uri))
- }
-
- def fromSparkURL(sparkUrl: String): RpcAddress = {
- val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl)
- RpcAddress(host, port)
- }
-}
-
-
-/**
- * An exception thrown if RpcTimeout modifies a [[TimeoutException]].
- */
-private[rpc] class RpcTimeoutException(message: String, cause: TimeoutException)
- extends TimeoutException(message) { initCause(cause) }
-
-
-/**
- * Associates a timeout with a description so that a when a TimeoutException occurs, additional
- * context about the timeout can be amended to the exception message.
- * @param duration timeout duration in seconds
- * @param timeoutProp the configuration property that controls this timeout
- */
-private[spark] class RpcTimeout(val duration: FiniteDuration, val timeoutProp: String)
- extends Serializable {
-
- /** Amends the standard message of TimeoutException to include the description */
- private def createRpcTimeoutException(te: TimeoutException): RpcTimeoutException = {
- new RpcTimeoutException(te.getMessage() + ". This timeout is controlled by " + timeoutProp, te)
- }
-
- /**
- * PartialFunction to match a TimeoutException and add the timeout description to the message
- *
- * @note This can be used in the recover callback of a Future to add to a TimeoutException
- * Example:
- * val timeout = new RpcTimeout(5 millis, "short timeout")
- * Future(throw new TimeoutException).recover(timeout.addMessageIfTimeout)
- */
- def addMessageIfTimeout[T]: PartialFunction[Throwable, T] = {
- // The exception has already been converted to a RpcTimeoutException so just raise it
- case rte: RpcTimeoutException => throw rte
- // Any other TimeoutException get converted to a RpcTimeoutException with modified message
- case te: TimeoutException => throw createRpcTimeoutException(te)
- }
-
- /**
- * Wait for the completed result and return it. If the result is not available within this
- * timeout, throw a [[RpcTimeoutException]] to indicate which configuration controls the timeout.
- * @param awaitable the `Awaitable` to be awaited
- * @throws RpcTimeoutException if after waiting for the specified time `awaitable`
- * is still not ready
- */
- def awaitResult[T](awaitable: Awaitable[T]): T = {
- try {
- Await.result(awaitable, duration)
- } catch addMessageIfTimeout
- }
-}
-
-
-private[spark] object RpcTimeout {
-
- /**
- * Lookup the timeout property in the configuration and create
- * a RpcTimeout with the property key in the description.
- * @param conf configuration properties containing the timeout
- * @param timeoutProp property key for the timeout in seconds
- * @throws NoSuchElementException if property is not set
- */
- def apply(conf: SparkConf, timeoutProp: String): RpcTimeout = {
- val timeout = { conf.getTimeAsSeconds(timeoutProp) seconds }
- new RpcTimeout(timeout, timeoutProp)
- }
-
- /**
- * Lookup the timeout property in the configuration and create
- * a RpcTimeout with the property key in the description.
- * Uses the given default value if property is not set
- * @param conf configuration properties containing the timeout
- * @param timeoutProp property key for the timeout in seconds
- * @param defaultValue default timeout value in seconds if property not found
- */
- def apply(conf: SparkConf, timeoutProp: String, defaultValue: String): RpcTimeout = {
- val timeout = { conf.getTimeAsSeconds(timeoutProp, defaultValue) seconds }
- new RpcTimeout(timeout, timeoutProp)
- }
-
- /**
- * Lookup prioritized list of timeout properties in the configuration
- * and create a RpcTimeout with the first set property key in the
- * description.
- * Uses the given default value if property is not set
- * @param conf configuration properties containing the timeout
- * @param timeoutPropList prioritized list of property keys for the timeout in seconds
- * @param defaultValue default timeout value in seconds if no properties found
- */
- def apply(conf: SparkConf, timeoutPropList: Seq[String], defaultValue: String): RpcTimeout = {
- require(timeoutPropList.nonEmpty)
-
- // Find the first set property or use the default value with the first property
- val itr = timeoutPropList.iterator
- var foundProp: Option[(String, String)] = None
- while (itr.hasNext && foundProp.isEmpty){
- val propKey = itr.next()
- conf.getOption(propKey).foreach { prop => foundProp = Some(propKey, prop) }
- }
- val finalProp = foundProp.getOrElse(timeoutPropList.head, defaultValue)
- val timeout = { Utils.timeStringAsSeconds(finalProp._2) seconds }
- new RpcTimeout(timeout, finalProp._1)
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala
new file mode 100644
index 0000000000..285786ebf9
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rpc
+
+import java.util.concurrent.TimeoutException
+
+import scala.concurrent.{Awaitable, Await}
+import scala.concurrent.duration._
+
+import org.apache.spark.SparkConf
+import org.apache.spark.util.Utils
+
+
+/**
+ * An exception thrown if RpcTimeout modifies a [[TimeoutException]].
+ */
+private[rpc] class RpcTimeoutException(message: String, cause: TimeoutException)
+ extends TimeoutException(message) { initCause(cause) }
+
+
+/**
+ * Associates a timeout with a description so that a when a TimeoutException occurs, additional
+ * context about the timeout can be amended to the exception message.
+ *
+ * @param duration timeout duration in seconds
+ * @param timeoutProp the configuration property that controls this timeout
+ */
+private[spark] class RpcTimeout(val duration: FiniteDuration, val timeoutProp: String)
+ extends Serializable {
+
+ /** Amends the standard message of TimeoutException to include the description */
+ private def createRpcTimeoutException(te: TimeoutException): RpcTimeoutException = {
+ new RpcTimeoutException(te.getMessage + ". This timeout is controlled by " + timeoutProp, te)
+ }
+
+ /**
+ * PartialFunction to match a TimeoutException and add the timeout description to the message
+ *
+ * @note This can be used in the recover callback of a Future to add to a TimeoutException
+ * Example:
+ * val timeout = new RpcTimeout(5 millis, "short timeout")
+ * Future(throw new TimeoutException).recover(timeout.addMessageIfTimeout)
+ */
+ def addMessageIfTimeout[T]: PartialFunction[Throwable, T] = {
+ // The exception has already been converted to a RpcTimeoutException so just raise it
+ case rte: RpcTimeoutException => throw rte
+ // Any other TimeoutException get converted to a RpcTimeoutException with modified message
+ case te: TimeoutException => throw createRpcTimeoutException(te)
+ }
+
+ /**
+ * Wait for the completed result and return it. If the result is not available within this
+ * timeout, throw a [[RpcTimeoutException]] to indicate which configuration controls the timeout.
+ * @param awaitable the `Awaitable` to be awaited
+ * @throws RpcTimeoutException if after waiting for the specified time `awaitable`
+ * is still not ready
+ */
+ def awaitResult[T](awaitable: Awaitable[T]): T = {
+ try {
+ Await.result(awaitable, duration)
+ } catch addMessageIfTimeout
+ }
+}
+
+
+private[spark] object RpcTimeout {
+
+ /**
+ * Lookup the timeout property in the configuration and create
+ * a RpcTimeout with the property key in the description.
+ * @param conf configuration properties containing the timeout
+ * @param timeoutProp property key for the timeout in seconds
+ * @throws NoSuchElementException if property is not set
+ */
+ def apply(conf: SparkConf, timeoutProp: String): RpcTimeout = {
+ val timeout = { conf.getTimeAsSeconds(timeoutProp).seconds }
+ new RpcTimeout(timeout, timeoutProp)
+ }
+
+ /**
+ * Lookup the timeout property in the configuration and create
+ * a RpcTimeout with the property key in the description.
+ * Uses the given default value if property is not set
+ * @param conf configuration properties containing the timeout
+ * @param timeoutProp property key for the timeout in seconds
+ * @param defaultValue default timeout value in seconds if property not found
+ */
+ def apply(conf: SparkConf, timeoutProp: String, defaultValue: String): RpcTimeout = {
+ val timeout = { conf.getTimeAsSeconds(timeoutProp, defaultValue).seconds }
+ new RpcTimeout(timeout, timeoutProp)
+ }
+
+ /**
+ * Lookup prioritized list of timeout properties in the configuration
+ * and create a RpcTimeout with the first set property key in the
+ * description.
+ * Uses the given default value if property is not set
+ * @param conf configuration properties containing the timeout
+ * @param timeoutPropList prioritized list of property keys for the timeout in seconds
+ * @param defaultValue default timeout value in seconds if no properties found
+ */
+ def apply(conf: SparkConf, timeoutPropList: Seq[String], defaultValue: String): RpcTimeout = {
+ require(timeoutPropList.nonEmpty)
+
+ // Find the first set property or use the default value with the first property
+ val itr = timeoutPropList.iterator
+ var foundProp: Option[(String, String)] = None
+ while (itr.hasNext && foundProp.isEmpty){
+ val propKey = itr.next()
+ conf.getOption(propKey).foreach { prop => foundProp = Some(propKey, prop) }
+ }
+ val finalProp = foundProp.getOrElse(timeoutPropList.head, defaultValue)
+ val timeout = { Utils.timeStringAsSeconds(finalProp._2).seconds }
+ new RpcTimeout(timeout, finalProp._1)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
index 95132a4e4a..3fad595a0d 100644
--- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
@@ -39,10 +39,6 @@ import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils}
*
* TODO Once we remove all usages of Akka in other place, we can move this file to a new project and
* remove Akka from the dependencies.
- *
- * @param actorSystem
- * @param conf
- * @param boundPort
*/
private[spark] class AkkaRpcEnv private[akka] (
val actorSystem: ActorSystem, conf: SparkConf, boundPort: Int)
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
index d71e6f01db..398e9eafc1 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
@@ -17,7 +17,7 @@
package org.apache.spark.rpc.netty
-import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, TimeUnit}
+import java.util.concurrent.{ThreadPoolExecutor, ConcurrentHashMap, LinkedBlockingQueue, TimeUnit}
import javax.annotation.concurrent.GuardedBy
import scala.collection.JavaConverters._
@@ -38,12 +38,16 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
val inbox = new Inbox(ref, endpoint)
}
- private val endpoints = new ConcurrentHashMap[String, EndpointData]()
- private val endpointRefs = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]()
+ private val endpoints = new ConcurrentHashMap[String, EndpointData]
+ private val endpointRefs = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]
// Track the receivers whose inboxes may contain messages.
private val receivers = new LinkedBlockingQueue[EndpointData]()
+ /**
+ * True if the dispatcher has been stopped. Once stopped, all messages posted will be bounced
+ * immediately.
+ */
@GuardedBy("this")
private var stopped = false
@@ -59,7 +63,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
}
val data = endpoints.get(name)
endpointRefs.put(data.endpoint, data.ref)
- receivers.put(data)
+ receivers.put(data) // for the OnStart message
}
endpointRef
}
@@ -73,7 +77,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
val data = endpoints.remove(name)
if (data != null) {
data.inbox.stop()
- receivers.put(data)
+ receivers.put(data) // for the OnStop message
}
// Don't clean `endpointRefs` here because it's possible that some messages are being processed
// now and they can use `getRpcEndpointRef`. So `endpointRefs` will be cleaned in Inbox via
@@ -91,19 +95,23 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
}
/**
- * Send a message to all registered [[RpcEndpoint]]s.
- * @param message
+ * Send a message to all registered [[RpcEndpoint]]s in this process.
+ *
+ * This can be used to make network events known to all end points (e.g. "a new node connected").
*/
- def broadcastMessage(message: InboxMessage): Unit = {
+ def postToAll(message: InboxMessage): Unit = {
val iter = endpoints.keySet().iterator()
while (iter.hasNext) {
val name = iter.next
- postMessageToInbox(name, (_) => message,
- () => { logWarning(s"Drop ${message} because ${name} has been stopped") })
+ postMessage(
+ name,
+ _ => message,
+ () => { logWarning(s"Drop $message because $name has been stopped") })
}
}
- def postMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = {
+ /** Posts a message sent by a remote endpoint. */
+ def postRemoteMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = {
def createMessage(sender: NettyRpcEndpointRef): InboxMessage = {
val rpcCallContext =
new RemoteNettyRpcCallContext(
@@ -116,10 +124,11 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
new SparkException(s"Could not find ${message.receiver.name} or it has been stopped"))
}
- postMessageToInbox(message.receiver.name, createMessage, onEndpointStopped)
+ postMessage(message.receiver.name, createMessage, onEndpointStopped)
}
- def postMessage(message: RequestMessage, p: Promise[Any]): Unit = {
+ /** Posts a message sent by a local endpoint. */
+ def postLocalMessage(message: RequestMessage, p: Promise[Any]): Unit = {
def createMessage(sender: NettyRpcEndpointRef): InboxMessage = {
val rpcCallContext =
new LocalNettyRpcCallContext(sender, message.senderAddress, message.needReply, p)
@@ -131,39 +140,36 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
new SparkException(s"Could not find ${message.receiver.name} or it has been stopped"))
}
- postMessageToInbox(message.receiver.name, createMessage, onEndpointStopped)
+ postMessage(message.receiver.name, createMessage, onEndpointStopped)
}
- private def postMessageToInbox(
+ /**
+ * Posts a message to a specific endpoint.
+ *
+ * @param endpointName name of the endpoint.
+ * @param createMessageFn function to create the message.
+ * @param callbackIfStopped callback function if the endpoint is stopped.
+ */
+ private def postMessage(
endpointName: String,
createMessageFn: NettyRpcEndpointRef => InboxMessage,
- onStopped: () => Unit): Unit = {
- val shouldCallOnStop =
- synchronized {
- val data = endpoints.get(endpointName)
- if (stopped || data == null) {
- true
- } else {
- data.inbox.post(createMessageFn(data.ref))
- receivers.put(data)
- false
- }
+ callbackIfStopped: () => Unit): Unit = {
+ val shouldCallOnStop = synchronized {
+ val data = endpoints.get(endpointName)
+ if (stopped || data == null) {
+ true
+ } else {
+ data.inbox.post(createMessageFn(data.ref))
+ receivers.put(data)
+ false
}
+ }
if (shouldCallOnStop) {
// We don't need to call `onStop` in the `synchronized` block
- onStopped()
+ callbackIfStopped()
}
}
- private val parallelism = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.parallelism",
- Runtime.getRuntime.availableProcessors())
-
- private val executor = ThreadUtils.newDaemonFixedThreadPool(parallelism, "dispatcher-event-loop")
-
- (0 until parallelism) foreach { _ =>
- executor.execute(new MessageLoop)
- }
-
def stop(): Unit = {
synchronized {
if (stopped) {
@@ -174,12 +180,12 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
// Stop all endpoints. This will queue all endpoints for processing by the message loops.
endpoints.keySet().asScala.foreach(unregisterRpcEndpoint)
// Enqueue a message that tells the message loops to stop.
- receivers.put(PoisonEndpoint)
- executor.shutdown()
+ receivers.put(PoisonPill)
+ threadpool.shutdown()
}
def awaitTermination(): Unit = {
- executor.awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS)
+ threadpool.awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS)
}
/**
@@ -189,15 +195,27 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
endpoints.containsKey(name)
}
+ /** Thread pool used for dispatching messages. */
+ private val threadpool: ThreadPoolExecutor = {
+ val numThreads = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.numThreads",
+ Runtime.getRuntime.availableProcessors())
+ val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop")
+ for (i <- 0 until numThreads) {
+ pool.execute(new MessageLoop)
+ }
+ pool
+ }
+
+ /** Message loop used for dispatching messages. */
private class MessageLoop extends Runnable {
override def run(): Unit = {
try {
while (true) {
try {
val data = receivers.take()
- if (data == PoisonEndpoint) {
- // Put PoisonEndpoint back so that other MessageLoops can see it.
- receivers.put(PoisonEndpoint)
+ if (data == PoisonPill) {
+ // Put PoisonPill back so that other MessageLoops can see it.
+ receivers.put(PoisonPill)
return
}
data.inbox.process(Dispatcher.this)
@@ -211,8 +229,6 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
}
}
- /**
- * A poison endpoint that indicates MessageLoop should exit its loop.
- */
- private val PoisonEndpoint = new EndpointData(null, null, null)
+ /** A poison endpoint that indicates MessageLoop should exit its message loop. */
+ private val PoisonPill = new EndpointData(null, null, null)
}
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala b/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala
index 6061c9b8de..fa9a3eb99b 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala
@@ -26,8 +26,8 @@ private[netty] case class ID(name: String)
/**
* An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if a [[RpcEndpoint]] exists in this [[RpcEnv]]
*/
-private[netty] class IDVerifier(
- override val rpcEnv: RpcEnv, dispatcher: Dispatcher) extends RpcEndpoint {
+private[netty] class IDVerifier(override val rpcEnv: RpcEnv, dispatcher: Dispatcher)
+ extends RpcEndpoint {
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case ID(name) => context.reply(dispatcher.verify(name))
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala
index b669f59a28..c72b588db5 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala
@@ -17,14 +17,16 @@
package org.apache.spark.rpc.netty
-import java.util.LinkedList
import javax.annotation.concurrent.GuardedBy
import scala.util.control.NonFatal
+import com.google.common.annotations.VisibleForTesting
+
import org.apache.spark.{Logging, SparkException}
import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, ThreadSafeRpcEndpoint}
+
private[netty] sealed trait InboxMessage
private[netty] case class ContentMessage(
@@ -37,44 +39,40 @@ private[netty] case object OnStart extends InboxMessage
private[netty] case object OnStop extends InboxMessage
-/**
- * A broadcast message that indicates connecting to a remote node.
- */
-private[netty] case class Associated(remoteAddress: RpcAddress) extends InboxMessage
+/** A message to tell all endpoints that a remote process has connected. */
+private[netty] case class RemoteProcessConnected(remoteAddress: RpcAddress) extends InboxMessage
-/**
- * A broadcast message that indicates a remote connection is lost.
- */
-private[netty] case class Disassociated(remoteAddress: RpcAddress) extends InboxMessage
+/** A message to tell all endpoints that a remote process has disconnected. */
+private[netty] case class RemoteProcessDisconnected(remoteAddress: RpcAddress) extends InboxMessage
-/**
- * A broadcast message that indicates a network error
- */
-private[netty] case class AssociationError(cause: Throwable, remoteAddress: RpcAddress)
+/** A message to tell all endpoints that a network error has happened. */
+private[netty] case class RemoteProcessConnectionError(cause: Throwable, remoteAddress: RpcAddress)
extends InboxMessage
/**
* A inbox that stores messages for an [[RpcEndpoint]] and posts messages to it thread-safely.
- * @param endpointRef
- * @param endpoint
*/
private[netty] class Inbox(
val endpointRef: NettyRpcEndpointRef,
- val endpoint: RpcEndpoint) extends Logging {
+ val endpoint: RpcEndpoint)
+ extends Logging {
- inbox =>
+ inbox => // Give this an alias so we can use it more clearly in closures.
@GuardedBy("this")
- protected val messages = new LinkedList[InboxMessage]()
+ protected val messages = new java.util.LinkedList[InboxMessage]()
+ /** True if the inbox (and its associated endpoint) is stopped. */
@GuardedBy("this")
private var stopped = false
+ /** Allow multiple threads to process messages at the same time. */
@GuardedBy("this")
private var enableConcurrent = false
+ /** The number of threads processing messages for this inbox. */
@GuardedBy("this")
- private var workerCount = 0
+ private var numActiveThreads = 0
// OnStart should be the first message to process
inbox.synchronized {
@@ -87,12 +85,12 @@ private[netty] class Inbox(
def process(dispatcher: Dispatcher): Unit = {
var message: InboxMessage = null
inbox.synchronized {
- if (!enableConcurrent && workerCount != 0) {
+ if (!enableConcurrent && numActiveThreads != 0) {
return
}
message = messages.poll()
if (message != null) {
- workerCount += 1
+ numActiveThreads += 1
} else {
return
}
@@ -101,15 +99,11 @@ private[netty] class Inbox(
safelyCall(endpoint) {
message match {
case ContentMessage(_sender, content, needReply, context) =>
- val pf: PartialFunction[Any, Unit] =
- if (needReply) {
- endpoint.receiveAndReply(context)
- } else {
- endpoint.receive
- }
+ // The partial function to call
+ val pf = if (needReply) endpoint.receiveAndReply(context) else endpoint.receive
try {
pf.applyOrElse[Any, Unit](content, { msg =>
- throw new SparkException(s"Unmatched message $message from ${_sender}")
+ throw new SparkException(s"Unsupported message $message from ${_sender}")
})
if (!needReply) {
context.finish()
@@ -121,11 +115,13 @@ private[netty] class Inbox(
context.sendFailure(e)
} else {
context.finish()
- throw e
}
+ // Throw the exception -- this exception will be caught by the safelyCall function.
+ // The endpoint's onError function will be called.
+ throw e
}
- case OnStart => {
+ case OnStart =>
endpoint.onStart()
if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
inbox.synchronized {
@@ -134,24 +130,22 @@ private[netty] class Inbox(
}
}
}
- }
case OnStop =>
- val _workCount = inbox.synchronized {
- workerCount
- }
- assert(_workCount == 1, s"There should be only one worker but was ${_workCount}")
+ val activeThreads = inbox.synchronized { inbox.numActiveThreads }
+ assert(activeThreads == 1,
+ s"There should be only a single active thread but found $activeThreads threads.")
dispatcher.removeRpcEndpointRef(endpoint)
endpoint.onStop()
assert(isEmpty, "OnStop should be the last message")
- case Associated(remoteAddress) =>
+ case RemoteProcessConnected(remoteAddress) =>
endpoint.onConnected(remoteAddress)
- case Disassociated(remoteAddress) =>
+ case RemoteProcessDisconnected(remoteAddress) =>
endpoint.onDisconnected(remoteAddress)
- case AssociationError(cause, remoteAddress) =>
+ case RemoteProcessConnectionError(cause, remoteAddress) =>
endpoint.onNetworkError(cause, remoteAddress)
}
}
@@ -159,33 +153,27 @@ private[netty] class Inbox(
inbox.synchronized {
// "enableConcurrent" will be set to false after `onStop` is called, so we should check it
// every time.
- if (!enableConcurrent && workerCount != 1) {
+ if (!enableConcurrent && numActiveThreads != 1) {
// If we are not the only one worker, exit
- workerCount -= 1
+ numActiveThreads -= 1
return
}
message = messages.poll()
if (message == null) {
- workerCount -= 1
+ numActiveThreads -= 1
return
}
}
}
}
- def post(message: InboxMessage): Unit = {
- val dropped =
- inbox.synchronized {
- if (stopped) {
- // We already put "OnStop" into "messages", so we should drop further messages
- true
- } else {
- messages.add(message)
- false
- }
- }
- if (dropped) {
+ def post(message: InboxMessage): Unit = inbox.synchronized {
+ if (stopped) {
+ // We already put "OnStop" into "messages", so we should drop further messages
onDrop(message)
+ } else {
+ messages.add(message)
+ false
}
}
@@ -203,24 +191,23 @@ private[netty] class Inbox(
}
}
- // Visible for testing.
+ def isEmpty: Boolean = inbox.synchronized { messages.isEmpty }
+
+ /** Called when we are dropping a message. Test cases override this to test message dropping. */
+ @VisibleForTesting
protected def onDrop(message: InboxMessage): Unit = {
- logWarning(s"Drop ${message} because $endpointRef is stopped")
+ logWarning(s"Drop $message because $endpointRef is stopped")
}
- def isEmpty: Boolean = inbox.synchronized { messages.isEmpty }
-
+ /**
+ * Calls action closure, and calls the endpoint's onError function in the case of exceptions.
+ */
private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = {
- try {
- action
- } catch {
- case NonFatal(e) => {
- try {
- endpoint.onError(e)
- } catch {
- case NonFatal(e) => logWarning(s"Ignore error", e)
+ try action catch {
+ case NonFatal(e) =>
+ try endpoint.onError(e) catch {
+ case NonFatal(ee) => logError(s"Ignoring error", ee)
}
- }
}
}
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala
index 75dcc02a0c..21d5bb4923 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala
@@ -26,7 +26,8 @@ import org.apache.spark.rpc.{RpcAddress, RpcCallContext}
private[netty] abstract class NettyRpcCallContext(
endpointRef: NettyRpcEndpointRef,
override val senderAddress: RpcAddress,
- needReply: Boolean) extends RpcCallContext with Logging {
+ needReply: Boolean)
+ extends RpcCallContext with Logging {
protected def send(message: Any): Unit
@@ -35,7 +36,7 @@ private[netty] abstract class NettyRpcCallContext(
send(AskResponse(endpointRef, response))
} else {
throw new IllegalStateException(
- s"Cannot send $response to the sender because the sender won't handle it")
+ s"Cannot send $response to the sender because the sender does not expect a reply")
}
}
@@ -63,7 +64,8 @@ private[netty] class LocalNettyRpcCallContext(
endpointRef: NettyRpcEndpointRef,
senderAddress: RpcAddress,
needReply: Boolean,
- p: Promise[Any]) extends NettyRpcCallContext(endpointRef, senderAddress, needReply) {
+ p: Promise[Any])
+ extends NettyRpcCallContext(endpointRef, senderAddress, needReply) {
override protected def send(message: Any): Unit = {
p.success(message)
@@ -78,7 +80,8 @@ private[netty] class RemoteNettyRpcCallContext(
endpointRef: NettyRpcEndpointRef,
callback: RpcResponseCallback,
senderAddress: RpcAddress,
- needReply: Boolean) extends NettyRpcCallContext(endpointRef, senderAddress, needReply) {
+ needReply: Boolean)
+ extends NettyRpcCallContext(endpointRef, senderAddress, needReply) {
override protected def send(message: Any): Unit = {
val reply = nettyEnv.serialize(message)
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
index 5522b40782..89b6df76c2 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
@@ -19,7 +19,6 @@ package org.apache.spark.rpc.netty
import java.io._
import java.net.{InetSocketAddress, URI}
import java.nio.ByteBuffer
-import java.util.Arrays
import java.util.concurrent._
import javax.annotation.concurrent.GuardedBy
@@ -77,19 +76,19 @@ private[netty] class NettyRpcEnv(
@volatile private var server: TransportServer = _
def start(port: Int): Unit = {
- val bootstraps: Seq[TransportServerBootstrap] =
+ val bootstraps: java.util.List[TransportServerBootstrap] =
if (securityManager.isAuthenticationEnabled()) {
- Seq(new SaslServerBootstrap(transportConf, securityManager))
+ java.util.Arrays.asList(new SaslServerBootstrap(transportConf, securityManager))
} else {
- Nil
+ java.util.Collections.emptyList()
}
- server = transportContext.createServer(port, bootstraps.asJava)
+ server = transportContext.createServer(port, bootstraps)
dispatcher.registerRpcEndpoint(IDVerifier.NAME, new IDVerifier(this, dispatcher))
}
override lazy val address: RpcAddress = {
require(server != null, "NettyRpcEnv has not yet started")
- RpcAddress(host, server.getPort())
+ RpcAddress(host, server.getPort)
}
override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = {
@@ -119,7 +118,7 @@ private[netty] class NettyRpcEnv(
val remoteAddr = message.receiver.address
if (remoteAddr == address) {
val promise = Promise[Any]()
- dispatcher.postMessage(message, promise)
+ dispatcher.postLocalMessage(message, promise)
promise.future.onComplete {
case Success(response) =>
val ack = response.asInstanceOf[Ack]
@@ -148,10 +147,9 @@ private[netty] class NettyRpcEnv(
}
})
} catch {
- case e: RejectedExecutionException => {
+ case e: RejectedExecutionException =>
// `send` after shutting clientConnectionExecutor down, ignore it
- logWarning(s"Cannot send ${message} because RpcEnv is stopped")
- }
+ logWarning(s"Cannot send $message because RpcEnv is stopped")
}
}
}
@@ -161,7 +159,7 @@ private[netty] class NettyRpcEnv(
val remoteAddr = message.receiver.address
if (remoteAddr == address) {
val p = Promise[Any]()
- dispatcher.postMessage(message, p)
+ dispatcher.postLocalMessage(message, p)
p.future.onComplete {
case Success(response) =>
val reply = response.asInstanceOf[AskResponse]
@@ -218,7 +216,7 @@ private[netty] class NettyRpcEnv(
private[netty] def serialize(content: Any): Array[Byte] = {
val buffer = javaSerializerInstance.serialize(content)
- Arrays.copyOfRange(
+ java.util.Arrays.copyOfRange(
buffer.array(), buffer.arrayOffset + buffer.position, buffer.arrayOffset + buffer.limit)
}
@@ -425,7 +423,7 @@ private[netty] class NettyRpcHandler(
assert(addr != null)
val remoteEnvAddress = requestMessage.senderAddress
val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
- val broadcastMessage =
+ val broadcastMessage: Option[RemoteProcessConnected] =
synchronized {
// If the first connection to a remote RpcEnv is found, we should broadcast "Associated"
if (remoteAddresses.put(clientAddr, remoteEnvAddress).isEmpty) {
@@ -435,7 +433,7 @@ private[netty] class NettyRpcHandler(
remoteConnectionCount.put(remoteEnvAddress, count + 1)
if (count == 0) {
// This is the first connection, so fire "Associated"
- Some(Associated(remoteEnvAddress))
+ Some(RemoteProcessConnected(remoteEnvAddress))
} else {
None
}
@@ -443,8 +441,8 @@ private[netty] class NettyRpcHandler(
None
}
}
- broadcastMessage.foreach(dispatcher.broadcastMessage)
- dispatcher.postMessage(requestMessage, callback)
+ broadcastMessage.foreach(dispatcher.postToAll)
+ dispatcher.postRemoteMessage(requestMessage, callback)
}
override def getStreamManager: StreamManager = new OneForOneStreamManager
@@ -455,12 +453,12 @@ private[netty] class NettyRpcHandler(
val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
val broadcastMessage =
synchronized {
- remoteAddresses.get(clientAddr).map(AssociationError(cause, _))
+ remoteAddresses.get(clientAddr).map(RemoteProcessConnectionError(cause, _))
}
if (broadcastMessage.isEmpty) {
logError(cause.getMessage, cause)
} else {
- dispatcher.broadcastMessage(broadcastMessage.get)
+ dispatcher.postToAll(broadcastMessage.get)
}
} else {
// If the channel is closed before connecting, its remoteAddress will be null.
@@ -485,7 +483,7 @@ private[netty] class NettyRpcHandler(
if (count - 1 == 0) {
// We lost all clients, so clean up and fire "Disassociated"
remoteConnectionCount.remove(remoteEnvAddress)
- Some(Disassociated(remoteEnvAddress))
+ Some(RemoteProcessDisconnected(remoteEnvAddress))
} else {
// Decrease the connection number of remoteEnvAddress
remoteConnectionCount.put(remoteEnvAddress, count - 1)
@@ -493,7 +491,7 @@ private[netty] class NettyRpcHandler(
}
}
}
- broadcastMessage.foreach(dispatcher.broadcastMessage)
+ broadcastMessage.foreach(dispatcher.postToAll)
} else {
// If the channel is closed before connecting, its remoteAddress will be null. In this case,
// we can ignore it since we don't fire "Associated".
diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
index 1ed098379e..15e7519d70 100644
--- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
@@ -15,7 +15,6 @@
* limitations under the License.
*/
-
package org.apache.spark.util
import java.util.concurrent._
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index e60c1b355a..bd7e51c3b5 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -1895,6 +1895,7 @@ private[spark] object Utils extends Logging {
* This is expected to throw java.net.BindException on port collision.
* @param conf A SparkConf used to get the maximum number of retries when binding to a port.
* @param serviceName Name of the service.
+ * @return (service: T, port: Int)
*/
def startServiceOnPort[T](
startPort: Int,
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala
index 120cf1b6fa..276c077b3d 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala
@@ -113,7 +113,7 @@ class InboxSuite extends SparkFunSuite {
val remoteAddress = RpcAddress("localhost", 11111)
val inbox = new Inbox(endpointRef, endpoint)
- inbox.post(Associated(remoteAddress))
+ inbox.post(RemoteProcessConnected(remoteAddress))
inbox.process(dispatcher)
endpoint.verifySingleOnConnectedMessage(remoteAddress)
@@ -127,7 +127,7 @@ class InboxSuite extends SparkFunSuite {
val remoteAddress = RpcAddress("localhost", 11111)
val inbox = new Inbox(endpointRef, endpoint)
- inbox.post(Disassociated(remoteAddress))
+ inbox.post(RemoteProcessDisconnected(remoteAddress))
inbox.process(dispatcher)
endpoint.verifySingleOnDisconnectedMessage(remoteAddress)
@@ -142,7 +142,7 @@ class InboxSuite extends SparkFunSuite {
val cause = new RuntimeException("Oops")
val inbox = new Inbox(endpointRef, endpoint)
- inbox.post(AssociationError(cause, remoteAddress))
+ inbox.post(RemoteProcessConnectionError(cause, remoteAddress))
inbox.process(dispatcher)
endpoint.verifySingleOnNetworkErrorMessage(cause, remoteAddress)
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
index 06ca035d19..f24f78b8c4 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
@@ -45,7 +45,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite {
when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40001))
nettyRpcHandler.receive(client, null, null)
- verify(dispatcher, times(1)).broadcastMessage(Associated(RpcAddress("localhost", 12345)))
+ verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 12345)))
}
test("connectionTerminated") {
@@ -60,8 +60,9 @@ class NettyRpcHandlerSuite extends SparkFunSuite {
when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000))
nettyRpcHandler.connectionTerminated(client)
- verify(dispatcher, times(1)).broadcastMessage(Associated(RpcAddress("localhost", 12345)))
- verify(dispatcher, times(1)).broadcastMessage(Disassociated(RpcAddress("localhost", 12345)))
+ verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 12345)))
+ verify(dispatcher, times(1)).postToAll(
+ RemoteProcessDisconnected(RpcAddress("localhost", 12345)))
}
}