aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorzsxwing <zsxwing@gmail.com>2015-10-22 21:01:01 -0700
committerReynold Xin <rxin@databricks.com>2015-10-22 21:01:01 -0700
commita88c66ca8780c7228dc909f904d31cd9464ee0e3 (patch)
tree8739bd32dfd756ddf809e8eeb2164128e1f9dd5f /core
parent34e71c6d89c1f2b6236dbf0d75cd12da08003c84 (diff)
downloadspark-a88c66ca8780c7228dc909f904d31cd9464ee0e3.tar.gz
spark-a88c66ca8780c7228dc909f904d31cd9464ee0e3.tar.bz2
spark-a88c66ca8780c7228dc909f904d31cd9464ee0e3.zip
[SPARK-11098][CORE] Add Outbox to cache the sending messages to resolve the message disorder issue
The current NettyRpc has a message order issue because it uses a thread pool to send messages. E.g., running the following two lines in the same thread, ``` ref.send("A") ref.send("B") ``` The remote endpoint may see "B" before "A" because sending "A" and "B" are in parallel. To resolve this issue, this PR added an outbox for each connection, and if we are connecting to the remote node when sending messages, just cache the sending messages in the outbox and send them one by one when the connection is established. Author: zsxwing <zsxwing@gmail.com> Closes #9197 from zsxwing/rpc-outbox.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala145
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala222
2 files changed, 310 insertions, 57 deletions
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
index e01cf1a29e..284284eb80 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
@@ -20,6 +20,7 @@ import java.io._
import java.net.{InetSocketAddress, URI}
import java.nio.ByteBuffer
import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicBoolean
import javax.annotation.concurrent.GuardedBy
import scala.collection.mutable
@@ -70,12 +71,30 @@ private[netty] class NettyRpcEnv(
// Because TransportClientFactory.createClient is blocking, we need to run it in this thread pool
// to implement non-blocking send/ask.
// TODO: a non-blocking TransportClientFactory.createClient in future
- private val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool(
+ private[netty] val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool(
"netty-rpc-connection",
conf.getInt("spark.rpc.connect.threads", 64))
@volatile private var server: TransportServer = _
+ private val stopped = new AtomicBoolean(false)
+
+ /**
+ * A map for [[RpcAddress]] and [[Outbox]]. When we are connecting to a remote [[RpcAddress]],
+ * we just put messages to its [[Outbox]] to implement a non-blocking `send` method.
+ */
+ private val outboxes = new ConcurrentHashMap[RpcAddress, Outbox]()
+
+ /**
+ * Remove the address's Outbox and stop it.
+ */
+ private[netty] def removeOutbox(address: RpcAddress): Unit = {
+ val outbox = outboxes.remove(address)
+ if (outbox != null) {
+ outbox.stop()
+ }
+ }
+
def start(port: Int): Unit = {
val bootstraps: java.util.List[TransportServerBootstrap] =
if (securityManager.isAuthenticationEnabled()) {
@@ -116,6 +135,30 @@ private[netty] class NettyRpcEnv(
dispatcher.stop(endpointRef)
}
+ private def postToOutbox(address: RpcAddress, message: OutboxMessage): Unit = {
+ val targetOutbox = {
+ val outbox = outboxes.get(address)
+ if (outbox == null) {
+ val newOutbox = new Outbox(this, address)
+ val oldOutbox = outboxes.putIfAbsent(address, newOutbox)
+ if (oldOutbox == null) {
+ newOutbox
+ } else {
+ oldOutbox
+ }
+ } else {
+ outbox
+ }
+ }
+ if (stopped.get) {
+ // It's possible that we put `targetOutbox` after stopping. So we need to clean it.
+ outboxes.remove(address)
+ targetOutbox.stop()
+ } else {
+ targetOutbox.send(message)
+ }
+ }
+
private[netty] def send(message: RequestMessage): Unit = {
val remoteAddr = message.receiver.address
if (remoteAddr == address) {
@@ -127,37 +170,28 @@ private[netty] class NettyRpcEnv(
val ack = response.asInstanceOf[Ack]
logTrace(s"Received ack from ${ack.sender}")
case Failure(e) =>
- logError(s"Exception when sending $message", e)
+ logWarning(s"Exception when sending $message", e)
}(ThreadUtils.sameThread)
} else {
// Message to a remote RPC endpoint.
- try {
- // `createClient` will block if it cannot find a known connection, so we should run it in
- // clientConnectionExecutor
- clientConnectionExecutor.execute(new Runnable {
- override def run(): Unit = Utils.tryLogNonFatalError {
- val client = clientFactory.createClient(remoteAddr.host, remoteAddr.port)
- client.sendRpc(serialize(message), new RpcResponseCallback {
-
- override def onFailure(e: Throwable): Unit = {
- logError(s"Exception when sending $message", e)
- }
-
- override def onSuccess(response: Array[Byte]): Unit = {
- val ack = deserialize[Ack](response)
- logDebug(s"Receive ack from ${ack.sender}")
- }
- })
- }
- })
- } catch {
- case e: RejectedExecutionException =>
- // `send` after shutting clientConnectionExecutor down, ignore it
- logWarning(s"Cannot send $message because RpcEnv is stopped")
- }
+ postToOutbox(remoteAddr, OutboxMessage(serialize(message), new RpcResponseCallback {
+
+ override def onFailure(e: Throwable): Unit = {
+ logWarning(s"Exception when sending $message", e)
+ }
+
+ override def onSuccess(response: Array[Byte]): Unit = {
+ val ack = deserialize[Ack](response)
+ logDebug(s"Receive ack from ${ack.sender}")
+ }
+ }))
}
}
+ private[netty] def createClient(address: RpcAddress): TransportClient = {
+ clientFactory.createClient(address.host, address.port)
+ }
+
private[netty] def ask(message: RequestMessage): Future[Any] = {
val promise = Promise[Any]()
val remoteAddr = message.receiver.address
@@ -180,39 +214,25 @@ private[netty] class NettyRpcEnv(
}
}(ThreadUtils.sameThread)
} else {
- try {
- // `createClient` will block if it cannot find a known connection, so we should run it in
- // clientConnectionExecutor
- clientConnectionExecutor.execute(new Runnable {
- override def run(): Unit = {
- val client = clientFactory.createClient(remoteAddr.host, remoteAddr.port)
- client.sendRpc(serialize(message), new RpcResponseCallback {
-
- override def onFailure(e: Throwable): Unit = {
- if (!promise.tryFailure(e)) {
- logWarning("Ignore Exception", e)
- }
- }
-
- override def onSuccess(response: Array[Byte]): Unit = {
- val reply = deserialize[AskResponse](response)
- if (reply.reply.isInstanceOf[RpcFailure]) {
- if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) {
- logWarning(s"Ignore failure: ${reply.reply}")
- }
- } else if (!promise.trySuccess(reply.reply)) {
- logWarning(s"Ignore message: ${reply}")
- }
- }
- })
- }
- })
- } catch {
- case e: RejectedExecutionException =>
+ postToOutbox(remoteAddr, OutboxMessage(serialize(message), new RpcResponseCallback {
+
+ override def onFailure(e: Throwable): Unit = {
if (!promise.tryFailure(e)) {
- logWarning(s"Ignore failure", e)
+ logWarning("Ignore Exception", e)
}
- }
+ }
+
+ override def onSuccess(response: Array[Byte]): Unit = {
+ val reply = deserialize[AskResponse](response)
+ if (reply.reply.isInstanceOf[RpcFailure]) {
+ if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) {
+ logWarning(s"Ignore failure: ${reply.reply}")
+ }
+ } else if (!promise.trySuccess(reply.reply)) {
+ logWarning(s"Ignore message: ${reply}")
+ }
+ }
+ }))
}
promise.future
}
@@ -245,6 +265,16 @@ private[netty] class NettyRpcEnv(
}
private def cleanup(): Unit = {
+ if (!stopped.compareAndSet(false, true)) {
+ return
+ }
+
+ val iter = outboxes.values().iterator()
+ while (iter.hasNext()) {
+ val outbox = iter.next()
+ outboxes.remove(outbox.address)
+ outbox.stop()
+ }
if (timeoutScheduler != null) {
timeoutScheduler.shutdownNow()
}
@@ -463,6 +493,7 @@ private[netty] class NettyRpcHandler(
val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
if (addr != null) {
val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
+ nettyEnv.removeOutbox(clientAddr)
val messageOpt: Option[RemoteProcessDisconnected] =
synchronized {
remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress =>
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala
new file mode 100644
index 0000000000..7d9d593b36
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rpc.netty
+
+import java.util.concurrent.Callable
+import javax.annotation.concurrent.GuardedBy
+
+import scala.util.control.NonFatal
+
+import org.apache.spark.SparkException
+import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
+import org.apache.spark.rpc.RpcAddress
+
+private[netty] case class OutboxMessage(content: Array[Byte], callback: RpcResponseCallback)
+
+private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
+
+ outbox => // Give this an alias so we can use it more clearly in closures.
+
+ @GuardedBy("this")
+ private val messages = new java.util.LinkedList[OutboxMessage]
+
+ @GuardedBy("this")
+ private var client: TransportClient = null
+
+ /**
+ * connectFuture points to the connect task. If there is no connect task, connectFuture will be
+ * null.
+ */
+ @GuardedBy("this")
+ private var connectFuture: java.util.concurrent.Future[Unit] = null
+
+ @GuardedBy("this")
+ private var stopped = false
+
+ /**
+ * If there is any thread draining the message queue
+ */
+ @GuardedBy("this")
+ private var draining = false
+
+ /**
+ * Send a message. If there is no active connection, cache it and launch a new connection. If
+ * [[Outbox]] is stopped, the sender will be notified with a [[SparkException]].
+ */
+ def send(message: OutboxMessage): Unit = {
+ val dropped = synchronized {
+ if (stopped) {
+ true
+ } else {
+ messages.add(message)
+ false
+ }
+ }
+ if (dropped) {
+ message.callback.onFailure(new SparkException("Message is dropped because Outbox is stopped"))
+ } else {
+ drainOutbox()
+ }
+ }
+
+ /**
+ * Drain the message queue. If there is other draining thread, just exit. If the connection has
+ * not been established, launch a task in the `nettyEnv.clientConnectionExecutor` to setup the
+ * connection.
+ */
+ private def drainOutbox(): Unit = {
+ var message: OutboxMessage = null
+ synchronized {
+ if (stopped) {
+ return
+ }
+ if (connectFuture != null) {
+ // We are connecting to the remote address, so just exit
+ return
+ }
+ if (client == null) {
+ // There is no connect task but client is null, so we need to launch the connect task.
+ launchConnectTask()
+ return
+ }
+ if (draining) {
+ // There is some thread draining, so just exit
+ return
+ }
+ message = messages.poll()
+ if (message == null) {
+ return
+ }
+ draining = true
+ }
+ while (true) {
+ try {
+ val _client = synchronized { client }
+ if (_client != null) {
+ _client.sendRpc(message.content, message.callback)
+ } else {
+ assert(stopped == true)
+ }
+ } catch {
+ case NonFatal(e) =>
+ handleNetworkFailure(e)
+ return
+ }
+ synchronized {
+ if (stopped) {
+ return
+ }
+ message = messages.poll()
+ if (message == null) {
+ draining = false
+ return
+ }
+ }
+ }
+ }
+
+ private def launchConnectTask(): Unit = {
+ connectFuture = nettyEnv.clientConnectionExecutor.submit(new Callable[Unit] {
+
+ override def call(): Unit = {
+ try {
+ val _client = nettyEnv.createClient(address)
+ outbox.synchronized {
+ client = _client
+ if (stopped) {
+ closeClient()
+ }
+ }
+ } catch {
+ case ie: InterruptedException =>
+ // exit
+ return
+ case NonFatal(e) =>
+ outbox.synchronized { connectFuture = null }
+ handleNetworkFailure(e)
+ return
+ }
+ outbox.synchronized { connectFuture = null }
+ // It's possible that no thread is draining now. If we don't drain here, we cannot send the
+ // messages until the next message arrives.
+ drainOutbox()
+ }
+ })
+ }
+
+ /**
+ * Stop [[Inbox]] and notify the waiting messages with the cause.
+ */
+ private def handleNetworkFailure(e: Throwable): Unit = {
+ synchronized {
+ assert(connectFuture == null)
+ if (stopped) {
+ return
+ }
+ stopped = true
+ closeClient()
+ }
+ // Remove this Outbox from nettyEnv so that the further messages will create a new Outbox along
+ // with a new connection
+ nettyEnv.removeOutbox(address)
+
+ // Notify the connection failure for the remaining messages
+ //
+ // We always check `stopped` before updating messages, so here we can make sure no thread will
+ // update messages and it's safe to just drain the queue.
+ var message = messages.poll()
+ while (message != null) {
+ message.callback.onFailure(e)
+ message = messages.poll()
+ }
+ assert(messages.isEmpty)
+ }
+
+ private def closeClient(): Unit = synchronized {
+ // Not sure if `client.close` is idempotent. Just for safety.
+ if (client != null) {
+ client.close()
+ }
+ client = null
+ }
+
+ /**
+ * Stop [[Outbox]]. The remaining messages in the [[Outbox]] will be notified with a
+ * [[SparkException]].
+ */
+ def stop(): Unit = {
+ synchronized {
+ if (stopped) {
+ return
+ }
+ stopped = true
+ if (connectFuture != null) {
+ connectFuture.cancel(true)
+ }
+ closeClient()
+ }
+
+ // We always check `stopped` before updating messages, so here we can make sure no thread will
+ // update messages and it's safe to just drain the queue.
+ var message = messages.poll()
+ while (message != null) {
+ message.callback.onFailure(new SparkException("Message is dropped because Outbox is stopped"))
+ message = messages.poll()
+ }
+ }
+}