diff options
Diffstat (limited to 'core/src/main/scala/org/apache/spark/rpc/netty')
5 files changed, 142 insertions, 138 deletions
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index d71e6f01db..398e9eafc1 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -17,7 +17,7 @@ package org.apache.spark.rpc.netty -import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, TimeUnit} +import java.util.concurrent.{ThreadPoolExecutor, ConcurrentHashMap, LinkedBlockingQueue, TimeUnit} import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ @@ -38,12 +38,16 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { val inbox = new Inbox(ref, endpoint) } - private val endpoints = new ConcurrentHashMap[String, EndpointData]() - private val endpointRefs = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]() + private val endpoints = new ConcurrentHashMap[String, EndpointData] + private val endpointRefs = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef] // Track the receivers whose inboxes may contain messages. private val receivers = new LinkedBlockingQueue[EndpointData]() + /** + * True if the dispatcher has been stopped. Once stopped, all messages posted will be bounced + * immediately. + */ @GuardedBy("this") private var stopped = false @@ -59,7 +63,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { } val data = endpoints.get(name) endpointRefs.put(data.endpoint, data.ref) - receivers.put(data) + receivers.put(data) // for the OnStart message } endpointRef } @@ -73,7 +77,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { val data = endpoints.remove(name) if (data != null) { data.inbox.stop() - receivers.put(data) + receivers.put(data) // for the OnStop message } // Don't clean `endpointRefs` here because it's possible that some messages are being processed // now and they can use `getRpcEndpointRef`. So `endpointRefs` will be cleaned in Inbox via @@ -91,19 +95,23 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { } /** - * Send a message to all registered [[RpcEndpoint]]s. - * @param message + * Send a message to all registered [[RpcEndpoint]]s in this process. + * + * This can be used to make network events known to all end points (e.g. "a new node connected"). */ - def broadcastMessage(message: InboxMessage): Unit = { + def postToAll(message: InboxMessage): Unit = { val iter = endpoints.keySet().iterator() while (iter.hasNext) { val name = iter.next - postMessageToInbox(name, (_) => message, - () => { logWarning(s"Drop ${message} because ${name} has been stopped") }) + postMessage( + name, + _ => message, + () => { logWarning(s"Drop $message because $name has been stopped") }) } } - def postMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = { + /** Posts a message sent by a remote endpoint. */ + def postRemoteMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = { def createMessage(sender: NettyRpcEndpointRef): InboxMessage = { val rpcCallContext = new RemoteNettyRpcCallContext( @@ -116,10 +124,11 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { new SparkException(s"Could not find ${message.receiver.name} or it has been stopped")) } - postMessageToInbox(message.receiver.name, createMessage, onEndpointStopped) + postMessage(message.receiver.name, createMessage, onEndpointStopped) } - def postMessage(message: RequestMessage, p: Promise[Any]): Unit = { + /** Posts a message sent by a local endpoint. */ + def postLocalMessage(message: RequestMessage, p: Promise[Any]): Unit = { def createMessage(sender: NettyRpcEndpointRef): InboxMessage = { val rpcCallContext = new LocalNettyRpcCallContext(sender, message.senderAddress, message.needReply, p) @@ -131,39 +140,36 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { new SparkException(s"Could not find ${message.receiver.name} or it has been stopped")) } - postMessageToInbox(message.receiver.name, createMessage, onEndpointStopped) + postMessage(message.receiver.name, createMessage, onEndpointStopped) } - private def postMessageToInbox( + /** + * Posts a message to a specific endpoint. + * + * @param endpointName name of the endpoint. + * @param createMessageFn function to create the message. + * @param callbackIfStopped callback function if the endpoint is stopped. + */ + private def postMessage( endpointName: String, createMessageFn: NettyRpcEndpointRef => InboxMessage, - onStopped: () => Unit): Unit = { - val shouldCallOnStop = - synchronized { - val data = endpoints.get(endpointName) - if (stopped || data == null) { - true - } else { - data.inbox.post(createMessageFn(data.ref)) - receivers.put(data) - false - } + callbackIfStopped: () => Unit): Unit = { + val shouldCallOnStop = synchronized { + val data = endpoints.get(endpointName) + if (stopped || data == null) { + true + } else { + data.inbox.post(createMessageFn(data.ref)) + receivers.put(data) + false } + } if (shouldCallOnStop) { // We don't need to call `onStop` in the `synchronized` block - onStopped() + callbackIfStopped() } } - private val parallelism = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.parallelism", - Runtime.getRuntime.availableProcessors()) - - private val executor = ThreadUtils.newDaemonFixedThreadPool(parallelism, "dispatcher-event-loop") - - (0 until parallelism) foreach { _ => - executor.execute(new MessageLoop) - } - def stop(): Unit = { synchronized { if (stopped) { @@ -174,12 +180,12 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { // Stop all endpoints. This will queue all endpoints for processing by the message loops. endpoints.keySet().asScala.foreach(unregisterRpcEndpoint) // Enqueue a message that tells the message loops to stop. - receivers.put(PoisonEndpoint) - executor.shutdown() + receivers.put(PoisonPill) + threadpool.shutdown() } def awaitTermination(): Unit = { - executor.awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS) + threadpool.awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS) } /** @@ -189,15 +195,27 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { endpoints.containsKey(name) } + /** Thread pool used for dispatching messages. */ + private val threadpool: ThreadPoolExecutor = { + val numThreads = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.numThreads", + Runtime.getRuntime.availableProcessors()) + val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop") + for (i <- 0 until numThreads) { + pool.execute(new MessageLoop) + } + pool + } + + /** Message loop used for dispatching messages. */ private class MessageLoop extends Runnable { override def run(): Unit = { try { while (true) { try { val data = receivers.take() - if (data == PoisonEndpoint) { - // Put PoisonEndpoint back so that other MessageLoops can see it. - receivers.put(PoisonEndpoint) + if (data == PoisonPill) { + // Put PoisonPill back so that other MessageLoops can see it. + receivers.put(PoisonPill) return } data.inbox.process(Dispatcher.this) @@ -211,8 +229,6 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { } } - /** - * A poison endpoint that indicates MessageLoop should exit its loop. - */ - private val PoisonEndpoint = new EndpointData(null, null, null) + /** A poison endpoint that indicates MessageLoop should exit its message loop. */ + private val PoisonPill = new EndpointData(null, null, null) } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala b/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala index 6061c9b8de..fa9a3eb99b 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala @@ -26,8 +26,8 @@ private[netty] case class ID(name: String) /** * An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if a [[RpcEndpoint]] exists in this [[RpcEnv]] */ -private[netty] class IDVerifier( - override val rpcEnv: RpcEnv, dispatcher: Dispatcher) extends RpcEndpoint { +private[netty] class IDVerifier(override val rpcEnv: RpcEnv, dispatcher: Dispatcher) + extends RpcEndpoint { override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case ID(name) => context.reply(dispatcher.verify(name)) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala index b669f59a28..c72b588db5 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala @@ -17,14 +17,16 @@ package org.apache.spark.rpc.netty -import java.util.LinkedList import javax.annotation.concurrent.GuardedBy import scala.util.control.NonFatal +import com.google.common.annotations.VisibleForTesting + import org.apache.spark.{Logging, SparkException} import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, ThreadSafeRpcEndpoint} + private[netty] sealed trait InboxMessage private[netty] case class ContentMessage( @@ -37,44 +39,40 @@ private[netty] case object OnStart extends InboxMessage private[netty] case object OnStop extends InboxMessage -/** - * A broadcast message that indicates connecting to a remote node. - */ -private[netty] case class Associated(remoteAddress: RpcAddress) extends InboxMessage +/** A message to tell all endpoints that a remote process has connected. */ +private[netty] case class RemoteProcessConnected(remoteAddress: RpcAddress) extends InboxMessage -/** - * A broadcast message that indicates a remote connection is lost. - */ -private[netty] case class Disassociated(remoteAddress: RpcAddress) extends InboxMessage +/** A message to tell all endpoints that a remote process has disconnected. */ +private[netty] case class RemoteProcessDisconnected(remoteAddress: RpcAddress) extends InboxMessage -/** - * A broadcast message that indicates a network error - */ -private[netty] case class AssociationError(cause: Throwable, remoteAddress: RpcAddress) +/** A message to tell all endpoints that a network error has happened. */ +private[netty] case class RemoteProcessConnectionError(cause: Throwable, remoteAddress: RpcAddress) extends InboxMessage /** * A inbox that stores messages for an [[RpcEndpoint]] and posts messages to it thread-safely. - * @param endpointRef - * @param endpoint */ private[netty] class Inbox( val endpointRef: NettyRpcEndpointRef, - val endpoint: RpcEndpoint) extends Logging { + val endpoint: RpcEndpoint) + extends Logging { - inbox => + inbox => // Give this an alias so we can use it more clearly in closures. @GuardedBy("this") - protected val messages = new LinkedList[InboxMessage]() + protected val messages = new java.util.LinkedList[InboxMessage]() + /** True if the inbox (and its associated endpoint) is stopped. */ @GuardedBy("this") private var stopped = false + /** Allow multiple threads to process messages at the same time. */ @GuardedBy("this") private var enableConcurrent = false + /** The number of threads processing messages for this inbox. */ @GuardedBy("this") - private var workerCount = 0 + private var numActiveThreads = 0 // OnStart should be the first message to process inbox.synchronized { @@ -87,12 +85,12 @@ private[netty] class Inbox( def process(dispatcher: Dispatcher): Unit = { var message: InboxMessage = null inbox.synchronized { - if (!enableConcurrent && workerCount != 0) { + if (!enableConcurrent && numActiveThreads != 0) { return } message = messages.poll() if (message != null) { - workerCount += 1 + numActiveThreads += 1 } else { return } @@ -101,15 +99,11 @@ private[netty] class Inbox( safelyCall(endpoint) { message match { case ContentMessage(_sender, content, needReply, context) => - val pf: PartialFunction[Any, Unit] = - if (needReply) { - endpoint.receiveAndReply(context) - } else { - endpoint.receive - } + // The partial function to call + val pf = if (needReply) endpoint.receiveAndReply(context) else endpoint.receive try { pf.applyOrElse[Any, Unit](content, { msg => - throw new SparkException(s"Unmatched message $message from ${_sender}") + throw new SparkException(s"Unsupported message $message from ${_sender}") }) if (!needReply) { context.finish() @@ -121,11 +115,13 @@ private[netty] class Inbox( context.sendFailure(e) } else { context.finish() - throw e } + // Throw the exception -- this exception will be caught by the safelyCall function. + // The endpoint's onError function will be called. + throw e } - case OnStart => { + case OnStart => endpoint.onStart() if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) { inbox.synchronized { @@ -134,24 +130,22 @@ private[netty] class Inbox( } } } - } case OnStop => - val _workCount = inbox.synchronized { - workerCount - } - assert(_workCount == 1, s"There should be only one worker but was ${_workCount}") + val activeThreads = inbox.synchronized { inbox.numActiveThreads } + assert(activeThreads == 1, + s"There should be only a single active thread but found $activeThreads threads.") dispatcher.removeRpcEndpointRef(endpoint) endpoint.onStop() assert(isEmpty, "OnStop should be the last message") - case Associated(remoteAddress) => + case RemoteProcessConnected(remoteAddress) => endpoint.onConnected(remoteAddress) - case Disassociated(remoteAddress) => + case RemoteProcessDisconnected(remoteAddress) => endpoint.onDisconnected(remoteAddress) - case AssociationError(cause, remoteAddress) => + case RemoteProcessConnectionError(cause, remoteAddress) => endpoint.onNetworkError(cause, remoteAddress) } } @@ -159,33 +153,27 @@ private[netty] class Inbox( inbox.synchronized { // "enableConcurrent" will be set to false after `onStop` is called, so we should check it // every time. - if (!enableConcurrent && workerCount != 1) { + if (!enableConcurrent && numActiveThreads != 1) { // If we are not the only one worker, exit - workerCount -= 1 + numActiveThreads -= 1 return } message = messages.poll() if (message == null) { - workerCount -= 1 + numActiveThreads -= 1 return } } } } - def post(message: InboxMessage): Unit = { - val dropped = - inbox.synchronized { - if (stopped) { - // We already put "OnStop" into "messages", so we should drop further messages - true - } else { - messages.add(message) - false - } - } - if (dropped) { + def post(message: InboxMessage): Unit = inbox.synchronized { + if (stopped) { + // We already put "OnStop" into "messages", so we should drop further messages onDrop(message) + } else { + messages.add(message) + false } } @@ -203,24 +191,23 @@ private[netty] class Inbox( } } - // Visible for testing. + def isEmpty: Boolean = inbox.synchronized { messages.isEmpty } + + /** Called when we are dropping a message. Test cases override this to test message dropping. */ + @VisibleForTesting protected def onDrop(message: InboxMessage): Unit = { - logWarning(s"Drop ${message} because $endpointRef is stopped") + logWarning(s"Drop $message because $endpointRef is stopped") } - def isEmpty: Boolean = inbox.synchronized { messages.isEmpty } - + /** + * Calls action closure, and calls the endpoint's onError function in the case of exceptions. + */ private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = { - try { - action - } catch { - case NonFatal(e) => { - try { - endpoint.onError(e) - } catch { - case NonFatal(e) => logWarning(s"Ignore error", e) + try action catch { + case NonFatal(e) => + try endpoint.onError(e) catch { + case NonFatal(ee) => logError(s"Ignoring error", ee) } - } } } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala index 75dcc02a0c..21d5bb4923 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala @@ -26,7 +26,8 @@ import org.apache.spark.rpc.{RpcAddress, RpcCallContext} private[netty] abstract class NettyRpcCallContext( endpointRef: NettyRpcEndpointRef, override val senderAddress: RpcAddress, - needReply: Boolean) extends RpcCallContext with Logging { + needReply: Boolean) + extends RpcCallContext with Logging { protected def send(message: Any): Unit @@ -35,7 +36,7 @@ private[netty] abstract class NettyRpcCallContext( send(AskResponse(endpointRef, response)) } else { throw new IllegalStateException( - s"Cannot send $response to the sender because the sender won't handle it") + s"Cannot send $response to the sender because the sender does not expect a reply") } } @@ -63,7 +64,8 @@ private[netty] class LocalNettyRpcCallContext( endpointRef: NettyRpcEndpointRef, senderAddress: RpcAddress, needReply: Boolean, - p: Promise[Any]) extends NettyRpcCallContext(endpointRef, senderAddress, needReply) { + p: Promise[Any]) + extends NettyRpcCallContext(endpointRef, senderAddress, needReply) { override protected def send(message: Any): Unit = { p.success(message) @@ -78,7 +80,8 @@ private[netty] class RemoteNettyRpcCallContext( endpointRef: NettyRpcEndpointRef, callback: RpcResponseCallback, senderAddress: RpcAddress, - needReply: Boolean) extends NettyRpcCallContext(endpointRef, senderAddress, needReply) { + needReply: Boolean) + extends NettyRpcCallContext(endpointRef, senderAddress, needReply) { override protected def send(message: Any): Unit = { val reply = nettyEnv.serialize(message) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 5522b40782..89b6df76c2 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -19,7 +19,6 @@ package org.apache.spark.rpc.netty import java.io._ import java.net.{InetSocketAddress, URI} import java.nio.ByteBuffer -import java.util.Arrays import java.util.concurrent._ import javax.annotation.concurrent.GuardedBy @@ -77,19 +76,19 @@ private[netty] class NettyRpcEnv( @volatile private var server: TransportServer = _ def start(port: Int): Unit = { - val bootstraps: Seq[TransportServerBootstrap] = + val bootstraps: java.util.List[TransportServerBootstrap] = if (securityManager.isAuthenticationEnabled()) { - Seq(new SaslServerBootstrap(transportConf, securityManager)) + java.util.Arrays.asList(new SaslServerBootstrap(transportConf, securityManager)) } else { - Nil + java.util.Collections.emptyList() } - server = transportContext.createServer(port, bootstraps.asJava) + server = transportContext.createServer(port, bootstraps) dispatcher.registerRpcEndpoint(IDVerifier.NAME, new IDVerifier(this, dispatcher)) } override lazy val address: RpcAddress = { require(server != null, "NettyRpcEnv has not yet started") - RpcAddress(host, server.getPort()) + RpcAddress(host, server.getPort) } override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { @@ -119,7 +118,7 @@ private[netty] class NettyRpcEnv( val remoteAddr = message.receiver.address if (remoteAddr == address) { val promise = Promise[Any]() - dispatcher.postMessage(message, promise) + dispatcher.postLocalMessage(message, promise) promise.future.onComplete { case Success(response) => val ack = response.asInstanceOf[Ack] @@ -148,10 +147,9 @@ private[netty] class NettyRpcEnv( } }) } catch { - case e: RejectedExecutionException => { + case e: RejectedExecutionException => // `send` after shutting clientConnectionExecutor down, ignore it - logWarning(s"Cannot send ${message} because RpcEnv is stopped") - } + logWarning(s"Cannot send $message because RpcEnv is stopped") } } } @@ -161,7 +159,7 @@ private[netty] class NettyRpcEnv( val remoteAddr = message.receiver.address if (remoteAddr == address) { val p = Promise[Any]() - dispatcher.postMessage(message, p) + dispatcher.postLocalMessage(message, p) p.future.onComplete { case Success(response) => val reply = response.asInstanceOf[AskResponse] @@ -218,7 +216,7 @@ private[netty] class NettyRpcEnv( private[netty] def serialize(content: Any): Array[Byte] = { val buffer = javaSerializerInstance.serialize(content) - Arrays.copyOfRange( + java.util.Arrays.copyOfRange( buffer.array(), buffer.arrayOffset + buffer.position, buffer.arrayOffset + buffer.limit) } @@ -425,7 +423,7 @@ private[netty] class NettyRpcHandler( assert(addr != null) val remoteEnvAddress = requestMessage.senderAddress val clientAddr = RpcAddress(addr.getHostName, addr.getPort) - val broadcastMessage = + val broadcastMessage: Option[RemoteProcessConnected] = synchronized { // If the first connection to a remote RpcEnv is found, we should broadcast "Associated" if (remoteAddresses.put(clientAddr, remoteEnvAddress).isEmpty) { @@ -435,7 +433,7 @@ private[netty] class NettyRpcHandler( remoteConnectionCount.put(remoteEnvAddress, count + 1) if (count == 0) { // This is the first connection, so fire "Associated" - Some(Associated(remoteEnvAddress)) + Some(RemoteProcessConnected(remoteEnvAddress)) } else { None } @@ -443,8 +441,8 @@ private[netty] class NettyRpcHandler( None } } - broadcastMessage.foreach(dispatcher.broadcastMessage) - dispatcher.postMessage(requestMessage, callback) + broadcastMessage.foreach(dispatcher.postToAll) + dispatcher.postRemoteMessage(requestMessage, callback) } override def getStreamManager: StreamManager = new OneForOneStreamManager @@ -455,12 +453,12 @@ private[netty] class NettyRpcHandler( val clientAddr = RpcAddress(addr.getHostName, addr.getPort) val broadcastMessage = synchronized { - remoteAddresses.get(clientAddr).map(AssociationError(cause, _)) + remoteAddresses.get(clientAddr).map(RemoteProcessConnectionError(cause, _)) } if (broadcastMessage.isEmpty) { logError(cause.getMessage, cause) } else { - dispatcher.broadcastMessage(broadcastMessage.get) + dispatcher.postToAll(broadcastMessage.get) } } else { // If the channel is closed before connecting, its remoteAddress will be null. @@ -485,7 +483,7 @@ private[netty] class NettyRpcHandler( if (count - 1 == 0) { // We lost all clients, so clean up and fire "Disassociated" remoteConnectionCount.remove(remoteEnvAddress) - Some(Disassociated(remoteEnvAddress)) + Some(RemoteProcessDisconnected(remoteEnvAddress)) } else { // Decrease the connection number of remoteEnvAddress remoteConnectionCount.put(remoteEnvAddress, count - 1) @@ -493,7 +491,7 @@ private[netty] class NettyRpcHandler( } } } - broadcastMessage.foreach(dispatcher.broadcastMessage) + broadcastMessage.foreach(dispatcher.postToAll) } else { // If the channel is closed before connecting, its remoteAddress will be null. In this case, // we can ignore it since we don't fire "Associated". |