aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-10-14 12:41:02 -0700
committerReynold Xin <rxin@databricks.com>2015-10-14 12:41:02 -0700
commitcf2e0ae7205443f052463e8cb9334ae2b6df2d0e (patch)
tree593d9fb9ef83400d694b6cce268bfa29f9e5b132 /core
parent615cc858cf913522059b6ebdde65f0204f4fb030 (diff)
downloadspark-cf2e0ae7205443f052463e8cb9334ae2b6df2d0e.tar.gz
spark-cf2e0ae7205443f052463e8cb9334ae2b6df2d0e.tar.bz2
spark-cf2e0ae7205443f052463e8cb9334ae2b6df2d0e.zip
[SPARK-11096] Post-hoc review Netty based RPC implementation - round 2
A few more changes: 1. Renamed IDVerifier -> RpcEndpointVerifier 2. Renamed NettyRpcAddress -> RpcEndpointAddress 3. Simplified NettyRpcHandler a bit by removing the connection count tracking. This is OK because I now force spark.shuffle.io.numConnectionsPerPeer to 1 4. Reduced spark.rpc.connect.threads to 64. It would be great to eventually remove this extra thread pool. 5. Minor cleanup & documentation. Author: Reynold Xin <rxin@databricks.com> Closes #9112 from rxin/SPARK-11096.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala114
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala (renamed from core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala)32
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala (renamed from core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala)21
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala3
7 files changed, 81 insertions, 107 deletions
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
index ef491a0ae4..2c4a8b9a0a 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
@@ -94,15 +94,6 @@ private[spark] abstract class RpcEnv(conf: SparkConf) {
}
/**
- * Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName`
- * asynchronously.
- */
- def asyncSetupEndpointRef(
- systemName: String, address: RpcAddress, endpointName: String): Future[RpcEndpointRef] = {
- asyncSetupEndpointRefByURI(uriOf(systemName, address, endpointName))
- }
-
- /**
* Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName`.
* This is a blocking action.
*/
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
index 398e9eafc1..f1a8273f15 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
@@ -29,6 +29,9 @@ import org.apache.spark.network.client.RpcResponseCallback
import org.apache.spark.rpc._
import org.apache.spark.util.ThreadUtils
+/**
+ * A message dispatcher, responsible for routing RPC messages to the appropriate endpoint(s).
+ */
private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
private class EndpointData(
@@ -42,7 +45,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
private val endpointRefs = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]
// Track the receivers whose inboxes may contain messages.
- private val receivers = new LinkedBlockingQueue[EndpointData]()
+ private val receivers = new LinkedBlockingQueue[EndpointData]
/**
* True if the dispatcher has been stopped. Once stopped, all messages posted will be bounced
@@ -52,7 +55,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
private var stopped = false
def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = {
- val addr = new NettyRpcAddress(nettyEnv.address.host, nettyEnv.address.port, name)
+ val addr = new RpcEndpointAddress(nettyEnv.address.host, nettyEnv.address.port, name)
val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv)
synchronized {
if (stopped) {
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
index 89b6df76c2..a2b28c524d 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
@@ -22,7 +22,6 @@ import java.nio.ByteBuffer
import java.util.concurrent._
import javax.annotation.concurrent.GuardedBy
-import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.concurrent.{Future, Promise}
import scala.reflect.ClassTag
@@ -45,8 +44,10 @@ private[netty] class NettyRpcEnv(
host: String,
securityManager: SecurityManager) extends RpcEnv(conf) with Logging {
- private val transportConf =
- SparkTransportConf.fromSparkConf(conf, conf.getInt("spark.rpc.io.threads", 0))
+ // Override numConnectionsPerPeer to 1 for RPC.
+ private val transportConf = SparkTransportConf.fromSparkConf(
+ conf.clone.set("spark.shuffle.io.numConnectionsPerPeer", "1"),
+ conf.getInt("spark.rpc.io.threads", 0))
private val dispatcher: Dispatcher = new Dispatcher(this)
@@ -54,14 +55,14 @@ private[netty] class NettyRpcEnv(
new TransportContext(transportConf, new NettyRpcHandler(dispatcher, this))
private val clientFactory = {
- val bootstraps: Seq[TransportClientBootstrap] =
+ val bootstraps: java.util.List[TransportClientBootstrap] =
if (securityManager.isAuthenticationEnabled()) {
- Seq(new SaslClientBootstrap(transportConf, "", securityManager,
+ java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager,
securityManager.isSaslEncryptionEnabled()))
} else {
- Nil
+ java.util.Collections.emptyList[TransportClientBootstrap]
}
- transportContext.createClientFactory(bootstraps.asJava)
+ transportContext.createClientFactory(bootstraps)
}
val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout")
@@ -71,7 +72,7 @@ private[netty] class NettyRpcEnv(
// TODO: a non-blocking TransportClientFactory.createClient in future
private val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool(
"netty-rpc-connection",
- conf.getInt("spark.rpc.connect.threads", 256))
+ conf.getInt("spark.rpc.connect.threads", 64))
@volatile private var server: TransportServer = _
@@ -83,7 +84,8 @@ private[netty] class NettyRpcEnv(
java.util.Collections.emptyList()
}
server = transportContext.createServer(port, bootstraps)
- dispatcher.registerRpcEndpoint(IDVerifier.NAME, new IDVerifier(this, dispatcher))
+ dispatcher.registerRpcEndpoint(
+ RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher))
}
override lazy val address: RpcAddress = {
@@ -96,11 +98,11 @@ private[netty] class NettyRpcEnv(
}
def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = {
- val addr = NettyRpcAddress(uri)
+ val addr = RpcEndpointAddress(uri)
val endpointRef = new NettyRpcEndpointRef(conf, addr, this)
- val idVerifierRef =
- new NettyRpcEndpointRef(conf, NettyRpcAddress(addr.host, addr.port, IDVerifier.NAME), this)
- idVerifierRef.ask[Boolean](ID(endpointRef.name)).flatMap { find =>
+ val verifier = new NettyRpcEndpointRef(
+ conf, RpcEndpointAddress(addr.host, addr.port, RpcEndpointVerifier.NAME), this)
+ verifier.ask[Boolean](RpcEndpointVerifier.CheckExistence(endpointRef.name)).flatMap { find =>
if (find) {
Future.successful(endpointRef)
} else {
@@ -117,16 +119,18 @@ private[netty] class NettyRpcEnv(
private[netty] def send(message: RequestMessage): Unit = {
val remoteAddr = message.receiver.address
if (remoteAddr == address) {
+ // Message to a local RPC endpoint.
val promise = Promise[Any]()
dispatcher.postLocalMessage(message, promise)
promise.future.onComplete {
case Success(response) =>
val ack = response.asInstanceOf[Ack]
- logDebug(s"Receive ack from ${ack.sender}")
+ logTrace(s"Received ack from ${ack.sender}")
case Failure(e) =>
logError(s"Exception when sending $message", e)
}(ThreadUtils.sameThread)
} else {
+ // Message to a remote RPC endpoint.
try {
// `createClient` will block if it cannot find a known connection, so we should run it in
// clientConnectionExecutor
@@ -204,11 +208,10 @@ private[netty] class NettyRpcEnv(
}
})
} catch {
- case e: RejectedExecutionException => {
+ case e: RejectedExecutionException =>
if (!promise.tryFailure(e)) {
logWarning(s"Ignore failure", e)
}
- }
}
}
promise.future
@@ -231,7 +234,7 @@ private[netty] class NettyRpcEnv(
}
override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String =
- new NettyRpcAddress(address.host, address.port, endpointName).toString
+ new RpcEndpointAddress(address.host, address.port, endpointName).toString
override def shutdown(): Unit = {
cleanup()
@@ -310,9 +313,9 @@ private[netty] class NettyRpcEndpointRef(@transient conf: SparkConf)
@transient @volatile private var nettyEnv: NettyRpcEnv = _
- @transient @volatile private var _address: NettyRpcAddress = _
+ @transient @volatile private var _address: RpcEndpointAddress = _
- def this(conf: SparkConf, _address: NettyRpcAddress, nettyEnv: NettyRpcEnv) {
+ def this(conf: SparkConf, _address: RpcEndpointAddress, nettyEnv: NettyRpcEnv) {
this(conf)
this._address = _address
this.nettyEnv = nettyEnv
@@ -322,7 +325,7 @@ private[netty] class NettyRpcEndpointRef(@transient conf: SparkConf)
private def readObject(in: ObjectInputStream): Unit = {
in.defaultReadObject()
- _address = in.readObject().asInstanceOf[NettyRpcAddress]
+ _address = in.readObject().asInstanceOf[RpcEndpointAddress]
nettyEnv = NettyRpcEnv.currentEnv.value
}
@@ -406,49 +409,37 @@ private[netty] class NettyRpcHandler(
private type RemoteEnvAddress = RpcAddress
// Store all client addresses and their NettyRpcEnv addresses.
+ // TODO: Is this even necessary?
@GuardedBy("this")
private val remoteAddresses = new mutable.HashMap[ClientAddress, RemoteEnvAddress]()
- // Store the connections from other NettyRpcEnv addresses. We need to keep track of the connection
- // count because `TransportClientFactory.createClient` will create multiple connections
- // (at most `spark.shuffle.io.numConnectionsPerPeer` connections) and randomly select a connection
- // to send the message. See `TransportClientFactory.createClient` for more details.
- @GuardedBy("this")
- private val remoteConnectionCount = new mutable.HashMap[RemoteEnvAddress, Int]()
-
override def receive(
client: TransportClient, message: Array[Byte], callback: RpcResponseCallback): Unit = {
val requestMessage = nettyEnv.deserialize[RequestMessage](message)
- val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
+ val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
assert(addr != null)
val remoteEnvAddress = requestMessage.senderAddress
val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
- val broadcastMessage: Option[RemoteProcessConnected] =
- synchronized {
- // If the first connection to a remote RpcEnv is found, we should broadcast "Associated"
- if (remoteAddresses.put(clientAddr, remoteEnvAddress).isEmpty) {
- // clientAddr connects at the first time
- val count = remoteConnectionCount.getOrElse(remoteEnvAddress, 0)
- // Increase the connection number of remoteEnvAddress
- remoteConnectionCount.put(remoteEnvAddress, count + 1)
- if (count == 0) {
- // This is the first connection, so fire "Associated"
- Some(RemoteProcessConnected(remoteEnvAddress))
- } else {
- None
- }
- } else {
- None
- }
+
+ // TODO: Can we add connection callback (channel registered) to the underlying framework?
+ // A variable to track whether we should dispatch the RemoteProcessConnected message.
+ var dispatchRemoteProcessConnected = false
+ synchronized {
+ if (remoteAddresses.put(clientAddr, remoteEnvAddress).isEmpty) {
+ // clientAddr connects at the first time, fire "RemoteProcessConnected"
+ dispatchRemoteProcessConnected = true
}
- broadcastMessage.foreach(dispatcher.postToAll)
+ }
+ if (dispatchRemoteProcessConnected) {
+ dispatcher.postToAll(RemoteProcessConnected(remoteEnvAddress))
+ }
dispatcher.postRemoteMessage(requestMessage, callback)
}
override def getStreamManager: StreamManager = new OneForOneStreamManager
override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = {
- val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
+ val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
if (addr != null) {
val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
val broadcastMessage =
@@ -469,34 +460,21 @@ private[netty] class NettyRpcHandler(
}
override def connectionTerminated(client: TransportClient): Unit = {
- val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
+ val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
if (addr != null) {
val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
- val broadcastMessage =
- synchronized {
- // If the last connection to a remote RpcEnv is terminated, we should broadcast
- // "Disassociated"
- remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress =>
- remoteAddresses -= clientAddr
- val count = remoteConnectionCount.getOrElse(remoteEnvAddress, 0)
- assert(count != 0, "remoteAddresses and remoteConnectionCount are not consistent")
- if (count - 1 == 0) {
- // We lost all clients, so clean up and fire "Disassociated"
- remoteConnectionCount.remove(remoteEnvAddress)
- Some(RemoteProcessDisconnected(remoteEnvAddress))
- } else {
- // Decrease the connection number of remoteEnvAddress
- remoteConnectionCount.put(remoteEnvAddress, count - 1)
- None
- }
- }
+ val messageOpt: Option[RemoteProcessDisconnected] =
+ synchronized {
+ remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress =>
+ remoteAddresses -= clientAddr
+ Some(RemoteProcessDisconnected(remoteEnvAddress))
}
- broadcastMessage.foreach(dispatcher.postToAll)
+ }
+ messageOpt.foreach(dispatcher.postToAll)
} else {
// If the channel is closed before connecting, its remoteAddress will be null. In this case,
// we can ignore it since we don't fire "Associated".
// See java.net.Socket.getRemoteSocketAddress
}
}
-
}
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala
index 1876b25592..87b6236936 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala
@@ -17,40 +17,44 @@
package org.apache.spark.rpc.netty
-import java.net.URI
-
import org.apache.spark.SparkException
import org.apache.spark.rpc.RpcAddress
-private[netty] case class NettyRpcAddress(host: String, port: Int, name: String) {
+/**
+ * An address identifier for an RPC endpoint.
+ *
+ * @param host host name of the remote process.
+ * @param port the port the remote RPC environment binds to.
+ * @param name name of the remote endpoint.
+ */
+private[netty] case class RpcEndpointAddress(host: String, port: Int, name: String) {
def toRpcAddress: RpcAddress = RpcAddress(host, port)
override val toString = s"spark://$name@$host:$port"
}
-private[netty] object NettyRpcAddress {
+private[netty] object RpcEndpointAddress {
- def apply(sparkUrl: String): NettyRpcAddress = {
+ def apply(sparkUrl: String): RpcEndpointAddress = {
try {
- val uri = new URI(sparkUrl)
+ val uri = new java.net.URI(sparkUrl)
val host = uri.getHost
val port = uri.getPort
val name = uri.getUserInfo
if (uri.getScheme != "spark" ||
- host == null ||
- port < 0 ||
- name == null ||
- (uri.getPath != null && !uri.getPath.isEmpty) || // uri.getPath returns "" instead of null
- uri.getFragment != null ||
- uri.getQuery != null) {
+ host == null ||
+ port < 0 ||
+ name == null ||
+ (uri.getPath != null && !uri.getPath.isEmpty) || // uri.getPath returns "" instead of null
+ uri.getFragment != null ||
+ uri.getQuery != null) {
throw new SparkException("Invalid Spark URL: " + sparkUrl)
}
- NettyRpcAddress(host, port, name)
+ RpcEndpointAddress(host, port, name)
} catch {
case e: java.net.URISyntaxException =>
throw new SparkException("Invalid Spark URL: " + sparkUrl, e)
}
}
-
}
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala
index fa9a3eb99b..99f20da2d6 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala
@@ -14,26 +14,27 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
package org.apache.spark.rpc.netty
import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv}
/**
- * A message used to ask the remote [[IDVerifier]] if an [[RpcEndpoint]] exists
- */
-private[netty] case class ID(name: String)
-
-/**
- * An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if a [[RpcEndpoint]] exists in this [[RpcEnv]]
+ * An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if an [[RpcEndpoint]] exists.
+ *
+ * This is used when setting up a remote endpoint reference.
*/
-private[netty] class IDVerifier(override val rpcEnv: RpcEnv, dispatcher: Dispatcher)
+private[netty] class RpcEndpointVerifier(override val rpcEnv: RpcEnv, dispatcher: Dispatcher)
extends RpcEndpoint {
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
- case ID(name) => context.reply(dispatcher.verify(name))
+ case RpcEndpointVerifier.CheckExistence(name) => context.reply(dispatcher.verify(name))
}
}
-private[netty] object IDVerifier {
- val NAME = "id-verifier"
+private[netty] object RpcEndpointVerifier {
+ val NAME = "endpoint-verifier"
+
+ /** A message used to ask the remote [[RpcEndpointVerifier]] if an [[RpcEndpoint]] exists. */
+ case class CheckExistence(name: String)
}
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala
index a5d43d3704..973a07a0bd 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala
@@ -22,7 +22,7 @@ import org.apache.spark.SparkFunSuite
class NettyRpcAddressSuite extends SparkFunSuite {
test("toString") {
- val addr = NettyRpcAddress("localhost", 12345, "test")
+ val addr = RpcEndpointAddress("localhost", 12345, "test")
assert(addr.toString === "spark://test@localhost:12345")
}
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
index f24f78b8c4..5430e4c0c4 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
@@ -42,9 +42,6 @@ class NettyRpcHandlerSuite extends SparkFunSuite {
when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000))
nettyRpcHandler.receive(client, null, null)
- when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40001))
- nettyRpcHandler.receive(client, null, null)
-
verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 12345)))
}