From 1797055dbf1d2fd7714d7c65c8d2efde2f15efc1 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 13 Oct 2015 09:51:20 -0700 Subject: [SPARK-11079] Post-hoc review Netty-based RPC - round 1 I'm going through the implementation right now for post-doc review. Adding more comments and renaming things as I go through them. I also want to write higher level documentation about how the whole thing works -- but those will come in other pull requests. Author: Reynold Xin Closes #9091 from rxin/rpc-review. --- .../scala/org/apache/spark/MapOutputTracker.scala | 2 +- .../scala/org/apache/spark/rpc/RpcAddress.scala | 50 +++++++ .../scala/org/apache/spark/rpc/RpcEndpoint.scala | 3 +- .../main/scala/org/apache/spark/rpc/RpcEnv.scala | 153 +-------------------- .../scala/org/apache/spark/rpc/RpcTimeout.scala | 131 ++++++++++++++++++ .../org/apache/spark/rpc/akka/AkkaRpcEnv.scala | 4 - .../org/apache/spark/rpc/netty/Dispatcher.scala | 108 ++++++++------- .../org/apache/spark/rpc/netty/IDVerifier.scala | 4 +- .../scala/org/apache/spark/rpc/netty/Inbox.scala | 119 +++++++--------- .../spark/rpc/netty/NettyRpcCallContext.scala | 11 +- .../org/apache/spark/rpc/netty/NettyRpcEnv.scala | 38 +++-- .../scala/org/apache/spark/util/ThreadUtils.scala | 1 - .../main/scala/org/apache/spark/util/Utils.scala | 1 + .../org/apache/spark/rpc/netty/InboxSuite.scala | 6 +- .../spark/rpc/netty/NettyRpcHandlerSuite.scala | 7 +- 15 files changed, 336 insertions(+), 302 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/rpc/RpcAddress.scala create mode 100644 core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 45e12e40c8..72355cdfa6 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -48,7 +48,7 @@ private[spark] class MapOutputTrackerMasterEndpoint( val hostPort = context.senderAddress.hostPort logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort) val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId) - val serializedSize = mapOutputStatuses.size + val serializedSize = mapOutputStatuses.length if (serializedSize > maxAkkaFrameSize) { val msg = s"Map output statuses were $serializedSize bytes which " + s"exceeds spark.akka.frameSize ($maxAkkaFrameSize bytes)." diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcAddress.scala b/core/src/main/scala/org/apache/spark/rpc/RpcAddress.scala new file mode 100644 index 0000000000..eb0b26947f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcAddress.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import org.apache.spark.util.Utils + + +/** + * Address for an RPC environment, with hostname and port. + */ +private[spark] case class RpcAddress(host: String, port: Int) { + + def hostPort: String = host + ":" + port + + /** Returns a string in the form of "spark://host:port". */ + def toSparkURL: String = "spark://" + hostPort + + override def toString: String = hostPort +} + + +private[spark] object RpcAddress { + + /** Return the [[RpcAddress]] represented by `uri`. */ + def fromURIString(uri: String): RpcAddress = { + val uriObj = new java.net.URI(uri) + RpcAddress(uriObj.getHost, uriObj.getPort) + } + + /** Returns the [[RpcAddress]] encoded in the form of "spark://host:port" */ + def fromSparkURL(sparkUrl: String): RpcAddress = { + val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl) + RpcAddress(host, port) + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala index f1ddc6d2cd..0ba9516952 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala @@ -145,5 +145,4 @@ private[spark] trait RpcEndpoint { * However, there is no guarantee that the same thread will be executing the same * [[ThreadSafeRpcEndpoint]] for different messages. */ -private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint { -} +private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 35e402c725..ef491a0ae4 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -17,12 +17,7 @@ package org.apache.spark.rpc -import java.net.URI -import java.util.concurrent.TimeoutException - -import scala.concurrent.{Awaitable, Await, Future} -import scala.concurrent.duration._ -import scala.language.postfixOps +import scala.concurrent.Future import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.util.{RpcUtils, Utils} @@ -35,8 +30,8 @@ import org.apache.spark.util.{RpcUtils, Utils} private[spark] object RpcEnv { private def getRpcEnvFactory(conf: SparkConf): RpcEnvFactory = { - // Add more RpcEnv implementations here - val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory", + val rpcEnvNames = Map( + "akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory", "netty" -> "org.apache.spark.rpc.netty.NettyRpcEnvFactory") val rpcEnvName = conf.get("spark.rpc", "netty") val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName) @@ -53,7 +48,6 @@ private[spark] object RpcEnv { val config = RpcEnvConfig(conf, name, host, port, securityManager) getRpcEnvFactory(conf).create(config) } - } @@ -155,144 +149,3 @@ private[spark] case class RpcEnvConfig( host: String, port: Int, securityManager: SecurityManager) - - -/** - * Represents a host and port. - */ -private[spark] case class RpcAddress(host: String, port: Int) { - // TODO do we need to add the type of RpcEnv in the address? - - val hostPort: String = host + ":" + port - - override val toString: String = hostPort - - def toSparkURL: String = "spark://" + hostPort -} - - -private[spark] object RpcAddress { - - /** - * Return the [[RpcAddress]] represented by `uri`. - */ - def fromURI(uri: URI): RpcAddress = { - RpcAddress(uri.getHost, uri.getPort) - } - - /** - * Return the [[RpcAddress]] represented by `uri`. - */ - def fromURIString(uri: String): RpcAddress = { - fromURI(new java.net.URI(uri)) - } - - def fromSparkURL(sparkUrl: String): RpcAddress = { - val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl) - RpcAddress(host, port) - } -} - - -/** - * An exception thrown if RpcTimeout modifies a [[TimeoutException]]. - */ -private[rpc] class RpcTimeoutException(message: String, cause: TimeoutException) - extends TimeoutException(message) { initCause(cause) } - - -/** - * Associates a timeout with a description so that a when a TimeoutException occurs, additional - * context about the timeout can be amended to the exception message. - * @param duration timeout duration in seconds - * @param timeoutProp the configuration property that controls this timeout - */ -private[spark] class RpcTimeout(val duration: FiniteDuration, val timeoutProp: String) - extends Serializable { - - /** Amends the standard message of TimeoutException to include the description */ - private def createRpcTimeoutException(te: TimeoutException): RpcTimeoutException = { - new RpcTimeoutException(te.getMessage() + ". This timeout is controlled by " + timeoutProp, te) - } - - /** - * PartialFunction to match a TimeoutException and add the timeout description to the message - * - * @note This can be used in the recover callback of a Future to add to a TimeoutException - * Example: - * val timeout = new RpcTimeout(5 millis, "short timeout") - * Future(throw new TimeoutException).recover(timeout.addMessageIfTimeout) - */ - def addMessageIfTimeout[T]: PartialFunction[Throwable, T] = { - // The exception has already been converted to a RpcTimeoutException so just raise it - case rte: RpcTimeoutException => throw rte - // Any other TimeoutException get converted to a RpcTimeoutException with modified message - case te: TimeoutException => throw createRpcTimeoutException(te) - } - - /** - * Wait for the completed result and return it. If the result is not available within this - * timeout, throw a [[RpcTimeoutException]] to indicate which configuration controls the timeout. - * @param awaitable the `Awaitable` to be awaited - * @throws RpcTimeoutException if after waiting for the specified time `awaitable` - * is still not ready - */ - def awaitResult[T](awaitable: Awaitable[T]): T = { - try { - Await.result(awaitable, duration) - } catch addMessageIfTimeout - } -} - - -private[spark] object RpcTimeout { - - /** - * Lookup the timeout property in the configuration and create - * a RpcTimeout with the property key in the description. - * @param conf configuration properties containing the timeout - * @param timeoutProp property key for the timeout in seconds - * @throws NoSuchElementException if property is not set - */ - def apply(conf: SparkConf, timeoutProp: String): RpcTimeout = { - val timeout = { conf.getTimeAsSeconds(timeoutProp) seconds } - new RpcTimeout(timeout, timeoutProp) - } - - /** - * Lookup the timeout property in the configuration and create - * a RpcTimeout with the property key in the description. - * Uses the given default value if property is not set - * @param conf configuration properties containing the timeout - * @param timeoutProp property key for the timeout in seconds - * @param defaultValue default timeout value in seconds if property not found - */ - def apply(conf: SparkConf, timeoutProp: String, defaultValue: String): RpcTimeout = { - val timeout = { conf.getTimeAsSeconds(timeoutProp, defaultValue) seconds } - new RpcTimeout(timeout, timeoutProp) - } - - /** - * Lookup prioritized list of timeout properties in the configuration - * and create a RpcTimeout with the first set property key in the - * description. - * Uses the given default value if property is not set - * @param conf configuration properties containing the timeout - * @param timeoutPropList prioritized list of property keys for the timeout in seconds - * @param defaultValue default timeout value in seconds if no properties found - */ - def apply(conf: SparkConf, timeoutPropList: Seq[String], defaultValue: String): RpcTimeout = { - require(timeoutPropList.nonEmpty) - - // Find the first set property or use the default value with the first property - val itr = timeoutPropList.iterator - var foundProp: Option[(String, String)] = None - while (itr.hasNext && foundProp.isEmpty){ - val propKey = itr.next() - conf.getOption(propKey).foreach { prop => foundProp = Some(propKey, prop) } - } - val finalProp = foundProp.getOrElse(timeoutPropList.head, defaultValue) - val timeout = { Utils.timeStringAsSeconds(finalProp._2) seconds } - new RpcTimeout(timeout, finalProp._1) - } -} diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala new file mode 100644 index 0000000000..285786ebf9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import java.util.concurrent.TimeoutException + +import scala.concurrent.{Awaitable, Await} +import scala.concurrent.duration._ + +import org.apache.spark.SparkConf +import org.apache.spark.util.Utils + + +/** + * An exception thrown if RpcTimeout modifies a [[TimeoutException]]. + */ +private[rpc] class RpcTimeoutException(message: String, cause: TimeoutException) + extends TimeoutException(message) { initCause(cause) } + + +/** + * Associates a timeout with a description so that a when a TimeoutException occurs, additional + * context about the timeout can be amended to the exception message. + * + * @param duration timeout duration in seconds + * @param timeoutProp the configuration property that controls this timeout + */ +private[spark] class RpcTimeout(val duration: FiniteDuration, val timeoutProp: String) + extends Serializable { + + /** Amends the standard message of TimeoutException to include the description */ + private def createRpcTimeoutException(te: TimeoutException): RpcTimeoutException = { + new RpcTimeoutException(te.getMessage + ". This timeout is controlled by " + timeoutProp, te) + } + + /** + * PartialFunction to match a TimeoutException and add the timeout description to the message + * + * @note This can be used in the recover callback of a Future to add to a TimeoutException + * Example: + * val timeout = new RpcTimeout(5 millis, "short timeout") + * Future(throw new TimeoutException).recover(timeout.addMessageIfTimeout) + */ + def addMessageIfTimeout[T]: PartialFunction[Throwable, T] = { + // The exception has already been converted to a RpcTimeoutException so just raise it + case rte: RpcTimeoutException => throw rte + // Any other TimeoutException get converted to a RpcTimeoutException with modified message + case te: TimeoutException => throw createRpcTimeoutException(te) + } + + /** + * Wait for the completed result and return it. If the result is not available within this + * timeout, throw a [[RpcTimeoutException]] to indicate which configuration controls the timeout. + * @param awaitable the `Awaitable` to be awaited + * @throws RpcTimeoutException if after waiting for the specified time `awaitable` + * is still not ready + */ + def awaitResult[T](awaitable: Awaitable[T]): T = { + try { + Await.result(awaitable, duration) + } catch addMessageIfTimeout + } +} + + +private[spark] object RpcTimeout { + + /** + * Lookup the timeout property in the configuration and create + * a RpcTimeout with the property key in the description. + * @param conf configuration properties containing the timeout + * @param timeoutProp property key for the timeout in seconds + * @throws NoSuchElementException if property is not set + */ + def apply(conf: SparkConf, timeoutProp: String): RpcTimeout = { + val timeout = { conf.getTimeAsSeconds(timeoutProp).seconds } + new RpcTimeout(timeout, timeoutProp) + } + + /** + * Lookup the timeout property in the configuration and create + * a RpcTimeout with the property key in the description. + * Uses the given default value if property is not set + * @param conf configuration properties containing the timeout + * @param timeoutProp property key for the timeout in seconds + * @param defaultValue default timeout value in seconds if property not found + */ + def apply(conf: SparkConf, timeoutProp: String, defaultValue: String): RpcTimeout = { + val timeout = { conf.getTimeAsSeconds(timeoutProp, defaultValue).seconds } + new RpcTimeout(timeout, timeoutProp) + } + + /** + * Lookup prioritized list of timeout properties in the configuration + * and create a RpcTimeout with the first set property key in the + * description. + * Uses the given default value if property is not set + * @param conf configuration properties containing the timeout + * @param timeoutPropList prioritized list of property keys for the timeout in seconds + * @param defaultValue default timeout value in seconds if no properties found + */ + def apply(conf: SparkConf, timeoutPropList: Seq[String], defaultValue: String): RpcTimeout = { + require(timeoutPropList.nonEmpty) + + // Find the first set property or use the default value with the first property + val itr = timeoutPropList.iterator + var foundProp: Option[(String, String)] = None + while (itr.hasNext && foundProp.isEmpty){ + val propKey = itr.next() + conf.getOption(propKey).foreach { prop => foundProp = Some(propKey, prop) } + } + val finalProp = foundProp.getOrElse(timeoutPropList.head, defaultValue) + val timeout = { Utils.timeStringAsSeconds(finalProp._2).seconds } + new RpcTimeout(timeout, finalProp._1) + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 95132a4e4a..3fad595a0d 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -39,10 +39,6 @@ import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils} * * TODO Once we remove all usages of Akka in other place, we can move this file to a new project and * remove Akka from the dependencies. - * - * @param actorSystem - * @param conf - * @param boundPort */ private[spark] class AkkaRpcEnv private[akka] ( val actorSystem: ActorSystem, conf: SparkConf, boundPort: Int) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index d71e6f01db..398e9eafc1 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -17,7 +17,7 @@ package org.apache.spark.rpc.netty -import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, TimeUnit} +import java.util.concurrent.{ThreadPoolExecutor, ConcurrentHashMap, LinkedBlockingQueue, TimeUnit} import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ @@ -38,12 +38,16 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { val inbox = new Inbox(ref, endpoint) } - private val endpoints = new ConcurrentHashMap[String, EndpointData]() - private val endpointRefs = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]() + private val endpoints = new ConcurrentHashMap[String, EndpointData] + private val endpointRefs = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef] // Track the receivers whose inboxes may contain messages. private val receivers = new LinkedBlockingQueue[EndpointData]() + /** + * True if the dispatcher has been stopped. Once stopped, all messages posted will be bounced + * immediately. + */ @GuardedBy("this") private var stopped = false @@ -59,7 +63,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { } val data = endpoints.get(name) endpointRefs.put(data.endpoint, data.ref) - receivers.put(data) + receivers.put(data) // for the OnStart message } endpointRef } @@ -73,7 +77,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { val data = endpoints.remove(name) if (data != null) { data.inbox.stop() - receivers.put(data) + receivers.put(data) // for the OnStop message } // Don't clean `endpointRefs` here because it's possible that some messages are being processed // now and they can use `getRpcEndpointRef`. So `endpointRefs` will be cleaned in Inbox via @@ -91,19 +95,23 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { } /** - * Send a message to all registered [[RpcEndpoint]]s. - * @param message + * Send a message to all registered [[RpcEndpoint]]s in this process. + * + * This can be used to make network events known to all end points (e.g. "a new node connected"). */ - def broadcastMessage(message: InboxMessage): Unit = { + def postToAll(message: InboxMessage): Unit = { val iter = endpoints.keySet().iterator() while (iter.hasNext) { val name = iter.next - postMessageToInbox(name, (_) => message, - () => { logWarning(s"Drop ${message} because ${name} has been stopped") }) + postMessage( + name, + _ => message, + () => { logWarning(s"Drop $message because $name has been stopped") }) } } - def postMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = { + /** Posts a message sent by a remote endpoint. */ + def postRemoteMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = { def createMessage(sender: NettyRpcEndpointRef): InboxMessage = { val rpcCallContext = new RemoteNettyRpcCallContext( @@ -116,10 +124,11 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { new SparkException(s"Could not find ${message.receiver.name} or it has been stopped")) } - postMessageToInbox(message.receiver.name, createMessage, onEndpointStopped) + postMessage(message.receiver.name, createMessage, onEndpointStopped) } - def postMessage(message: RequestMessage, p: Promise[Any]): Unit = { + /** Posts a message sent by a local endpoint. */ + def postLocalMessage(message: RequestMessage, p: Promise[Any]): Unit = { def createMessage(sender: NettyRpcEndpointRef): InboxMessage = { val rpcCallContext = new LocalNettyRpcCallContext(sender, message.senderAddress, message.needReply, p) @@ -131,39 +140,36 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { new SparkException(s"Could not find ${message.receiver.name} or it has been stopped")) } - postMessageToInbox(message.receiver.name, createMessage, onEndpointStopped) + postMessage(message.receiver.name, createMessage, onEndpointStopped) } - private def postMessageToInbox( + /** + * Posts a message to a specific endpoint. + * + * @param endpointName name of the endpoint. + * @param createMessageFn function to create the message. + * @param callbackIfStopped callback function if the endpoint is stopped. + */ + private def postMessage( endpointName: String, createMessageFn: NettyRpcEndpointRef => InboxMessage, - onStopped: () => Unit): Unit = { - val shouldCallOnStop = - synchronized { - val data = endpoints.get(endpointName) - if (stopped || data == null) { - true - } else { - data.inbox.post(createMessageFn(data.ref)) - receivers.put(data) - false - } + callbackIfStopped: () => Unit): Unit = { + val shouldCallOnStop = synchronized { + val data = endpoints.get(endpointName) + if (stopped || data == null) { + true + } else { + data.inbox.post(createMessageFn(data.ref)) + receivers.put(data) + false } + } if (shouldCallOnStop) { // We don't need to call `onStop` in the `synchronized` block - onStopped() + callbackIfStopped() } } - private val parallelism = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.parallelism", - Runtime.getRuntime.availableProcessors()) - - private val executor = ThreadUtils.newDaemonFixedThreadPool(parallelism, "dispatcher-event-loop") - - (0 until parallelism) foreach { _ => - executor.execute(new MessageLoop) - } - def stop(): Unit = { synchronized { if (stopped) { @@ -174,12 +180,12 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { // Stop all endpoints. This will queue all endpoints for processing by the message loops. endpoints.keySet().asScala.foreach(unregisterRpcEndpoint) // Enqueue a message that tells the message loops to stop. - receivers.put(PoisonEndpoint) - executor.shutdown() + receivers.put(PoisonPill) + threadpool.shutdown() } def awaitTermination(): Unit = { - executor.awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS) + threadpool.awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS) } /** @@ -189,15 +195,27 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { endpoints.containsKey(name) } + /** Thread pool used for dispatching messages. */ + private val threadpool: ThreadPoolExecutor = { + val numThreads = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.numThreads", + Runtime.getRuntime.availableProcessors()) + val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop") + for (i <- 0 until numThreads) { + pool.execute(new MessageLoop) + } + pool + } + + /** Message loop used for dispatching messages. */ private class MessageLoop extends Runnable { override def run(): Unit = { try { while (true) { try { val data = receivers.take() - if (data == PoisonEndpoint) { - // Put PoisonEndpoint back so that other MessageLoops can see it. - receivers.put(PoisonEndpoint) + if (data == PoisonPill) { + // Put PoisonPill back so that other MessageLoops can see it. + receivers.put(PoisonPill) return } data.inbox.process(Dispatcher.this) @@ -211,8 +229,6 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { } } - /** - * A poison endpoint that indicates MessageLoop should exit its loop. - */ - private val PoisonEndpoint = new EndpointData(null, null, null) + /** A poison endpoint that indicates MessageLoop should exit its message loop. */ + private val PoisonPill = new EndpointData(null, null, null) } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala b/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala index 6061c9b8de..fa9a3eb99b 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala @@ -26,8 +26,8 @@ private[netty] case class ID(name: String) /** * An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if a [[RpcEndpoint]] exists in this [[RpcEnv]] */ -private[netty] class IDVerifier( - override val rpcEnv: RpcEnv, dispatcher: Dispatcher) extends RpcEndpoint { +private[netty] class IDVerifier(override val rpcEnv: RpcEnv, dispatcher: Dispatcher) + extends RpcEndpoint { override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case ID(name) => context.reply(dispatcher.verify(name)) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala index b669f59a28..c72b588db5 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala @@ -17,14 +17,16 @@ package org.apache.spark.rpc.netty -import java.util.LinkedList import javax.annotation.concurrent.GuardedBy import scala.util.control.NonFatal +import com.google.common.annotations.VisibleForTesting + import org.apache.spark.{Logging, SparkException} import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, ThreadSafeRpcEndpoint} + private[netty] sealed trait InboxMessage private[netty] case class ContentMessage( @@ -37,44 +39,40 @@ private[netty] case object OnStart extends InboxMessage private[netty] case object OnStop extends InboxMessage -/** - * A broadcast message that indicates connecting to a remote node. - */ -private[netty] case class Associated(remoteAddress: RpcAddress) extends InboxMessage +/** A message to tell all endpoints that a remote process has connected. */ +private[netty] case class RemoteProcessConnected(remoteAddress: RpcAddress) extends InboxMessage -/** - * A broadcast message that indicates a remote connection is lost. - */ -private[netty] case class Disassociated(remoteAddress: RpcAddress) extends InboxMessage +/** A message to tell all endpoints that a remote process has disconnected. */ +private[netty] case class RemoteProcessDisconnected(remoteAddress: RpcAddress) extends InboxMessage -/** - * A broadcast message that indicates a network error - */ -private[netty] case class AssociationError(cause: Throwable, remoteAddress: RpcAddress) +/** A message to tell all endpoints that a network error has happened. */ +private[netty] case class RemoteProcessConnectionError(cause: Throwable, remoteAddress: RpcAddress) extends InboxMessage /** * A inbox that stores messages for an [[RpcEndpoint]] and posts messages to it thread-safely. - * @param endpointRef - * @param endpoint */ private[netty] class Inbox( val endpointRef: NettyRpcEndpointRef, - val endpoint: RpcEndpoint) extends Logging { + val endpoint: RpcEndpoint) + extends Logging { - inbox => + inbox => // Give this an alias so we can use it more clearly in closures. @GuardedBy("this") - protected val messages = new LinkedList[InboxMessage]() + protected val messages = new java.util.LinkedList[InboxMessage]() + /** True if the inbox (and its associated endpoint) is stopped. */ @GuardedBy("this") private var stopped = false + /** Allow multiple threads to process messages at the same time. */ @GuardedBy("this") private var enableConcurrent = false + /** The number of threads processing messages for this inbox. */ @GuardedBy("this") - private var workerCount = 0 + private var numActiveThreads = 0 // OnStart should be the first message to process inbox.synchronized { @@ -87,12 +85,12 @@ private[netty] class Inbox( def process(dispatcher: Dispatcher): Unit = { var message: InboxMessage = null inbox.synchronized { - if (!enableConcurrent && workerCount != 0) { + if (!enableConcurrent && numActiveThreads != 0) { return } message = messages.poll() if (message != null) { - workerCount += 1 + numActiveThreads += 1 } else { return } @@ -101,15 +99,11 @@ private[netty] class Inbox( safelyCall(endpoint) { message match { case ContentMessage(_sender, content, needReply, context) => - val pf: PartialFunction[Any, Unit] = - if (needReply) { - endpoint.receiveAndReply(context) - } else { - endpoint.receive - } + // The partial function to call + val pf = if (needReply) endpoint.receiveAndReply(context) else endpoint.receive try { pf.applyOrElse[Any, Unit](content, { msg => - throw new SparkException(s"Unmatched message $message from ${_sender}") + throw new SparkException(s"Unsupported message $message from ${_sender}") }) if (!needReply) { context.finish() @@ -121,11 +115,13 @@ private[netty] class Inbox( context.sendFailure(e) } else { context.finish() - throw e } + // Throw the exception -- this exception will be caught by the safelyCall function. + // The endpoint's onError function will be called. + throw e } - case OnStart => { + case OnStart => endpoint.onStart() if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) { inbox.synchronized { @@ -134,24 +130,22 @@ private[netty] class Inbox( } } } - } case OnStop => - val _workCount = inbox.synchronized { - workerCount - } - assert(_workCount == 1, s"There should be only one worker but was ${_workCount}") + val activeThreads = inbox.synchronized { inbox.numActiveThreads } + assert(activeThreads == 1, + s"There should be only a single active thread but found $activeThreads threads.") dispatcher.removeRpcEndpointRef(endpoint) endpoint.onStop() assert(isEmpty, "OnStop should be the last message") - case Associated(remoteAddress) => + case RemoteProcessConnected(remoteAddress) => endpoint.onConnected(remoteAddress) - case Disassociated(remoteAddress) => + case RemoteProcessDisconnected(remoteAddress) => endpoint.onDisconnected(remoteAddress) - case AssociationError(cause, remoteAddress) => + case RemoteProcessConnectionError(cause, remoteAddress) => endpoint.onNetworkError(cause, remoteAddress) } } @@ -159,33 +153,27 @@ private[netty] class Inbox( inbox.synchronized { // "enableConcurrent" will be set to false after `onStop` is called, so we should check it // every time. - if (!enableConcurrent && workerCount != 1) { + if (!enableConcurrent && numActiveThreads != 1) { // If we are not the only one worker, exit - workerCount -= 1 + numActiveThreads -= 1 return } message = messages.poll() if (message == null) { - workerCount -= 1 + numActiveThreads -= 1 return } } } } - def post(message: InboxMessage): Unit = { - val dropped = - inbox.synchronized { - if (stopped) { - // We already put "OnStop" into "messages", so we should drop further messages - true - } else { - messages.add(message) - false - } - } - if (dropped) { + def post(message: InboxMessage): Unit = inbox.synchronized { + if (stopped) { + // We already put "OnStop" into "messages", so we should drop further messages onDrop(message) + } else { + messages.add(message) + false } } @@ -203,24 +191,23 @@ private[netty] class Inbox( } } - // Visible for testing. + def isEmpty: Boolean = inbox.synchronized { messages.isEmpty } + + /** Called when we are dropping a message. Test cases override this to test message dropping. */ + @VisibleForTesting protected def onDrop(message: InboxMessage): Unit = { - logWarning(s"Drop ${message} because $endpointRef is stopped") + logWarning(s"Drop $message because $endpointRef is stopped") } - def isEmpty: Boolean = inbox.synchronized { messages.isEmpty } - + /** + * Calls action closure, and calls the endpoint's onError function in the case of exceptions. + */ private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = { - try { - action - } catch { - case NonFatal(e) => { - try { - endpoint.onError(e) - } catch { - case NonFatal(e) => logWarning(s"Ignore error", e) + try action catch { + case NonFatal(e) => + try endpoint.onError(e) catch { + case NonFatal(ee) => logError(s"Ignoring error", ee) } - } } } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala index 75dcc02a0c..21d5bb4923 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala @@ -26,7 +26,8 @@ import org.apache.spark.rpc.{RpcAddress, RpcCallContext} private[netty] abstract class NettyRpcCallContext( endpointRef: NettyRpcEndpointRef, override val senderAddress: RpcAddress, - needReply: Boolean) extends RpcCallContext with Logging { + needReply: Boolean) + extends RpcCallContext with Logging { protected def send(message: Any): Unit @@ -35,7 +36,7 @@ private[netty] abstract class NettyRpcCallContext( send(AskResponse(endpointRef, response)) } else { throw new IllegalStateException( - s"Cannot send $response to the sender because the sender won't handle it") + s"Cannot send $response to the sender because the sender does not expect a reply") } } @@ -63,7 +64,8 @@ private[netty] class LocalNettyRpcCallContext( endpointRef: NettyRpcEndpointRef, senderAddress: RpcAddress, needReply: Boolean, - p: Promise[Any]) extends NettyRpcCallContext(endpointRef, senderAddress, needReply) { + p: Promise[Any]) + extends NettyRpcCallContext(endpointRef, senderAddress, needReply) { override protected def send(message: Any): Unit = { p.success(message) @@ -78,7 +80,8 @@ private[netty] class RemoteNettyRpcCallContext( endpointRef: NettyRpcEndpointRef, callback: RpcResponseCallback, senderAddress: RpcAddress, - needReply: Boolean) extends NettyRpcCallContext(endpointRef, senderAddress, needReply) { + needReply: Boolean) + extends NettyRpcCallContext(endpointRef, senderAddress, needReply) { override protected def send(message: Any): Unit = { val reply = nettyEnv.serialize(message) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 5522b40782..89b6df76c2 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -19,7 +19,6 @@ package org.apache.spark.rpc.netty import java.io._ import java.net.{InetSocketAddress, URI} import java.nio.ByteBuffer -import java.util.Arrays import java.util.concurrent._ import javax.annotation.concurrent.GuardedBy @@ -77,19 +76,19 @@ private[netty] class NettyRpcEnv( @volatile private var server: TransportServer = _ def start(port: Int): Unit = { - val bootstraps: Seq[TransportServerBootstrap] = + val bootstraps: java.util.List[TransportServerBootstrap] = if (securityManager.isAuthenticationEnabled()) { - Seq(new SaslServerBootstrap(transportConf, securityManager)) + java.util.Arrays.asList(new SaslServerBootstrap(transportConf, securityManager)) } else { - Nil + java.util.Collections.emptyList() } - server = transportContext.createServer(port, bootstraps.asJava) + server = transportContext.createServer(port, bootstraps) dispatcher.registerRpcEndpoint(IDVerifier.NAME, new IDVerifier(this, dispatcher)) } override lazy val address: RpcAddress = { require(server != null, "NettyRpcEnv has not yet started") - RpcAddress(host, server.getPort()) + RpcAddress(host, server.getPort) } override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { @@ -119,7 +118,7 @@ private[netty] class NettyRpcEnv( val remoteAddr = message.receiver.address if (remoteAddr == address) { val promise = Promise[Any]() - dispatcher.postMessage(message, promise) + dispatcher.postLocalMessage(message, promise) promise.future.onComplete { case Success(response) => val ack = response.asInstanceOf[Ack] @@ -148,10 +147,9 @@ private[netty] class NettyRpcEnv( } }) } catch { - case e: RejectedExecutionException => { + case e: RejectedExecutionException => // `send` after shutting clientConnectionExecutor down, ignore it - logWarning(s"Cannot send ${message} because RpcEnv is stopped") - } + logWarning(s"Cannot send $message because RpcEnv is stopped") } } } @@ -161,7 +159,7 @@ private[netty] class NettyRpcEnv( val remoteAddr = message.receiver.address if (remoteAddr == address) { val p = Promise[Any]() - dispatcher.postMessage(message, p) + dispatcher.postLocalMessage(message, p) p.future.onComplete { case Success(response) => val reply = response.asInstanceOf[AskResponse] @@ -218,7 +216,7 @@ private[netty] class NettyRpcEnv( private[netty] def serialize(content: Any): Array[Byte] = { val buffer = javaSerializerInstance.serialize(content) - Arrays.copyOfRange( + java.util.Arrays.copyOfRange( buffer.array(), buffer.arrayOffset + buffer.position, buffer.arrayOffset + buffer.limit) } @@ -425,7 +423,7 @@ private[netty] class NettyRpcHandler( assert(addr != null) val remoteEnvAddress = requestMessage.senderAddress val clientAddr = RpcAddress(addr.getHostName, addr.getPort) - val broadcastMessage = + val broadcastMessage: Option[RemoteProcessConnected] = synchronized { // If the first connection to a remote RpcEnv is found, we should broadcast "Associated" if (remoteAddresses.put(clientAddr, remoteEnvAddress).isEmpty) { @@ -435,7 +433,7 @@ private[netty] class NettyRpcHandler( remoteConnectionCount.put(remoteEnvAddress, count + 1) if (count == 0) { // This is the first connection, so fire "Associated" - Some(Associated(remoteEnvAddress)) + Some(RemoteProcessConnected(remoteEnvAddress)) } else { None } @@ -443,8 +441,8 @@ private[netty] class NettyRpcHandler( None } } - broadcastMessage.foreach(dispatcher.broadcastMessage) - dispatcher.postMessage(requestMessage, callback) + broadcastMessage.foreach(dispatcher.postToAll) + dispatcher.postRemoteMessage(requestMessage, callback) } override def getStreamManager: StreamManager = new OneForOneStreamManager @@ -455,12 +453,12 @@ private[netty] class NettyRpcHandler( val clientAddr = RpcAddress(addr.getHostName, addr.getPort) val broadcastMessage = synchronized { - remoteAddresses.get(clientAddr).map(AssociationError(cause, _)) + remoteAddresses.get(clientAddr).map(RemoteProcessConnectionError(cause, _)) } if (broadcastMessage.isEmpty) { logError(cause.getMessage, cause) } else { - dispatcher.broadcastMessage(broadcastMessage.get) + dispatcher.postToAll(broadcastMessage.get) } } else { // If the channel is closed before connecting, its remoteAddress will be null. @@ -485,7 +483,7 @@ private[netty] class NettyRpcHandler( if (count - 1 == 0) { // We lost all clients, so clean up and fire "Disassociated" remoteConnectionCount.remove(remoteEnvAddress) - Some(Disassociated(remoteEnvAddress)) + Some(RemoteProcessDisconnected(remoteEnvAddress)) } else { // Decrease the connection number of remoteEnvAddress remoteConnectionCount.put(remoteEnvAddress, count - 1) @@ -493,7 +491,7 @@ private[netty] class NettyRpcHandler( } } } - broadcastMessage.foreach(dispatcher.broadcastMessage) + broadcastMessage.foreach(dispatcher.postToAll) } else { // If the channel is closed before connecting, its remoteAddress will be null. In this case, // we can ignore it since we don't fire "Associated". diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 1ed098379e..15e7519d70 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -15,7 +15,6 @@ * limitations under the License. */ - package org.apache.spark.util import java.util.concurrent._ diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index e60c1b355a..bd7e51c3b5 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1895,6 +1895,7 @@ private[spark] object Utils extends Logging { * This is expected to throw java.net.BindException on port collision. * @param conf A SparkConf used to get the maximum number of retries when binding to a port. * @param serviceName Name of the service. + * @return (service: T, port: Int) */ def startServiceOnPort[T]( startPort: Int, diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala index 120cf1b6fa..276c077b3d 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala @@ -113,7 +113,7 @@ class InboxSuite extends SparkFunSuite { val remoteAddress = RpcAddress("localhost", 11111) val inbox = new Inbox(endpointRef, endpoint) - inbox.post(Associated(remoteAddress)) + inbox.post(RemoteProcessConnected(remoteAddress)) inbox.process(dispatcher) endpoint.verifySingleOnConnectedMessage(remoteAddress) @@ -127,7 +127,7 @@ class InboxSuite extends SparkFunSuite { val remoteAddress = RpcAddress("localhost", 11111) val inbox = new Inbox(endpointRef, endpoint) - inbox.post(Disassociated(remoteAddress)) + inbox.post(RemoteProcessDisconnected(remoteAddress)) inbox.process(dispatcher) endpoint.verifySingleOnDisconnectedMessage(remoteAddress) @@ -142,7 +142,7 @@ class InboxSuite extends SparkFunSuite { val cause = new RuntimeException("Oops") val inbox = new Inbox(endpointRef, endpoint) - inbox.post(AssociationError(cause, remoteAddress)) + inbox.post(RemoteProcessConnectionError(cause, remoteAddress)) inbox.process(dispatcher) endpoint.verifySingleOnNetworkErrorMessage(cause, remoteAddress) diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala index 06ca035d19..f24f78b8c4 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -45,7 +45,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite { when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40001)) nettyRpcHandler.receive(client, null, null) - verify(dispatcher, times(1)).broadcastMessage(Associated(RpcAddress("localhost", 12345))) + verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 12345))) } test("connectionTerminated") { @@ -60,8 +60,9 @@ class NettyRpcHandlerSuite extends SparkFunSuite { when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) nettyRpcHandler.connectionTerminated(client) - verify(dispatcher, times(1)).broadcastMessage(Associated(RpcAddress("localhost", 12345))) - verify(dispatcher, times(1)).broadcastMessage(Disassociated(RpcAddress("localhost", 12345))) + verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 12345))) + verify(dispatcher, times(1)).postToAll( + RemoteProcessDisconnected(RpcAddress("localhost", 12345))) } } -- cgit v1.2.3