aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorMarcelo Vanzin <vanzin@cloudera.com>2015-11-25 12:58:18 -0800
committerMarcelo Vanzin <vanzin@cloudera.com>2015-11-25 12:58:18 -0800
commit4e81783e92f464d479baaf93eccc3adb1496989a (patch)
tree6ba31cd598671110d0e38f0930d36f358cd9b82d /core
parentd29e2ef4cf43c7f7c5aa40d305cf02be44ce19e0 (diff)
downloadspark-4e81783e92f464d479baaf93eccc3adb1496989a.tar.gz
spark-4e81783e92f464d479baaf93eccc3adb1496989a.tar.bz2
spark-4e81783e92f464d479baaf93eccc3adb1496989a.zip
[SPARK-11866][NETWORK][CORE] Make sure timed out RPCs are cleaned up.
This change does a couple of different things to make sure that the RpcEnv-level code and the network library agree about the status of outstanding RPCs. For RPCs that do not expect a reply ("RpcEnv.send"), support for one way messages (hello CORBA!) was added to the network layer. This is a "fire and forget" message that does not require any state to be kept by the TransportClient; as a result, the RpcEnv 'Ack' message is not needed anymore. For RPCs that do expect a reply ("RpcEnv.ask"), the network library now returns the internal RPC id; if the RpcEnv layer decides to time out the RPC before the network layer does, it now asks the TransportClient to forget about the RPC, so that if the network-level timeout occurs, the client is not killed. As part of implementing the above, I cleaned up some of the code in the netty rpc backend, removing types that were not necessary and factoring out some common code. Of interest is a slight change in the exceptions when posting messages to a stopped RpcEnv; that's mostly to avoid nasty error messages from the local-cluster backend when shutting down, which pollutes the terminal output. Author: Marcelo Vanzin <vanzin@cloudera.com> Closes #9917 from vanzin/SPARK-11866.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala55
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala35
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala153
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala64
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala2
8 files changed, 162 insertions, 187 deletions
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
index 3aef0515cb..25a17473e4 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
@@ -92,7 +92,11 @@ private[deploy] class ExecutorRunner(
process.destroy()
exitCode = Some(process.waitFor())
}
- worker.send(ExecutorStateChanged(appId, execId, state, message, exitCode))
+ try {
+ worker.send(ExecutorStateChanged(appId, execId, state, message, exitCode))
+ } catch {
+ case e: IllegalStateException => logWarning(e.getMessage(), e)
+ }
}
/** Stop this executor runner, including killing the process it launched */
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
index eb25d6c7b7..533c984766 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
@@ -106,44 +106,30 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
val iter = endpoints.keySet().iterator()
while (iter.hasNext) {
val name = iter.next
- postMessage(
- name,
- _ => message,
- () => { logWarning(s"Drop $message because $name has been stopped") })
+ postMessage(name, message, (e) => logWarning(s"Message $message dropped.", e))
}
}
/** Posts a message sent by a remote endpoint. */
def postRemoteMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = {
- def createMessage(sender: NettyRpcEndpointRef): InboxMessage = {
- val rpcCallContext =
- new RemoteNettyRpcCallContext(
- nettyEnv, sender, callback, message.senderAddress, message.needReply)
- ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext)
- }
-
- def onEndpointStopped(): Unit = {
- callback.onFailure(
- new SparkException(s"Could not find ${message.receiver.name} or it has been stopped"))
- }
-
- postMessage(message.receiver.name, createMessage, onEndpointStopped)
+ val rpcCallContext =
+ new RemoteNettyRpcCallContext(nettyEnv, callback, message.senderAddress)
+ val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext)
+ postMessage(message.receiver.name, rpcMessage, (e) => callback.onFailure(e))
}
/** Posts a message sent by a local endpoint. */
def postLocalMessage(message: RequestMessage, p: Promise[Any]): Unit = {
- def createMessage(sender: NettyRpcEndpointRef): InboxMessage = {
- val rpcCallContext =
- new LocalNettyRpcCallContext(sender, message.senderAddress, message.needReply, p)
- ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext)
- }
-
- def onEndpointStopped(): Unit = {
- p.tryFailure(
- new SparkException(s"Could not find ${message.receiver.name} or it has been stopped"))
- }
+ val rpcCallContext =
+ new LocalNettyRpcCallContext(message.senderAddress, p)
+ val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext)
+ postMessage(message.receiver.name, rpcMessage, (e) => p.tryFailure(e))
+ }
- postMessage(message.receiver.name, createMessage, onEndpointStopped)
+ /** Posts a one-way message. */
+ def postOneWayMessage(message: RequestMessage): Unit = {
+ postMessage(message.receiver.name, OneWayMessage(message.senderAddress, message.content),
+ (e) => throw e)
}
/**
@@ -155,21 +141,26 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
*/
private def postMessage(
endpointName: String,
- createMessageFn: NettyRpcEndpointRef => InboxMessage,
- callbackIfStopped: () => Unit): Unit = {
+ message: InboxMessage,
+ callbackIfStopped: (Exception) => Unit): Unit = {
val shouldCallOnStop = synchronized {
val data = endpoints.get(endpointName)
if (stopped || data == null) {
true
} else {
- data.inbox.post(createMessageFn(data.ref))
+ data.inbox.post(message)
receivers.offer(data)
false
}
}
if (shouldCallOnStop) {
// We don't need to call `onStop` in the `synchronized` block
- callbackIfStopped()
+ val error = if (stopped) {
+ new IllegalStateException("RpcEnv already stopped.")
+ } else {
+ new SparkException(s"Could not find $endpointName or it has been stopped.")
+ }
+ callbackIfStopped(error)
}
}
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala
index 464027f07c..175463cc10 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala
@@ -27,10 +27,13 @@ import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, ThreadSafeRpcEndpoint}
private[netty] sealed trait InboxMessage
-private[netty] case class ContentMessage(
+private[netty] case class OneWayMessage(
+ senderAddress: RpcAddress,
+ content: Any) extends InboxMessage
+
+private[netty] case class RpcMessage(
senderAddress: RpcAddress,
content: Any,
- needReply: Boolean,
context: NettyRpcCallContext) extends InboxMessage
private[netty] case object OnStart extends InboxMessage
@@ -96,29 +99,24 @@ private[netty] class Inbox(
while (true) {
safelyCall(endpoint) {
message match {
- case ContentMessage(_sender, content, needReply, context) =>
- // The partial function to call
- val pf = if (needReply) endpoint.receiveAndReply(context) else endpoint.receive
+ case RpcMessage(_sender, content, context) =>
try {
- pf.applyOrElse[Any, Unit](content, { msg =>
+ endpoint.receiveAndReply(context).applyOrElse[Any, Unit](content, { msg =>
throw new SparkException(s"Unsupported message $message from ${_sender}")
})
- if (!needReply) {
- context.finish()
- }
} catch {
case NonFatal(e) =>
- if (needReply) {
- // If the sender asks a reply, we should send the error back to the sender
- context.sendFailure(e)
- } else {
- context.finish()
- }
+ context.sendFailure(e)
// Throw the exception -- this exception will be caught by the safelyCall function.
// The endpoint's onError function will be called.
throw e
}
+ case OneWayMessage(_sender, content) =>
+ endpoint.receive.applyOrElse[Any, Unit](content, { msg =>
+ throw new SparkException(s"Unsupported message $message from ${_sender}")
+ })
+
case OnStart =>
endpoint.onStart()
if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala
index 21d5bb4923..6637e2321f 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala
@@ -23,49 +23,28 @@ import org.apache.spark.Logging
import org.apache.spark.network.client.RpcResponseCallback
import org.apache.spark.rpc.{RpcAddress, RpcCallContext}
-private[netty] abstract class NettyRpcCallContext(
- endpointRef: NettyRpcEndpointRef,
- override val senderAddress: RpcAddress,
- needReply: Boolean)
+private[netty] abstract class NettyRpcCallContext(override val senderAddress: RpcAddress)
extends RpcCallContext with Logging {
protected def send(message: Any): Unit
override def reply(response: Any): Unit = {
- if (needReply) {
- send(AskResponse(endpointRef, response))
- } else {
- throw new IllegalStateException(
- s"Cannot send $response to the sender because the sender does not expect a reply")
- }
+ send(response)
}
override def sendFailure(e: Throwable): Unit = {
- if (needReply) {
- send(AskResponse(endpointRef, RpcFailure(e)))
- } else {
- logError(e.getMessage, e)
- throw new IllegalStateException(
- "Cannot send reply to the sender because the sender won't handle it")
- }
+ send(RpcFailure(e))
}
- def finish(): Unit = {
- if (!needReply) {
- send(Ack(endpointRef))
- }
- }
}
/**
* If the sender and the receiver are in the same process, the reply can be sent back via `Promise`.
*/
private[netty] class LocalNettyRpcCallContext(
- endpointRef: NettyRpcEndpointRef,
senderAddress: RpcAddress,
- needReply: Boolean,
p: Promise[Any])
- extends NettyRpcCallContext(endpointRef, senderAddress, needReply) {
+ extends NettyRpcCallContext(senderAddress) {
override protected def send(message: Any): Unit = {
p.success(message)
@@ -77,11 +56,9 @@ private[netty] class LocalNettyRpcCallContext(
*/
private[netty] class RemoteNettyRpcCallContext(
nettyEnv: NettyRpcEnv,
- endpointRef: NettyRpcEndpointRef,
callback: RpcResponseCallback,
- senderAddress: RpcAddress,
- needReply: Boolean)
- extends NettyRpcCallContext(endpointRef, senderAddress, needReply) {
+ senderAddress: RpcAddress)
+ extends NettyRpcCallContext(senderAddress) {
override protected def send(message: Any): Unit = {
val reply = nettyEnv.serialize(message)
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
index c8fa870f50..c7d74fa1d9 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
@@ -150,7 +150,7 @@ private[netty] class NettyRpcEnv(
private def postToOutbox(receiver: NettyRpcEndpointRef, message: OutboxMessage): Unit = {
if (receiver.client != null) {
- receiver.client.sendRpc(message.content, message.createCallback(receiver.client));
+ message.sendWith(receiver.client)
} else {
require(receiver.address != null,
"Cannot send message to client endpoint with no listen address.")
@@ -182,25 +182,10 @@ private[netty] class NettyRpcEnv(
val remoteAddr = message.receiver.address
if (remoteAddr == address) {
// Message to a local RPC endpoint.
- val promise = Promise[Any]()
- dispatcher.postLocalMessage(message, promise)
- promise.future.onComplete {
- case Success(response) =>
- val ack = response.asInstanceOf[Ack]
- logTrace(s"Received ack from ${ack.sender}")
- case Failure(e) =>
- logWarning(s"Exception when sending $message", e)
- }(ThreadUtils.sameThread)
+ dispatcher.postOneWayMessage(message)
} else {
// Message to a remote RPC endpoint.
- postToOutbox(message.receiver, OutboxMessage(serialize(message),
- (e) => {
- logWarning(s"Exception when sending $message", e)
- },
- (client, response) => {
- val ack = deserialize[Ack](client, response)
- logDebug(s"Receive ack from ${ack.sender}")
- }))
+ postToOutbox(message.receiver, OneWayOutboxMessage(serialize(message)))
}
}
@@ -208,46 +193,52 @@ private[netty] class NettyRpcEnv(
clientFactory.createClient(address.host, address.port)
}
- private[netty] def ask(message: RequestMessage): Future[Any] = {
+ private[netty] def ask[T: ClassTag](message: RequestMessage, timeout: RpcTimeout): Future[T] = {
val promise = Promise[Any]()
val remoteAddr = message.receiver.address
+
+ def onFailure(e: Throwable): Unit = {
+ if (!promise.tryFailure(e)) {
+ logWarning(s"Ignored failure: $e")
+ }
+ }
+
+ def onSuccess(reply: Any): Unit = reply match {
+ case RpcFailure(e) => onFailure(e)
+ case rpcReply =>
+ if (!promise.trySuccess(rpcReply)) {
+ logWarning(s"Ignored message: $reply")
+ }
+ }
+
if (remoteAddr == address) {
val p = Promise[Any]()
- dispatcher.postLocalMessage(message, p)
p.future.onComplete {
- case Success(response) =>
- val reply = response.asInstanceOf[AskResponse]
- if (reply.reply.isInstanceOf[RpcFailure]) {
- if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) {
- logWarning(s"Ignore failure: ${reply.reply}")
- }
- } else if (!promise.trySuccess(reply.reply)) {
- logWarning(s"Ignore message: ${reply}")
- }
- case Failure(e) =>
- if (!promise.tryFailure(e)) {
- logWarning("Ignore Exception", e)
- }
+ case Success(response) => onSuccess(response)
+ case Failure(e) => onFailure(e)
}(ThreadUtils.sameThread)
+ dispatcher.postLocalMessage(message, p)
} else {
- postToOutbox(message.receiver, OutboxMessage(serialize(message),
- (e) => {
- if (!promise.tryFailure(e)) {
- logWarning("Ignore Exception", e)
- }
- },
- (client, response) => {
- val reply = deserialize[AskResponse](client, response)
- if (reply.reply.isInstanceOf[RpcFailure]) {
- if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) {
- logWarning(s"Ignore failure: ${reply.reply}")
- }
- } else if (!promise.trySuccess(reply.reply)) {
- logWarning(s"Ignore message: ${reply}")
- }
- }))
+ val rpcMessage = RpcOutboxMessage(serialize(message),
+ onFailure,
+ (client, response) => onSuccess(deserialize[Any](client, response)))
+ postToOutbox(message.receiver, rpcMessage)
+ promise.future.onFailure {
+ case _: TimeoutException => rpcMessage.onTimeout()
+ case _ =>
+ }(ThreadUtils.sameThread)
}
- promise.future
+
+ val timeoutCancelable = timeoutScheduler.schedule(new Runnable {
+ override def run(): Unit = {
+ promise.tryFailure(
+ new TimeoutException("Cannot receive any reply in ${timeout.duration}"))
+ }
+ }, timeout.duration.toNanos, TimeUnit.NANOSECONDS)
+ promise.future.onComplete { v =>
+ timeoutCancelable.cancel(true)
+ }(ThreadUtils.sameThread)
+ promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread)
}
private[netty] def serialize(content: Any): Array[Byte] = {
@@ -512,25 +503,12 @@ private[netty] class NettyRpcEndpointRef(
override def name: String = _name
override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
- val promise = Promise[Any]()
- val timeoutCancelable = nettyEnv.timeoutScheduler.schedule(new Runnable {
- override def run(): Unit = {
- promise.tryFailure(new TimeoutException("Cannot receive any reply in " + timeout.duration))
- }
- }, timeout.duration.toNanos, TimeUnit.NANOSECONDS)
- val f = nettyEnv.ask(RequestMessage(nettyEnv.address, this, message, true))
- f.onComplete { v =>
- timeoutCancelable.cancel(true)
- if (!promise.tryComplete(v)) {
- logWarning(s"Ignore message $v")
- }
- }(ThreadUtils.sameThread)
- promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread)
+ nettyEnv.ask(RequestMessage(nettyEnv.address, this, message), timeout)
}
override def send(message: Any): Unit = {
require(message != null, "Message is null")
- nettyEnv.send(RequestMessage(nettyEnv.address, this, message, false))
+ nettyEnv.send(RequestMessage(nettyEnv.address, this, message))
}
override def toString: String = s"NettyRpcEndpointRef(${_address})"
@@ -549,24 +527,7 @@ private[netty] class NettyRpcEndpointRef(
* The message that is sent from the sender to the receiver.
*/
private[netty] case class RequestMessage(
- senderAddress: RpcAddress, receiver: NettyRpcEndpointRef, content: Any, needReply: Boolean)
-
-/**
- * The base trait for all messages that are sent back from the receiver to the sender.
- */
-private[netty] trait ResponseMessage
-
-/**
- * The reply for `ask` from the receiver side.
- */
-private[netty] case class AskResponse(sender: NettyRpcEndpointRef, reply: Any)
- extends ResponseMessage
-
-/**
- * A message to send back to the receiver side. It's necessary because [[TransportClient]] only
- * clean the resources when it receives a reply.
- */
-private[netty] case class Ack(sender: NettyRpcEndpointRef) extends ResponseMessage
+ senderAddress: RpcAddress, receiver: NettyRpcEndpointRef, content: Any)
/**
* A response that indicates some failure happens in the receiver side.
@@ -598,6 +559,18 @@ private[netty] class NettyRpcHandler(
client: TransportClient,
message: Array[Byte],
callback: RpcResponseCallback): Unit = {
+ val messageToDispatch = internalReceive(client, message)
+ dispatcher.postRemoteMessage(messageToDispatch, callback)
+ }
+
+ override def receive(
+ client: TransportClient,
+ message: Array[Byte]): Unit = {
+ val messageToDispatch = internalReceive(client, message)
+ dispatcher.postOneWayMessage(messageToDispatch)
+ }
+
+ private def internalReceive(client: TransportClient, message: Array[Byte]): RequestMessage = {
val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
assert(addr != null)
val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
@@ -605,14 +578,12 @@ private[netty] class NettyRpcHandler(
dispatcher.postToAll(RemoteProcessConnected(clientAddr))
}
val requestMessage = nettyEnv.deserialize[RequestMessage](client, message)
- val messageToDispatch = if (requestMessage.senderAddress == null) {
- // Create a new message with the socket address of the client as the sender.
- RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content,
- requestMessage.needReply)
- } else {
- requestMessage
- }
- dispatcher.postRemoteMessage(messageToDispatch, callback)
+ if (requestMessage.senderAddress == null) {
+ // Create a new message with the socket address of the client as the sender.
+ RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content)
+ } else {
+ requestMessage
+ }
}
override def getStreamManager: StreamManager = streamManager
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala
index 2f6817f2eb..36fdd00bbc 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala
@@ -22,22 +22,56 @@ import javax.annotation.concurrent.GuardedBy
import scala.util.control.NonFatal
-import org.apache.spark.SparkException
+import org.apache.spark.{Logging, SparkException}
import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
import org.apache.spark.rpc.RpcAddress
-private[netty] case class OutboxMessage(content: Array[Byte],
- _onFailure: (Throwable) => Unit,
- _onSuccess: (TransportClient, Array[Byte]) => Unit) {
+private[netty] sealed trait OutboxMessage {
- def createCallback(client: TransportClient): RpcResponseCallback = new RpcResponseCallback() {
- override def onFailure(e: Throwable): Unit = {
- _onFailure(e)
- }
+ def sendWith(client: TransportClient): Unit
- override def onSuccess(response: Array[Byte]): Unit = {
- _onSuccess(client, response)
- }
+ def onFailure(e: Throwable): Unit
+
+}
+
+private[netty] case class OneWayOutboxMessage(content: Array[Byte]) extends OutboxMessage
+ with Logging {
+
+ override def sendWith(client: TransportClient): Unit = {
+ client.send(content)
+ }
+
+ override def onFailure(e: Throwable): Unit = {
+ logWarning(s"Failed to send one-way RPC.", e)
+ }
+
+}
+
+private[netty] case class RpcOutboxMessage(
+ content: Array[Byte],
+ _onFailure: (Throwable) => Unit,
+ _onSuccess: (TransportClient, Array[Byte]) => Unit)
+ extends OutboxMessage with RpcResponseCallback {
+
+ private var client: TransportClient = _
+ private var requestId: Long = _
+
+ override def sendWith(client: TransportClient): Unit = {
+ this.client = client
+ this.requestId = client.sendRpc(content, this)
+ }
+
+ def onTimeout(): Unit = {
+ require(client != null, "TransportClient has not yet been set.")
+ client.removeRpcRequest(requestId)
+ }
+
+ override def onFailure(e: Throwable): Unit = {
+ _onFailure(e)
+ }
+
+ override def onSuccess(response: Array[Byte]): Unit = {
+ _onSuccess(client, response)
}
}
@@ -82,7 +116,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
}
}
if (dropped) {
- message._onFailure(new SparkException("Message is dropped because Outbox is stopped"))
+ message.onFailure(new SparkException("Message is dropped because Outbox is stopped"))
} else {
drainOutbox()
}
@@ -122,7 +156,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
try {
val _client = synchronized { client }
if (_client != null) {
- _client.sendRpc(message.content, message.createCallback(_client))
+ message.sendWith(_client)
} else {
assert(stopped == true)
}
@@ -195,7 +229,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
// update messages and it's safe to just drain the queue.
var message = messages.poll()
while (message != null) {
- message._onFailure(e)
+ message.onFailure(e)
message = messages.poll()
}
assert(messages.isEmpty)
@@ -229,7 +263,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
// update messages and it's safe to just drain the queue.
var message = messages.poll()
while (message != null) {
- message._onFailure(new SparkException("Message is dropped because Outbox is stopped"))
+ message.onFailure(new SparkException("Message is dropped because Outbox is stopped"))
message = messages.poll()
}
}
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala
index 276c077b3d..2136795b18 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala
@@ -35,7 +35,7 @@ class InboxSuite extends SparkFunSuite {
val dispatcher = mock(classOf[Dispatcher])
val inbox = new Inbox(endpointRef, endpoint)
- val message = ContentMessage(null, "hi", false, null)
+ val message = OneWayMessage(null, "hi")
inbox.post(message)
inbox.process(dispatcher)
assert(inbox.isEmpty)
@@ -55,7 +55,7 @@ class InboxSuite extends SparkFunSuite {
val dispatcher = mock(classOf[Dispatcher])
val inbox = new Inbox(endpointRef, endpoint)
- val message = ContentMessage(null, "hi", true, null)
+ val message = RpcMessage(null, "hi", null)
inbox.post(message)
inbox.process(dispatcher)
assert(inbox.isEmpty)
@@ -83,7 +83,7 @@ class InboxSuite extends SparkFunSuite {
new Thread {
override def run(): Unit = {
for (_ <- 0 until 100) {
- val message = ContentMessage(null, "hi", false, null)
+ val message = OneWayMessage(null, "hi")
inbox.post(message)
}
exitLatch.countDown()
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
index ccca795683..323184cdd9 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
@@ -33,7 +33,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite {
val env = mock(classOf[NettyRpcEnv])
val sm = mock(classOf[StreamManager])
when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any()))
- .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false))
+ .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null))
test("receive") {
val dispatcher = mock(classOf[Dispatcher])