diff options
author | Reynold Xin <rxin@databricks.com> | 2015-10-14 12:41:02 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-10-14 12:41:02 -0700 |
commit | cf2e0ae7205443f052463e8cb9334ae2b6df2d0e (patch) | |
tree | 593d9fb9ef83400d694b6cce268bfa29f9e5b132 /core/src | |
parent | 615cc858cf913522059b6ebdde65f0204f4fb030 (diff) | |
download | spark-cf2e0ae7205443f052463e8cb9334ae2b6df2d0e.tar.gz spark-cf2e0ae7205443f052463e8cb9334ae2b6df2d0e.tar.bz2 spark-cf2e0ae7205443f052463e8cb9334ae2b6df2d0e.zip |
[SPARK-11096] Post-hoc review Netty based RPC implementation - round 2
A few more changes:
1. Renamed IDVerifier -> RpcEndpointVerifier
2. Renamed NettyRpcAddress -> RpcEndpointAddress
3. Simplified NettyRpcHandler a bit by removing the connection count tracking. This is OK because I now force spark.shuffle.io.numConnectionsPerPeer to 1
4. Reduced spark.rpc.connect.threads to 64. It would be great to eventually remove this extra thread pool.
5. Minor cleanup & documentation.
Author: Reynold Xin <rxin@databricks.com>
Closes #9112 from rxin/SPARK-11096.
Diffstat (limited to 'core/src')
-rw-r--r-- | core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala | 9 | ||||
-rw-r--r-- | core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala | 7 | ||||
-rw-r--r-- | core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala | 114 | ||||
-rw-r--r-- | core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala (renamed from core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala) | 32 | ||||
-rw-r--r-- | core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala (renamed from core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala) | 21 | ||||
-rw-r--r-- | core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala | 2 | ||||
-rw-r--r-- | core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala | 3 |
7 files changed, 81 insertions, 107 deletions
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index ef491a0ae4..2c4a8b9a0a 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -94,15 +94,6 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { } /** - * Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName` - * asynchronously. - */ - def asyncSetupEndpointRef( - systemName: String, address: RpcAddress, endpointName: String): Future[RpcEndpointRef] = { - asyncSetupEndpointRefByURI(uriOf(systemName, address, endpointName)) - } - - /** * Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName`. * This is a blocking action. */ diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index 398e9eafc1..f1a8273f15 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -29,6 +29,9 @@ import org.apache.spark.network.client.RpcResponseCallback import org.apache.spark.rpc._ import org.apache.spark.util.ThreadUtils +/** + * A message dispatcher, responsible for routing RPC messages to the appropriate endpoint(s). + */ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { private class EndpointData( @@ -42,7 +45,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { private val endpointRefs = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef] // Track the receivers whose inboxes may contain messages. - private val receivers = new LinkedBlockingQueue[EndpointData]() + private val receivers = new LinkedBlockingQueue[EndpointData] /** * True if the dispatcher has been stopped. Once stopped, all messages posted will be bounced @@ -52,7 +55,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { private var stopped = false def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = { - val addr = new NettyRpcAddress(nettyEnv.address.host, nettyEnv.address.port, name) + val addr = new RpcEndpointAddress(nettyEnv.address.host, nettyEnv.address.port, name) val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv) synchronized { if (stopped) { diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 89b6df76c2..a2b28c524d 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -22,7 +22,6 @@ import java.nio.ByteBuffer import java.util.concurrent._ import javax.annotation.concurrent.GuardedBy -import scala.collection.JavaConverters._ import scala.collection.mutable import scala.concurrent.{Future, Promise} import scala.reflect.ClassTag @@ -45,8 +44,10 @@ private[netty] class NettyRpcEnv( host: String, securityManager: SecurityManager) extends RpcEnv(conf) with Logging { - private val transportConf = - SparkTransportConf.fromSparkConf(conf, conf.getInt("spark.rpc.io.threads", 0)) + // Override numConnectionsPerPeer to 1 for RPC. + private val transportConf = SparkTransportConf.fromSparkConf( + conf.clone.set("spark.shuffle.io.numConnectionsPerPeer", "1"), + conf.getInt("spark.rpc.io.threads", 0)) private val dispatcher: Dispatcher = new Dispatcher(this) @@ -54,14 +55,14 @@ private[netty] class NettyRpcEnv( new TransportContext(transportConf, new NettyRpcHandler(dispatcher, this)) private val clientFactory = { - val bootstraps: Seq[TransportClientBootstrap] = + val bootstraps: java.util.List[TransportClientBootstrap] = if (securityManager.isAuthenticationEnabled()) { - Seq(new SaslClientBootstrap(transportConf, "", securityManager, + java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager, securityManager.isSaslEncryptionEnabled())) } else { - Nil + java.util.Collections.emptyList[TransportClientBootstrap] } - transportContext.createClientFactory(bootstraps.asJava) + transportContext.createClientFactory(bootstraps) } val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout") @@ -71,7 +72,7 @@ private[netty] class NettyRpcEnv( // TODO: a non-blocking TransportClientFactory.createClient in future private val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool( "netty-rpc-connection", - conf.getInt("spark.rpc.connect.threads", 256)) + conf.getInt("spark.rpc.connect.threads", 64)) @volatile private var server: TransportServer = _ @@ -83,7 +84,8 @@ private[netty] class NettyRpcEnv( java.util.Collections.emptyList() } server = transportContext.createServer(port, bootstraps) - dispatcher.registerRpcEndpoint(IDVerifier.NAME, new IDVerifier(this, dispatcher)) + dispatcher.registerRpcEndpoint( + RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher)) } override lazy val address: RpcAddress = { @@ -96,11 +98,11 @@ private[netty] class NettyRpcEnv( } def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = { - val addr = NettyRpcAddress(uri) + val addr = RpcEndpointAddress(uri) val endpointRef = new NettyRpcEndpointRef(conf, addr, this) - val idVerifierRef = - new NettyRpcEndpointRef(conf, NettyRpcAddress(addr.host, addr.port, IDVerifier.NAME), this) - idVerifierRef.ask[Boolean](ID(endpointRef.name)).flatMap { find => + val verifier = new NettyRpcEndpointRef( + conf, RpcEndpointAddress(addr.host, addr.port, RpcEndpointVerifier.NAME), this) + verifier.ask[Boolean](RpcEndpointVerifier.CheckExistence(endpointRef.name)).flatMap { find => if (find) { Future.successful(endpointRef) } else { @@ -117,16 +119,18 @@ private[netty] class NettyRpcEnv( private[netty] def send(message: RequestMessage): Unit = { val remoteAddr = message.receiver.address if (remoteAddr == address) { + // Message to a local RPC endpoint. val promise = Promise[Any]() dispatcher.postLocalMessage(message, promise) promise.future.onComplete { case Success(response) => val ack = response.asInstanceOf[Ack] - logDebug(s"Receive ack from ${ack.sender}") + logTrace(s"Received ack from ${ack.sender}") case Failure(e) => logError(s"Exception when sending $message", e) }(ThreadUtils.sameThread) } else { + // Message to a remote RPC endpoint. try { // `createClient` will block if it cannot find a known connection, so we should run it in // clientConnectionExecutor @@ -204,11 +208,10 @@ private[netty] class NettyRpcEnv( } }) } catch { - case e: RejectedExecutionException => { + case e: RejectedExecutionException => if (!promise.tryFailure(e)) { logWarning(s"Ignore failure", e) } - } } } promise.future @@ -231,7 +234,7 @@ private[netty] class NettyRpcEnv( } override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = - new NettyRpcAddress(address.host, address.port, endpointName).toString + new RpcEndpointAddress(address.host, address.port, endpointName).toString override def shutdown(): Unit = { cleanup() @@ -310,9 +313,9 @@ private[netty] class NettyRpcEndpointRef(@transient conf: SparkConf) @transient @volatile private var nettyEnv: NettyRpcEnv = _ - @transient @volatile private var _address: NettyRpcAddress = _ + @transient @volatile private var _address: RpcEndpointAddress = _ - def this(conf: SparkConf, _address: NettyRpcAddress, nettyEnv: NettyRpcEnv) { + def this(conf: SparkConf, _address: RpcEndpointAddress, nettyEnv: NettyRpcEnv) { this(conf) this._address = _address this.nettyEnv = nettyEnv @@ -322,7 +325,7 @@ private[netty] class NettyRpcEndpointRef(@transient conf: SparkConf) private def readObject(in: ObjectInputStream): Unit = { in.defaultReadObject() - _address = in.readObject().asInstanceOf[NettyRpcAddress] + _address = in.readObject().asInstanceOf[RpcEndpointAddress] nettyEnv = NettyRpcEnv.currentEnv.value } @@ -406,49 +409,37 @@ private[netty] class NettyRpcHandler( private type RemoteEnvAddress = RpcAddress // Store all client addresses and their NettyRpcEnv addresses. + // TODO: Is this even necessary? @GuardedBy("this") private val remoteAddresses = new mutable.HashMap[ClientAddress, RemoteEnvAddress]() - // Store the connections from other NettyRpcEnv addresses. We need to keep track of the connection - // count because `TransportClientFactory.createClient` will create multiple connections - // (at most `spark.shuffle.io.numConnectionsPerPeer` connections) and randomly select a connection - // to send the message. See `TransportClientFactory.createClient` for more details. - @GuardedBy("this") - private val remoteConnectionCount = new mutable.HashMap[RemoteEnvAddress, Int]() - override def receive( client: TransportClient, message: Array[Byte], callback: RpcResponseCallback): Unit = { val requestMessage = nettyEnv.deserialize[RequestMessage](message) - val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] + val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] assert(addr != null) val remoteEnvAddress = requestMessage.senderAddress val clientAddr = RpcAddress(addr.getHostName, addr.getPort) - val broadcastMessage: Option[RemoteProcessConnected] = - synchronized { - // If the first connection to a remote RpcEnv is found, we should broadcast "Associated" - if (remoteAddresses.put(clientAddr, remoteEnvAddress).isEmpty) { - // clientAddr connects at the first time - val count = remoteConnectionCount.getOrElse(remoteEnvAddress, 0) - // Increase the connection number of remoteEnvAddress - remoteConnectionCount.put(remoteEnvAddress, count + 1) - if (count == 0) { - // This is the first connection, so fire "Associated" - Some(RemoteProcessConnected(remoteEnvAddress)) - } else { - None - } - } else { - None - } + + // TODO: Can we add connection callback (channel registered) to the underlying framework? + // A variable to track whether we should dispatch the RemoteProcessConnected message. + var dispatchRemoteProcessConnected = false + synchronized { + if (remoteAddresses.put(clientAddr, remoteEnvAddress).isEmpty) { + // clientAddr connects at the first time, fire "RemoteProcessConnected" + dispatchRemoteProcessConnected = true } - broadcastMessage.foreach(dispatcher.postToAll) + } + if (dispatchRemoteProcessConnected) { + dispatcher.postToAll(RemoteProcessConnected(remoteEnvAddress)) + } dispatcher.postRemoteMessage(requestMessage, callback) } override def getStreamManager: StreamManager = new OneForOneStreamManager override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = { - val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] + val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] if (addr != null) { val clientAddr = RpcAddress(addr.getHostName, addr.getPort) val broadcastMessage = @@ -469,34 +460,21 @@ private[netty] class NettyRpcHandler( } override def connectionTerminated(client: TransportClient): Unit = { - val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] + val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] if (addr != null) { val clientAddr = RpcAddress(addr.getHostName, addr.getPort) - val broadcastMessage = - synchronized { - // If the last connection to a remote RpcEnv is terminated, we should broadcast - // "Disassociated" - remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress => - remoteAddresses -= clientAddr - val count = remoteConnectionCount.getOrElse(remoteEnvAddress, 0) - assert(count != 0, "remoteAddresses and remoteConnectionCount are not consistent") - if (count - 1 == 0) { - // We lost all clients, so clean up and fire "Disassociated" - remoteConnectionCount.remove(remoteEnvAddress) - Some(RemoteProcessDisconnected(remoteEnvAddress)) - } else { - // Decrease the connection number of remoteEnvAddress - remoteConnectionCount.put(remoteEnvAddress, count - 1) - None - } - } + val messageOpt: Option[RemoteProcessDisconnected] = + synchronized { + remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress => + remoteAddresses -= clientAddr + Some(RemoteProcessDisconnected(remoteEnvAddress)) } - broadcastMessage.foreach(dispatcher.postToAll) + } + messageOpt.foreach(dispatcher.postToAll) } else { // If the channel is closed before connecting, its remoteAddress will be null. In this case, // we can ignore it since we don't fire "Associated". // See java.net.Socket.getRemoteSocketAddress } } - } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala index 1876b25592..87b6236936 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala @@ -17,40 +17,44 @@ package org.apache.spark.rpc.netty -import java.net.URI - import org.apache.spark.SparkException import org.apache.spark.rpc.RpcAddress -private[netty] case class NettyRpcAddress(host: String, port: Int, name: String) { +/** + * An address identifier for an RPC endpoint. + * + * @param host host name of the remote process. + * @param port the port the remote RPC environment binds to. + * @param name name of the remote endpoint. + */ +private[netty] case class RpcEndpointAddress(host: String, port: Int, name: String) { def toRpcAddress: RpcAddress = RpcAddress(host, port) override val toString = s"spark://$name@$host:$port" } -private[netty] object NettyRpcAddress { +private[netty] object RpcEndpointAddress { - def apply(sparkUrl: String): NettyRpcAddress = { + def apply(sparkUrl: String): RpcEndpointAddress = { try { - val uri = new URI(sparkUrl) + val uri = new java.net.URI(sparkUrl) val host = uri.getHost val port = uri.getPort val name = uri.getUserInfo if (uri.getScheme != "spark" || - host == null || - port < 0 || - name == null || - (uri.getPath != null && !uri.getPath.isEmpty) || // uri.getPath returns "" instead of null - uri.getFragment != null || - uri.getQuery != null) { + host == null || + port < 0 || + name == null || + (uri.getPath != null && !uri.getPath.isEmpty) || // uri.getPath returns "" instead of null + uri.getFragment != null || + uri.getQuery != null) { throw new SparkException("Invalid Spark URL: " + sparkUrl) } - NettyRpcAddress(host, port, name) + RpcEndpointAddress(host, port, name) } catch { case e: java.net.URISyntaxException => throw new SparkException("Invalid Spark URL: " + sparkUrl, e) } } - } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala index fa9a3eb99b..99f20da2d6 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala @@ -14,26 +14,27 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.rpc.netty import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv} /** - * A message used to ask the remote [[IDVerifier]] if an [[RpcEndpoint]] exists - */ -private[netty] case class ID(name: String) - -/** - * An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if a [[RpcEndpoint]] exists in this [[RpcEnv]] + * An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if an [[RpcEndpoint]] exists. + * + * This is used when setting up a remote endpoint reference. */ -private[netty] class IDVerifier(override val rpcEnv: RpcEnv, dispatcher: Dispatcher) +private[netty] class RpcEndpointVerifier(override val rpcEnv: RpcEnv, dispatcher: Dispatcher) extends RpcEndpoint { override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case ID(name) => context.reply(dispatcher.verify(name)) + case RpcEndpointVerifier.CheckExistence(name) => context.reply(dispatcher.verify(name)) } } -private[netty] object IDVerifier { - val NAME = "id-verifier" +private[netty] object RpcEndpointVerifier { + val NAME = "endpoint-verifier" + + /** A message used to ask the remote [[RpcEndpointVerifier]] if an [[RpcEndpoint]] exists. */ + case class CheckExistence(name: String) } diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala index a5d43d3704..973a07a0bd 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.SparkFunSuite class NettyRpcAddressSuite extends SparkFunSuite { test("toString") { - val addr = NettyRpcAddress("localhost", 12345, "test") + val addr = RpcEndpointAddress("localhost", 12345, "test") assert(addr.toString === "spark://test@localhost:12345") } diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala index f24f78b8c4..5430e4c0c4 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -42,9 +42,6 @@ class NettyRpcHandlerSuite extends SparkFunSuite { when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) nettyRpcHandler.receive(client, null, null) - when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40001)) - nettyRpcHandler.receive(client, null, null) - verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 12345))) } |