aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala55
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala35
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala153
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala64
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala2
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClient.java34
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/Message.java4
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java3
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java75
-rw-r--r--network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java5
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java36
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java18
-rw-r--r--network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java2
-rw-r--r--network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java31
-rw-r--r--network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java9
18 files changed, 374 insertions, 192 deletions
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
index 3aef0515cb..25a17473e4 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
@@ -92,7 +92,11 @@ private[deploy] class ExecutorRunner(
process.destroy()
exitCode = Some(process.waitFor())
}
- worker.send(ExecutorStateChanged(appId, execId, state, message, exitCode))
+ try {
+ worker.send(ExecutorStateChanged(appId, execId, state, message, exitCode))
+ } catch {
+ case e: IllegalStateException => logWarning(e.getMessage(), e)
+ }
}
/** Stop this executor runner, including killing the process it launched */
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
index eb25d6c7b7..533c984766 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
@@ -106,44 +106,30 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
val iter = endpoints.keySet().iterator()
while (iter.hasNext) {
val name = iter.next
- postMessage(
- name,
- _ => message,
- () => { logWarning(s"Drop $message because $name has been stopped") })
+ postMessage(name, message, (e) => logWarning(s"Message $message dropped.", e))
}
}
/** Posts a message sent by a remote endpoint. */
def postRemoteMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = {
- def createMessage(sender: NettyRpcEndpointRef): InboxMessage = {
- val rpcCallContext =
- new RemoteNettyRpcCallContext(
- nettyEnv, sender, callback, message.senderAddress, message.needReply)
- ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext)
- }
-
- def onEndpointStopped(): Unit = {
- callback.onFailure(
- new SparkException(s"Could not find ${message.receiver.name} or it has been stopped"))
- }
-
- postMessage(message.receiver.name, createMessage, onEndpointStopped)
+ val rpcCallContext =
+ new RemoteNettyRpcCallContext(nettyEnv, callback, message.senderAddress)
+ val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext)
+ postMessage(message.receiver.name, rpcMessage, (e) => callback.onFailure(e))
}
/** Posts a message sent by a local endpoint. */
def postLocalMessage(message: RequestMessage, p: Promise[Any]): Unit = {
- def createMessage(sender: NettyRpcEndpointRef): InboxMessage = {
- val rpcCallContext =
- new LocalNettyRpcCallContext(sender, message.senderAddress, message.needReply, p)
- ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext)
- }
-
- def onEndpointStopped(): Unit = {
- p.tryFailure(
- new SparkException(s"Could not find ${message.receiver.name} or it has been stopped"))
- }
+ val rpcCallContext =
+ new LocalNettyRpcCallContext(message.senderAddress, p)
+ val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext)
+ postMessage(message.receiver.name, rpcMessage, (e) => p.tryFailure(e))
+ }
- postMessage(message.receiver.name, createMessage, onEndpointStopped)
+ /** Posts a one-way message. */
+ def postOneWayMessage(message: RequestMessage): Unit = {
+ postMessage(message.receiver.name, OneWayMessage(message.senderAddress, message.content),
+ (e) => throw e)
}
/**
@@ -155,21 +141,26 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
*/
private def postMessage(
endpointName: String,
- createMessageFn: NettyRpcEndpointRef => InboxMessage,
- callbackIfStopped: () => Unit): Unit = {
+ message: InboxMessage,
+ callbackIfStopped: (Exception) => Unit): Unit = {
val shouldCallOnStop = synchronized {
val data = endpoints.get(endpointName)
if (stopped || data == null) {
true
} else {
- data.inbox.post(createMessageFn(data.ref))
+ data.inbox.post(message)
receivers.offer(data)
false
}
}
if (shouldCallOnStop) {
// We don't need to call `onStop` in the `synchronized` block
- callbackIfStopped()
+ val error = if (stopped) {
+ new IllegalStateException("RpcEnv already stopped.")
+ } else {
+ new SparkException(s"Could not find $endpointName or it has been stopped.")
+ }
+ callbackIfStopped(error)
}
}
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala
index 464027f07c..175463cc10 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala
@@ -27,10 +27,13 @@ import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, ThreadSafeRpcEndpoint}
private[netty] sealed trait InboxMessage
-private[netty] case class ContentMessage(
+private[netty] case class OneWayMessage(
+ senderAddress: RpcAddress,
+ content: Any) extends InboxMessage
+
+private[netty] case class RpcMessage(
senderAddress: RpcAddress,
content: Any,
- needReply: Boolean,
context: NettyRpcCallContext) extends InboxMessage
private[netty] case object OnStart extends InboxMessage
@@ -96,29 +99,24 @@ private[netty] class Inbox(
while (true) {
safelyCall(endpoint) {
message match {
- case ContentMessage(_sender, content, needReply, context) =>
- // The partial function to call
- val pf = if (needReply) endpoint.receiveAndReply(context) else endpoint.receive
+ case RpcMessage(_sender, content, context) =>
try {
- pf.applyOrElse[Any, Unit](content, { msg =>
+ endpoint.receiveAndReply(context).applyOrElse[Any, Unit](content, { msg =>
throw new SparkException(s"Unsupported message $message from ${_sender}")
})
- if (!needReply) {
- context.finish()
- }
} catch {
case NonFatal(e) =>
- if (needReply) {
- // If the sender asks a reply, we should send the error back to the sender
- context.sendFailure(e)
- } else {
- context.finish()
- }
+ context.sendFailure(e)
// Throw the exception -- this exception will be caught by the safelyCall function.
// The endpoint's onError function will be called.
throw e
}
+ case OneWayMessage(_sender, content) =>
+ endpoint.receive.applyOrElse[Any, Unit](content, { msg =>
+ throw new SparkException(s"Unsupported message $message from ${_sender}")
+ })
+
case OnStart =>
endpoint.onStart()
if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala
index 21d5bb4923..6637e2321f 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala
@@ -23,49 +23,28 @@ import org.apache.spark.Logging
import org.apache.spark.network.client.RpcResponseCallback
import org.apache.spark.rpc.{RpcAddress, RpcCallContext}
-private[netty] abstract class NettyRpcCallContext(
- endpointRef: NettyRpcEndpointRef,
- override val senderAddress: RpcAddress,
- needReply: Boolean)
+private[netty] abstract class NettyRpcCallContext(override val senderAddress: RpcAddress)
extends RpcCallContext with Logging {
protected def send(message: Any): Unit
override def reply(response: Any): Unit = {
- if (needReply) {
- send(AskResponse(endpointRef, response))
- } else {
- throw new IllegalStateException(
- s"Cannot send $response to the sender because the sender does not expect a reply")
- }
+ send(response)
}
override def sendFailure(e: Throwable): Unit = {
- if (needReply) {
- send(AskResponse(endpointRef, RpcFailure(e)))
- } else {
- logError(e.getMessage, e)
- throw new IllegalStateException(
- "Cannot send reply to the sender because the sender won't handle it")
- }
+ send(RpcFailure(e))
}
- def finish(): Unit = {
- if (!needReply) {
- send(Ack(endpointRef))
- }
- }
}
/**
* If the sender and the receiver are in the same process, the reply can be sent back via `Promise`.
*/
private[netty] class LocalNettyRpcCallContext(
- endpointRef: NettyRpcEndpointRef,
senderAddress: RpcAddress,
- needReply: Boolean,
p: Promise[Any])
- extends NettyRpcCallContext(endpointRef, senderAddress, needReply) {
+ extends NettyRpcCallContext(senderAddress) {
override protected def send(message: Any): Unit = {
p.success(message)
@@ -77,11 +56,9 @@ private[netty] class LocalNettyRpcCallContext(
*/
private[netty] class RemoteNettyRpcCallContext(
nettyEnv: NettyRpcEnv,
- endpointRef: NettyRpcEndpointRef,
callback: RpcResponseCallback,
- senderAddress: RpcAddress,
- needReply: Boolean)
- extends NettyRpcCallContext(endpointRef, senderAddress, needReply) {
+ senderAddress: RpcAddress)
+ extends NettyRpcCallContext(senderAddress) {
override protected def send(message: Any): Unit = {
val reply = nettyEnv.serialize(message)
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
index c8fa870f50..c7d74fa1d9 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
@@ -150,7 +150,7 @@ private[netty] class NettyRpcEnv(
private def postToOutbox(receiver: NettyRpcEndpointRef, message: OutboxMessage): Unit = {
if (receiver.client != null) {
- receiver.client.sendRpc(message.content, message.createCallback(receiver.client));
+ message.sendWith(receiver.client)
} else {
require(receiver.address != null,
"Cannot send message to client endpoint with no listen address.")
@@ -182,25 +182,10 @@ private[netty] class NettyRpcEnv(
val remoteAddr = message.receiver.address
if (remoteAddr == address) {
// Message to a local RPC endpoint.
- val promise = Promise[Any]()
- dispatcher.postLocalMessage(message, promise)
- promise.future.onComplete {
- case Success(response) =>
- val ack = response.asInstanceOf[Ack]
- logTrace(s"Received ack from ${ack.sender}")
- case Failure(e) =>
- logWarning(s"Exception when sending $message", e)
- }(ThreadUtils.sameThread)
+ dispatcher.postOneWayMessage(message)
} else {
// Message to a remote RPC endpoint.
- postToOutbox(message.receiver, OutboxMessage(serialize(message),
- (e) => {
- logWarning(s"Exception when sending $message", e)
- },
- (client, response) => {
- val ack = deserialize[Ack](client, response)
- logDebug(s"Receive ack from ${ack.sender}")
- }))
+ postToOutbox(message.receiver, OneWayOutboxMessage(serialize(message)))
}
}
@@ -208,46 +193,52 @@ private[netty] class NettyRpcEnv(
clientFactory.createClient(address.host, address.port)
}
- private[netty] def ask(message: RequestMessage): Future[Any] = {
+ private[netty] def ask[T: ClassTag](message: RequestMessage, timeout: RpcTimeout): Future[T] = {
val promise = Promise[Any]()
val remoteAddr = message.receiver.address
+
+ def onFailure(e: Throwable): Unit = {
+ if (!promise.tryFailure(e)) {
+ logWarning(s"Ignored failure: $e")
+ }
+ }
+
+ def onSuccess(reply: Any): Unit = reply match {
+ case RpcFailure(e) => onFailure(e)
+ case rpcReply =>
+ if (!promise.trySuccess(rpcReply)) {
+ logWarning(s"Ignored message: $reply")
+ }
+ }
+
if (remoteAddr == address) {
val p = Promise[Any]()
- dispatcher.postLocalMessage(message, p)
p.future.onComplete {
- case Success(response) =>
- val reply = response.asInstanceOf[AskResponse]
- if (reply.reply.isInstanceOf[RpcFailure]) {
- if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) {
- logWarning(s"Ignore failure: ${reply.reply}")
- }
- } else if (!promise.trySuccess(reply.reply)) {
- logWarning(s"Ignore message: ${reply}")
- }
- case Failure(e) =>
- if (!promise.tryFailure(e)) {
- logWarning("Ignore Exception", e)
- }
+ case Success(response) => onSuccess(response)
+ case Failure(e) => onFailure(e)
}(ThreadUtils.sameThread)
+ dispatcher.postLocalMessage(message, p)
} else {
- postToOutbox(message.receiver, OutboxMessage(serialize(message),
- (e) => {
- if (!promise.tryFailure(e)) {
- logWarning("Ignore Exception", e)
- }
- },
- (client, response) => {
- val reply = deserialize[AskResponse](client, response)
- if (reply.reply.isInstanceOf[RpcFailure]) {
- if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) {
- logWarning(s"Ignore failure: ${reply.reply}")
- }
- } else if (!promise.trySuccess(reply.reply)) {
- logWarning(s"Ignore message: ${reply}")
- }
- }))
+ val rpcMessage = RpcOutboxMessage(serialize(message),
+ onFailure,
+ (client, response) => onSuccess(deserialize[Any](client, response)))
+ postToOutbox(message.receiver, rpcMessage)
+ promise.future.onFailure {
+ case _: TimeoutException => rpcMessage.onTimeout()
+ case _ =>
+ }(ThreadUtils.sameThread)
}
- promise.future
+
+ val timeoutCancelable = timeoutScheduler.schedule(new Runnable {
+ override def run(): Unit = {
+ promise.tryFailure(
+ new TimeoutException("Cannot receive any reply in ${timeout.duration}"))
+ }
+ }, timeout.duration.toNanos, TimeUnit.NANOSECONDS)
+ promise.future.onComplete { v =>
+ timeoutCancelable.cancel(true)
+ }(ThreadUtils.sameThread)
+ promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread)
}
private[netty] def serialize(content: Any): Array[Byte] = {
@@ -512,25 +503,12 @@ private[netty] class NettyRpcEndpointRef(
override def name: String = _name
override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
- val promise = Promise[Any]()
- val timeoutCancelable = nettyEnv.timeoutScheduler.schedule(new Runnable {
- override def run(): Unit = {
- promise.tryFailure(new TimeoutException("Cannot receive any reply in " + timeout.duration))
- }
- }, timeout.duration.toNanos, TimeUnit.NANOSECONDS)
- val f = nettyEnv.ask(RequestMessage(nettyEnv.address, this, message, true))
- f.onComplete { v =>
- timeoutCancelable.cancel(true)
- if (!promise.tryComplete(v)) {
- logWarning(s"Ignore message $v")
- }
- }(ThreadUtils.sameThread)
- promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread)
+ nettyEnv.ask(RequestMessage(nettyEnv.address, this, message), timeout)
}
override def send(message: Any): Unit = {
require(message != null, "Message is null")
- nettyEnv.send(RequestMessage(nettyEnv.address, this, message, false))
+ nettyEnv.send(RequestMessage(nettyEnv.address, this, message))
}
override def toString: String = s"NettyRpcEndpointRef(${_address})"
@@ -549,24 +527,7 @@ private[netty] class NettyRpcEndpointRef(
* The message that is sent from the sender to the receiver.
*/
private[netty] case class RequestMessage(
- senderAddress: RpcAddress, receiver: NettyRpcEndpointRef, content: Any, needReply: Boolean)
-
-/**
- * The base trait for all messages that are sent back from the receiver to the sender.
- */
-private[netty] trait ResponseMessage
-
-/**
- * The reply for `ask` from the receiver side.
- */
-private[netty] case class AskResponse(sender: NettyRpcEndpointRef, reply: Any)
- extends ResponseMessage
-
-/**
- * A message to send back to the receiver side. It's necessary because [[TransportClient]] only
- * clean the resources when it receives a reply.
- */
-private[netty] case class Ack(sender: NettyRpcEndpointRef) extends ResponseMessage
+ senderAddress: RpcAddress, receiver: NettyRpcEndpointRef, content: Any)
/**
* A response that indicates some failure happens in the receiver side.
@@ -598,6 +559,18 @@ private[netty] class NettyRpcHandler(
client: TransportClient,
message: Array[Byte],
callback: RpcResponseCallback): Unit = {
+ val messageToDispatch = internalReceive(client, message)
+ dispatcher.postRemoteMessage(messageToDispatch, callback)
+ }
+
+ override def receive(
+ client: TransportClient,
+ message: Array[Byte]): Unit = {
+ val messageToDispatch = internalReceive(client, message)
+ dispatcher.postOneWayMessage(messageToDispatch)
+ }
+
+ private def internalReceive(client: TransportClient, message: Array[Byte]): RequestMessage = {
val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
assert(addr != null)
val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
@@ -605,14 +578,12 @@ private[netty] class NettyRpcHandler(
dispatcher.postToAll(RemoteProcessConnected(clientAddr))
}
val requestMessage = nettyEnv.deserialize[RequestMessage](client, message)
- val messageToDispatch = if (requestMessage.senderAddress == null) {
- // Create a new message with the socket address of the client as the sender.
- RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content,
- requestMessage.needReply)
- } else {
- requestMessage
- }
- dispatcher.postRemoteMessage(messageToDispatch, callback)
+ if (requestMessage.senderAddress == null) {
+ // Create a new message with the socket address of the client as the sender.
+ RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content)
+ } else {
+ requestMessage
+ }
}
override def getStreamManager: StreamManager = streamManager
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala
index 2f6817f2eb..36fdd00bbc 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala
@@ -22,22 +22,56 @@ import javax.annotation.concurrent.GuardedBy
import scala.util.control.NonFatal
-import org.apache.spark.SparkException
+import org.apache.spark.{Logging, SparkException}
import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
import org.apache.spark.rpc.RpcAddress
-private[netty] case class OutboxMessage(content: Array[Byte],
- _onFailure: (Throwable) => Unit,
- _onSuccess: (TransportClient, Array[Byte]) => Unit) {
+private[netty] sealed trait OutboxMessage {
- def createCallback(client: TransportClient): RpcResponseCallback = new RpcResponseCallback() {
- override def onFailure(e: Throwable): Unit = {
- _onFailure(e)
- }
+ def sendWith(client: TransportClient): Unit
- override def onSuccess(response: Array[Byte]): Unit = {
- _onSuccess(client, response)
- }
+ def onFailure(e: Throwable): Unit
+
+}
+
+private[netty] case class OneWayOutboxMessage(content: Array[Byte]) extends OutboxMessage
+ with Logging {
+
+ override def sendWith(client: TransportClient): Unit = {
+ client.send(content)
+ }
+
+ override def onFailure(e: Throwable): Unit = {
+ logWarning(s"Failed to send one-way RPC.", e)
+ }
+
+}
+
+private[netty] case class RpcOutboxMessage(
+ content: Array[Byte],
+ _onFailure: (Throwable) => Unit,
+ _onSuccess: (TransportClient, Array[Byte]) => Unit)
+ extends OutboxMessage with RpcResponseCallback {
+
+ private var client: TransportClient = _
+ private var requestId: Long = _
+
+ override def sendWith(client: TransportClient): Unit = {
+ this.client = client
+ this.requestId = client.sendRpc(content, this)
+ }
+
+ def onTimeout(): Unit = {
+ require(client != null, "TransportClient has not yet been set.")
+ client.removeRpcRequest(requestId)
+ }
+
+ override def onFailure(e: Throwable): Unit = {
+ _onFailure(e)
+ }
+
+ override def onSuccess(response: Array[Byte]): Unit = {
+ _onSuccess(client, response)
}
}
@@ -82,7 +116,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
}
}
if (dropped) {
- message._onFailure(new SparkException("Message is dropped because Outbox is stopped"))
+ message.onFailure(new SparkException("Message is dropped because Outbox is stopped"))
} else {
drainOutbox()
}
@@ -122,7 +156,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
try {
val _client = synchronized { client }
if (_client != null) {
- _client.sendRpc(message.content, message.createCallback(_client))
+ message.sendWith(_client)
} else {
assert(stopped == true)
}
@@ -195,7 +229,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
// update messages and it's safe to just drain the queue.
var message = messages.poll()
while (message != null) {
- message._onFailure(e)
+ message.onFailure(e)
message = messages.poll()
}
assert(messages.isEmpty)
@@ -229,7 +263,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
// update messages and it's safe to just drain the queue.
var message = messages.poll()
while (message != null) {
- message._onFailure(new SparkException("Message is dropped because Outbox is stopped"))
+ message.onFailure(new SparkException("Message is dropped because Outbox is stopped"))
message = messages.poll()
}
}
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala
index 276c077b3d..2136795b18 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala
@@ -35,7 +35,7 @@ class InboxSuite extends SparkFunSuite {
val dispatcher = mock(classOf[Dispatcher])
val inbox = new Inbox(endpointRef, endpoint)
- val message = ContentMessage(null, "hi", false, null)
+ val message = OneWayMessage(null, "hi")
inbox.post(message)
inbox.process(dispatcher)
assert(inbox.isEmpty)
@@ -55,7 +55,7 @@ class InboxSuite extends SparkFunSuite {
val dispatcher = mock(classOf[Dispatcher])
val inbox = new Inbox(endpointRef, endpoint)
- val message = ContentMessage(null, "hi", true, null)
+ val message = RpcMessage(null, "hi", null)
inbox.post(message)
inbox.process(dispatcher)
assert(inbox.isEmpty)
@@ -83,7 +83,7 @@ class InboxSuite extends SparkFunSuite {
new Thread {
override def run(): Unit = {
for (_ <- 0 until 100) {
- val message = ContentMessage(null, "hi", false, null)
+ val message = OneWayMessage(null, "hi")
inbox.post(message)
}
exitLatch.countDown()
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
index ccca795683..323184cdd9 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
@@ -33,7 +33,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite {
val env = mock(classOf[NettyRpcEnv])
val sm = mock(classOf[StreamManager])
when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any()))
- .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false))
+ .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null))
test("receive") {
val dispatcher = mock(classOf[Dispatcher])
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
index 876fcd8467..8a58e7b245 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
@@ -25,6 +25,7 @@ import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
+import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Objects;
import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
@@ -36,6 +37,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.network.protocol.ChunkFetchRequest;
+import org.apache.spark.network.protocol.OneWayMessage;
import org.apache.spark.network.protocol.RpcRequest;
import org.apache.spark.network.protocol.StreamChunkId;
import org.apache.spark.network.protocol.StreamRequest;
@@ -205,8 +207,12 @@ public class TransportClient implements Closeable {
/**
* Sends an opaque message to the RpcHandler on the server-side. The callback will be invoked
* with the server's response or upon any failure.
+ *
+ * @param message The message to send.
+ * @param callback Callback to handle the RPC's reply.
+ * @return The RPC's id.
*/
- public void sendRpc(byte[] message, final RpcResponseCallback callback) {
+ public long sendRpc(byte[] message, final RpcResponseCallback callback) {
final String serverAddr = NettyUtils.getRemoteAddress(channel);
final long startTime = System.currentTimeMillis();
logger.trace("Sending RPC to {}", serverAddr);
@@ -235,6 +241,8 @@ public class TransportClient implements Closeable {
}
}
});
+
+ return requestId;
}
/**
@@ -265,11 +273,35 @@ public class TransportClient implements Closeable {
}
}
+ /**
+ * Sends an opaque message to the RpcHandler on the server-side. No reply is expected for the
+ * message, and no delivery guarantees are made.
+ *
+ * @param message The message to send.
+ */
+ public void send(byte[] message) {
+ channel.writeAndFlush(new OneWayMessage(message));
+ }
+
+ /**
+ * Removes any state associated with the given RPC.
+ *
+ * @param requestId The RPC id returned by {@link #sendRpc(byte[], RpcResponseCallback)}.
+ */
+ public void removeRpcRequest(long requestId) {
+ handler.removeRpcRequest(requestId);
+ }
+
/** Mark this channel as having timed out. */
public void timeOut() {
this.timedOut = true;
}
+ @VisibleForTesting
+ public TransportResponseHandler getHandler() {
+ return handler;
+ }
+
@Override
public void close() {
// close is a local operation and should finish with milliseconds; timeout just to be safe
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java
index d01598c20f..39afd03db6 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java
@@ -28,7 +28,8 @@ public interface Message extends Encodable {
public static enum Type implements Encodable {
ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2),
RpcRequest(3), RpcResponse(4), RpcFailure(5),
- StreamRequest(6), StreamResponse(7), StreamFailure(8);
+ StreamRequest(6), StreamResponse(7), StreamFailure(8),
+ OneWayMessage(9);
private final byte id;
@@ -55,6 +56,7 @@ public interface Message extends Encodable {
case 6: return StreamRequest;
case 7: return StreamResponse;
case 8: return StreamFailure;
+ case 9: return OneWayMessage;
default: throw new IllegalArgumentException("Unknown message type: " + id);
}
}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
index 3c04048f38..074780f2b9 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
@@ -63,6 +63,9 @@ public final class MessageDecoder extends MessageToMessageDecoder<ByteBuf> {
case RpcFailure:
return RpcFailure.decode(in);
+ case OneWayMessage:
+ return OneWayMessage.decode(in);
+
case StreamRequest:
return StreamRequest.decode(in);
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java
new file mode 100644
index 0000000000..95a0270be3
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import java.util.Arrays;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+/**
+ * A RPC that does not expect a reply, which is handled by a remote
+ * {@link org.apache.spark.network.server.RpcHandler}.
+ */
+public final class OneWayMessage implements RequestMessage {
+ /** Serialized message to send to remote RpcHandler. */
+ public final byte[] message;
+
+ public OneWayMessage(byte[] message) {
+ this.message = message;
+ }
+
+ @Override
+ public Type type() { return Type.OneWayMessage; }
+
+ @Override
+ public int encodedLength() {
+ return Encoders.ByteArrays.encodedLength(message);
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ Encoders.ByteArrays.encode(buf, message);
+ }
+
+ public static OneWayMessage decode(ByteBuf buf) {
+ byte[] message = Encoders.ByteArrays.decode(buf);
+ return new OneWayMessage(message);
+ }
+
+ @Override
+ public int hashCode() {
+ return Arrays.hashCode(message);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof OneWayMessage) {
+ OneWayMessage o = (OneWayMessage) other;
+ return Arrays.equals(message, o.message);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("message", message)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
index 7033adb9ca..830db94b89 100644
--- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
@@ -109,6 +109,11 @@ class SaslRpcHandler extends RpcHandler {
}
@Override
+ public void receive(TransportClient client, byte[] message) {
+ delegate.receive(client, message);
+ }
+
+ @Override
public StreamManager getStreamManager() {
return delegate.getStreamManager();
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
index dbb7f95f55..65109ddfe1 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
@@ -17,6 +17,9 @@
package org.apache.spark.network.server;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
@@ -24,6 +27,9 @@ import org.apache.spark.network.client.TransportClient;
* Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s.
*/
public abstract class RpcHandler {
+
+ private static final RpcResponseCallback ONE_WAY_CALLBACK = new OneWayRpcCallback();
+
/**
* Receive a single RPC message. Any exception thrown while in this method will be sent back to
* the client in string form as a standard RPC failure.
@@ -48,10 +54,40 @@ public abstract class RpcHandler {
public abstract StreamManager getStreamManager();
/**
+ * Receives an RPC message that does not expect a reply. The default implementation will
+ * call "{@link receive(TransportClient, byte[], RpcResponseCallback}" and log a warning if
+ * any of the callback methods are called.
+ *
+ * @param client A channel client which enables the handler to make requests back to the sender
+ * of this RPC. This will always be the exact same object for a particular channel.
+ * @param message The serialized bytes of the RPC.
+ */
+ public void receive(TransportClient client, byte[] message) {
+ receive(client, message, ONE_WAY_CALLBACK);
+ }
+
+ /**
* Invoked when the connection associated with the given client has been invalidated.
* No further requests will come from this client.
*/
public void connectionTerminated(TransportClient client) { }
public void exceptionCaught(Throwable cause, TransportClient client) { }
+
+ private static class OneWayRpcCallback implements RpcResponseCallback {
+
+ private final Logger logger = LoggerFactory.getLogger(OneWayRpcCallback.class);
+
+ @Override
+ public void onSuccess(byte[] response) {
+ logger.warn("Response provided for one-way RPC.");
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ logger.error("Error response provided for one-way RPC.", e);
+ }
+
+ }
+
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
index 4f67bd573b..db18ea77d1 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
@@ -17,6 +17,7 @@
package org.apache.spark.network.server;
+import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
@@ -27,13 +28,14 @@ import org.slf4j.LoggerFactory;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
-import org.apache.spark.network.protocol.Encodable;
-import org.apache.spark.network.protocol.RequestMessage;
import org.apache.spark.network.protocol.ChunkFetchRequest;
-import org.apache.spark.network.protocol.RpcRequest;
import org.apache.spark.network.protocol.ChunkFetchFailure;
import org.apache.spark.network.protocol.ChunkFetchSuccess;
+import org.apache.spark.network.protocol.Encodable;
+import org.apache.spark.network.protocol.OneWayMessage;
+import org.apache.spark.network.protocol.RequestMessage;
import org.apache.spark.network.protocol.RpcFailure;
+import org.apache.spark.network.protocol.RpcRequest;
import org.apache.spark.network.protocol.RpcResponse;
import org.apache.spark.network.protocol.StreamFailure;
import org.apache.spark.network.protocol.StreamRequest;
@@ -95,6 +97,8 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
processFetchRequest((ChunkFetchRequest) request);
} else if (request instanceof RpcRequest) {
processRpcRequest((RpcRequest) request);
+ } else if (request instanceof OneWayMessage) {
+ processOneWayMessage((OneWayMessage) request);
} else if (request instanceof StreamRequest) {
processStreamRequest((StreamRequest) request);
} else {
@@ -156,6 +160,14 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
}
}
+ private void processOneWayMessage(OneWayMessage req) {
+ try {
+ rpcHandler.receive(reverseClient, req.message);
+ } catch (Exception e) {
+ logger.error("Error while invoking RpcHandler#receive() for one-way message.", e);
+ }
+ }
+
/**
* Responds to a single message with some Encodable object. If a failure occurs while sending,
* it will be logged and the channel closed.
diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
index 22b451fc0e..1aa20900ff 100644
--- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
@@ -35,6 +35,7 @@ import org.apache.spark.network.protocol.ChunkFetchSuccess;
import org.apache.spark.network.protocol.Message;
import org.apache.spark.network.protocol.MessageDecoder;
import org.apache.spark.network.protocol.MessageEncoder;
+import org.apache.spark.network.protocol.OneWayMessage;
import org.apache.spark.network.protocol.RpcFailure;
import org.apache.spark.network.protocol.RpcRequest;
import org.apache.spark.network.protocol.RpcResponse;
@@ -84,6 +85,7 @@ public class ProtocolSuite {
testClientToServer(new RpcRequest(12345, new byte[0]));
testClientToServer(new RpcRequest(12345, new byte[100]));
testClientToServer(new StreamRequest("abcde"));
+ testClientToServer(new OneWayMessage(new byte[100]));
}
@Test
diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
index 8eb56bdd98..88fa2258bb 100644
--- a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
@@ -17,9 +17,11 @@
package org.apache.spark.network;
+import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
+import java.util.List;
import java.util.Set;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
@@ -46,6 +48,7 @@ public class RpcIntegrationSuite {
static TransportServer server;
static TransportClientFactory clientFactory;
static RpcHandler rpcHandler;
+ static List<String> oneWayMsgs;
@BeforeClass
public static void setUp() throws Exception {
@@ -65,11 +68,18 @@ public class RpcIntegrationSuite {
}
@Override
+ public void receive(TransportClient client, byte[] message) {
+ String msg = new String(message, Charsets.UTF_8);
+ oneWayMsgs.add(msg);
+ }
+
+ @Override
public StreamManager getStreamManager() { return new OneForOneStreamManager(); }
};
TransportContext context = new TransportContext(conf, rpcHandler);
server = context.createServer();
clientFactory = context.createClientFactory();
+ oneWayMsgs = new ArrayList<>();
}
@AfterClass
@@ -158,6 +168,27 @@ public class RpcIntegrationSuite {
assertErrorsContain(res.errorMessages, Sets.newHashSet("Thrown: the", "Returned: !"));
}
+ @Test
+ public void sendOneWayMessage() throws Exception {
+ final String message = "no reply";
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ try {
+ client.send(message.getBytes(Charsets.UTF_8));
+ assertEquals(0, client.getHandler().numOutstandingRequests());
+
+ // Make sure the message arrives.
+ long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS);
+ while (System.nanoTime() < deadline && oneWayMsgs.size() == 0) {
+ TimeUnit.MILLISECONDS.sleep(10);
+ }
+
+ assertEquals(1, oneWayMsgs.size());
+ assertEquals(message, oneWayMsgs.get(0));
+ } finally {
+ client.close();
+ }
+ }
+
private void assertErrorsContain(Set<String> errors, Set<String> contains) {
assertEquals(contains.size(), errors.size());
diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
index b146899670..a6f180bc40 100644
--- a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
@@ -21,6 +21,7 @@ import static org.junit.Assert.*;
import static org.mockito.Mockito.*;
import java.io.File;
+import java.lang.reflect.Method;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.List;
@@ -353,6 +354,14 @@ public class SparkSaslSuite {
verify(handler).exceptionCaught(any(Throwable.class), any(TransportClient.class));
}
+ @Test
+ public void testDelegates() throws Exception {
+ Method[] rpcHandlerMethods = RpcHandler.class.getDeclaredMethods();
+ for (Method m : rpcHandlerMethods) {
+ SaslRpcHandler.class.getDeclaredMethod(m.getName(), m.getParameterTypes());
+ }
+ }
+
private static class SaslTestCtx {
final TransportClient client;