aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/org
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala/org')
-rw-r--r--core/src/main/scala/org/apache/spark/MapOutputTracker.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala20
-rwxr-xr-xcore/src/main/scala/org/apache/spark/deploy/worker/Worker.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala51
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/RpcEndpointNotFoundException.scala22
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala218
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala39
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala220
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala56
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala87
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala504
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/util/ThreadUtils.scala6
17 files changed, 1211 insertions, 50 deletions
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 94eb8daa85..e380c5b8af 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -45,7 +45,7 @@ private[spark] class MapOutputTrackerMasterEndpoint(
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case GetMapOutputStatuses(shuffleId: Int) =>
- val hostPort = context.sender.address.hostPort
+ val hostPort = context.senderAddress.hostPort
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId)
val serializedSize = mapOutputStatuses.size
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index c6fef7f91f..cfde27fb2e 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -20,11 +20,10 @@ package org.apache.spark
import java.io.File
import java.net.Socket
-import akka.actor.ActorSystem
-
import scala.collection.mutable
import scala.util.Properties
+import akka.actor.ActorSystem
import com.google.common.collect.MapMaker
import org.apache.spark.annotation.DeveloperApi
@@ -41,7 +40,7 @@ import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager}
import org.apache.spark.storage._
import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator}
-import org.apache.spark.util.{RpcUtils, Utils}
+import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils}
/**
* :: DeveloperApi ::
@@ -57,6 +56,7 @@ import org.apache.spark.util.{RpcUtils, Utils}
class SparkEnv (
val executorId: String,
private[spark] val rpcEnv: RpcEnv,
+ _actorSystem: ActorSystem, // TODO Remove actorSystem
val serializer: Serializer,
val closureSerializer: Serializer,
val cacheManager: CacheManager,
@@ -76,7 +76,7 @@ class SparkEnv (
// TODO Remove actorSystem
@deprecated("Actor system is no longer supported as of 1.4.0", "1.4.0")
- val actorSystem: ActorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem
+ val actorSystem: ActorSystem = _actorSystem
private[spark] var isStopped = false
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
@@ -100,6 +100,9 @@ class SparkEnv (
blockManager.master.stop()
metricsSystem.stop()
outputCommitCoordinator.stop()
+ if (!rpcEnv.isInstanceOf[AkkaRpcEnv]) {
+ actorSystem.shutdown()
+ }
rpcEnv.shutdown()
// Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut
@@ -249,7 +252,13 @@ object SparkEnv extends Logging {
// Create the ActorSystem for Akka and get the port it binds to.
val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName
val rpcEnv = RpcEnv.create(actorSystemName, hostname, port, conf, securityManager)
- val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem
+ val actorSystem: ActorSystem =
+ if (rpcEnv.isInstanceOf[AkkaRpcEnv]) {
+ rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem
+ } else {
+ // Create a ActorSystem for legacy codes
+ AkkaUtils.createActorSystem(actorSystemName, hostname, port, conf, securityManager)._1
+ }
// Figure out which port Akka actually bound to in case the original port is 0 or occupied.
if (isDriver) {
@@ -395,6 +404,7 @@ object SparkEnv extends Logging {
val envInstance = new SparkEnv(
executorId,
rpcEnv,
+ actorSystem,
serializer,
closureSerializer,
cacheManager,
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index 770927c80f..93a1b3f310 100755
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -329,7 +329,7 @@ private[deploy] class Worker(
registrationRetryTimer = Some(forwordMessageScheduler.scheduleAtFixedRate(
new Runnable {
override def run(): Unit = Utils.tryLogNonFatalError {
- self.send(ReregisterWithMaster)
+ Option(self).foreach(_.send(ReregisterWithMaster))
}
},
INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS,
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
index 735c4f0927..ab56fde938 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
@@ -24,14 +24,13 @@ import org.apache.spark.rpc._
* Actor which connects to a worker process and terminates the JVM if the connection is severed.
* Provides fate sharing between a worker and its associated child processes.
*/
-private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: String)
+private[spark] class WorkerWatcher(
+ override val rpcEnv: RpcEnv, workerUrl: String, isTesting: Boolean = false)
extends RpcEndpoint with Logging {
- override def onStart() {
- logInfo(s"Connecting to worker $workerUrl")
- if (!isTesting) {
- rpcEnv.asyncSetupEndpointRefByURI(workerUrl)
- }
+ logInfo(s"Connecting to worker $workerUrl")
+ if (!isTesting) {
+ rpcEnv.asyncSetupEndpointRefByURI(workerUrl)
}
// Used to avoid shutting down JVM during tests
@@ -40,8 +39,6 @@ private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: Strin
// true rather than calling `System.exit`. The user can check `isShutDown` to know if
// `exitNonZero` is called.
private[deploy] var isShutDown = false
- private[deploy] def setTesting(testing: Boolean) = isTesting = testing
- private var isTesting = false
// Lets filter events only from the worker's rpc system
private val expectedAddress = RpcAddress.fromURIString(workerUrl)
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala
index 3e5b64265e..f527ec86ab 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala
@@ -37,5 +37,5 @@ private[spark] trait RpcCallContext {
/**
* The sender of this message.
*/
- def sender: RpcEndpointRef
+ def senderAddress: RpcAddress
}
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala
index dfcbc51cdf..f1ddc6d2cd 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala
@@ -29,20 +29,6 @@ private[spark] trait RpcEnvFactory {
}
/**
- * A trait that requires RpcEnv thread-safely sending messages to it.
- *
- * Thread-safety means processing of one message happens before processing of the next message by
- * the same [[ThreadSafeRpcEndpoint]]. In the other words, changes to internal fields of a
- * [[ThreadSafeRpcEndpoint]] are visible when processing the next message, and fields in the
- * [[ThreadSafeRpcEndpoint]] need not be volatile or equivalent.
- *
- * However, there is no guarantee that the same thread will be executing the same
- * [[ThreadSafeRpcEndpoint]] for different messages.
- */
-private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint
-
-
-/**
* An end point for the RPC that defines what functions to trigger given a message.
*
* It is guaranteed that `onStart`, `receive` and `onStop` will be called in sequence.
@@ -101,38 +87,39 @@ private[spark] trait RpcEndpoint {
}
/**
- * Invoked before [[RpcEndpoint]] starts to handle any message.
+ * Invoked when `remoteAddress` is connected to the current node.
*/
- def onStart(): Unit = {
+ def onConnected(remoteAddress: RpcAddress): Unit = {
// By default, do nothing.
}
/**
- * Invoked when [[RpcEndpoint]] is stopping.
+ * Invoked when `remoteAddress` is lost.
*/
- def onStop(): Unit = {
+ def onDisconnected(remoteAddress: RpcAddress): Unit = {
// By default, do nothing.
}
/**
- * Invoked when `remoteAddress` is connected to the current node.
+ * Invoked when some network error happens in the connection between the current node and
+ * `remoteAddress`.
*/
- def onConnected(remoteAddress: RpcAddress): Unit = {
+ def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = {
// By default, do nothing.
}
/**
- * Invoked when `remoteAddress` is lost.
+ * Invoked before [[RpcEndpoint]] starts to handle any message.
*/
- def onDisconnected(remoteAddress: RpcAddress): Unit = {
+ def onStart(): Unit = {
// By default, do nothing.
}
/**
- * Invoked when some network error happens in the connection between the current node and
- * `remoteAddress`.
+ * Invoked when [[RpcEndpoint]] is stopping. `self` will be `null` in this method and you cannot
+ * use it to send or ask messages.
*/
- def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = {
+ def onStop(): Unit = {
// By default, do nothing.
}
@@ -146,3 +133,17 @@ private[spark] trait RpcEndpoint {
}
}
}
+
+/**
+ * A trait that requires RpcEnv thread-safely sending messages to it.
+ *
+ * Thread-safety means processing of one message happens before processing of the next message by
+ * the same [[ThreadSafeRpcEndpoint]]. In the other words, changes to internal fields of a
+ * [[ThreadSafeRpcEndpoint]] are visible when processing the next message, and fields in the
+ * [[ThreadSafeRpcEndpoint]] need not be volatile or equivalent.
+ *
+ * However, there is no guarantee that the same thread will be executing the same
+ * [[ThreadSafeRpcEndpoint]] for different messages.
+ */
+private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint {
+}
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointNotFoundException.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointNotFoundException.scala
new file mode 100644
index 0000000000..d177881fb3
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointNotFoundException.scala
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.rpc
+
+import org.apache.spark.SparkException
+
+private[rpc] class RpcEndpointNotFoundException(uri: String)
+ extends SparkException(s"Cannot find endpoint: $uri")
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
index 29debe8081..afe1ee8a0b 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
@@ -36,8 +36,11 @@ private[spark] object RpcEnv {
private def getRpcEnvFactory(conf: SparkConf): RpcEnvFactory = {
// Add more RpcEnv implementations here
- val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory")
- val rpcEnvName = conf.get("spark.rpc", "akka")
+ val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory",
+ "netty" -> "org.apache.spark.rpc.netty.NettyRpcEnvFactory")
+ // Use "netty" by default so that Jenkins can run all tests using NettyRpcEnv.
+ // Will change it back to "akka" before merging the new implementation.
+ val rpcEnvName = conf.get("spark.rpc", "netty")
val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName)
Utils.classForName(rpcEnvFactoryClassName).newInstance().asInstanceOf[RpcEnvFactory]
}
diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
index ad67e1c5ad..95132a4e4a 100644
--- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
@@ -166,9 +166,9 @@ private[spark] class AkkaRpcEnv private[akka] (
_sender ! AkkaMessage(response, false)
}
- // Some RpcEndpoints need to know the sender's address
- override val sender: RpcEndpointRef =
- new AkkaRpcEndpointRef(defaultAddress, _sender, conf)
+ // Use "lazy" because most of RpcEndpoints don't need "senderAddress"
+ override lazy val senderAddress: RpcAddress =
+ new AkkaRpcEndpointRef(defaultAddress, _sender, conf).address
})
} else {
endpoint.receive
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
new file mode 100644
index 0000000000..d71e6f01db
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rpc.netty
+
+import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, TimeUnit}
+import javax.annotation.concurrent.GuardedBy
+
+import scala.collection.JavaConverters._
+import scala.concurrent.Promise
+import scala.util.control.NonFatal
+
+import org.apache.spark.{SparkException, Logging}
+import org.apache.spark.network.client.RpcResponseCallback
+import org.apache.spark.rpc._
+import org.apache.spark.util.ThreadUtils
+
+private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
+
+ private class EndpointData(
+ val name: String,
+ val endpoint: RpcEndpoint,
+ val ref: NettyRpcEndpointRef) {
+ val inbox = new Inbox(ref, endpoint)
+ }
+
+ private val endpoints = new ConcurrentHashMap[String, EndpointData]()
+ private val endpointRefs = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]()
+
+ // Track the receivers whose inboxes may contain messages.
+ private val receivers = new LinkedBlockingQueue[EndpointData]()
+
+ @GuardedBy("this")
+ private var stopped = false
+
+ def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = {
+ val addr = new NettyRpcAddress(nettyEnv.address.host, nettyEnv.address.port, name)
+ val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv)
+ synchronized {
+ if (stopped) {
+ throw new IllegalStateException("RpcEnv has been stopped")
+ }
+ if (endpoints.putIfAbsent(name, new EndpointData(name, endpoint, endpointRef)) != null) {
+ throw new IllegalArgumentException(s"There is already an RpcEndpoint called $name")
+ }
+ val data = endpoints.get(name)
+ endpointRefs.put(data.endpoint, data.ref)
+ receivers.put(data)
+ }
+ endpointRef
+ }
+
+ def getRpcEndpointRef(endpoint: RpcEndpoint): RpcEndpointRef = endpointRefs.get(endpoint)
+
+ def removeRpcEndpointRef(endpoint: RpcEndpoint): Unit = endpointRefs.remove(endpoint)
+
+ // Should be idempotent
+ private def unregisterRpcEndpoint(name: String): Unit = {
+ val data = endpoints.remove(name)
+ if (data != null) {
+ data.inbox.stop()
+ receivers.put(data)
+ }
+ // Don't clean `endpointRefs` here because it's possible that some messages are being processed
+ // now and they can use `getRpcEndpointRef`. So `endpointRefs` will be cleaned in Inbox via
+ // `removeRpcEndpointRef`.
+ }
+
+ def stop(rpcEndpointRef: RpcEndpointRef): Unit = {
+ synchronized {
+ if (stopped) {
+ // This endpoint will be stopped by Dispatcher.stop() method.
+ return
+ }
+ unregisterRpcEndpoint(rpcEndpointRef.name)
+ }
+ }
+
+ /**
+ * Send a message to all registered [[RpcEndpoint]]s.
+ * @param message
+ */
+ def broadcastMessage(message: InboxMessage): Unit = {
+ val iter = endpoints.keySet().iterator()
+ while (iter.hasNext) {
+ val name = iter.next
+ postMessageToInbox(name, (_) => message,
+ () => { logWarning(s"Drop ${message} because ${name} has been stopped") })
+ }
+ }
+
+ def postMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = {
+ def createMessage(sender: NettyRpcEndpointRef): InboxMessage = {
+ val rpcCallContext =
+ new RemoteNettyRpcCallContext(
+ nettyEnv, sender, callback, message.senderAddress, message.needReply)
+ ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext)
+ }
+
+ def onEndpointStopped(): Unit = {
+ callback.onFailure(
+ new SparkException(s"Could not find ${message.receiver.name} or it has been stopped"))
+ }
+
+ postMessageToInbox(message.receiver.name, createMessage, onEndpointStopped)
+ }
+
+ def postMessage(message: RequestMessage, p: Promise[Any]): Unit = {
+ def createMessage(sender: NettyRpcEndpointRef): InboxMessage = {
+ val rpcCallContext =
+ new LocalNettyRpcCallContext(sender, message.senderAddress, message.needReply, p)
+ ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext)
+ }
+
+ def onEndpointStopped(): Unit = {
+ p.tryFailure(
+ new SparkException(s"Could not find ${message.receiver.name} or it has been stopped"))
+ }
+
+ postMessageToInbox(message.receiver.name, createMessage, onEndpointStopped)
+ }
+
+ private def postMessageToInbox(
+ endpointName: String,
+ createMessageFn: NettyRpcEndpointRef => InboxMessage,
+ onStopped: () => Unit): Unit = {
+ val shouldCallOnStop =
+ synchronized {
+ val data = endpoints.get(endpointName)
+ if (stopped || data == null) {
+ true
+ } else {
+ data.inbox.post(createMessageFn(data.ref))
+ receivers.put(data)
+ false
+ }
+ }
+ if (shouldCallOnStop) {
+ // We don't need to call `onStop` in the `synchronized` block
+ onStopped()
+ }
+ }
+
+ private val parallelism = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.parallelism",
+ Runtime.getRuntime.availableProcessors())
+
+ private val executor = ThreadUtils.newDaemonFixedThreadPool(parallelism, "dispatcher-event-loop")
+
+ (0 until parallelism) foreach { _ =>
+ executor.execute(new MessageLoop)
+ }
+
+ def stop(): Unit = {
+ synchronized {
+ if (stopped) {
+ return
+ }
+ stopped = true
+ }
+ // Stop all endpoints. This will queue all endpoints for processing by the message loops.
+ endpoints.keySet().asScala.foreach(unregisterRpcEndpoint)
+ // Enqueue a message that tells the message loops to stop.
+ receivers.put(PoisonEndpoint)
+ executor.shutdown()
+ }
+
+ def awaitTermination(): Unit = {
+ executor.awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS)
+ }
+
+ /**
+ * Return if the endpoint exists
+ */
+ def verify(name: String): Boolean = {
+ endpoints.containsKey(name)
+ }
+
+ private class MessageLoop extends Runnable {
+ override def run(): Unit = {
+ try {
+ while (true) {
+ try {
+ val data = receivers.take()
+ if (data == PoisonEndpoint) {
+ // Put PoisonEndpoint back so that other MessageLoops can see it.
+ receivers.put(PoisonEndpoint)
+ return
+ }
+ data.inbox.process(Dispatcher.this)
+ } catch {
+ case NonFatal(e) => logError(e.getMessage, e)
+ }
+ }
+ } catch {
+ case ie: InterruptedException => // exit
+ }
+ }
+ }
+
+ /**
+ * A poison endpoint that indicates MessageLoop should exit its loop.
+ */
+ private val PoisonEndpoint = new EndpointData(null, null, null)
+}
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala b/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala
new file mode 100644
index 0000000000..6061c9b8de
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.rpc.netty
+
+import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv}
+
+/**
+ * A message used to ask the remote [[IDVerifier]] if an [[RpcEndpoint]] exists
+ */
+private[netty] case class ID(name: String)
+
+/**
+ * An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if a [[RpcEndpoint]] exists in this [[RpcEnv]]
+ */
+private[netty] class IDVerifier(
+ override val rpcEnv: RpcEnv, dispatcher: Dispatcher) extends RpcEndpoint {
+
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ case ID(name) => context.reply(dispatcher.verify(name))
+ }
+}
+
+private[netty] object IDVerifier {
+ val NAME = "id-verifier"
+}
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala
new file mode 100644
index 0000000000..4803548365
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rpc.netty
+
+import java.util.LinkedList
+import javax.annotation.concurrent.GuardedBy
+
+import scala.util.control.NonFatal
+
+import org.apache.spark.{Logging, SparkException}
+import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, ThreadSafeRpcEndpoint}
+
+private[netty] sealed trait InboxMessage
+
+private[netty] case class ContentMessage(
+ senderAddress: RpcAddress,
+ content: Any,
+ needReply: Boolean,
+ context: NettyRpcCallContext) extends InboxMessage
+
+private[netty] case object OnStart extends InboxMessage
+
+private[netty] case object OnStop extends InboxMessage
+
+/**
+ * A broadcast message that indicates connecting to a remote node.
+ */
+private[netty] case class Associated(remoteAddress: RpcAddress) extends InboxMessage
+
+/**
+ * A broadcast message that indicates a remote connection is lost.
+ */
+private[netty] case class Disassociated(remoteAddress: RpcAddress) extends InboxMessage
+
+/**
+ * A broadcast message that indicates a network error
+ */
+private[netty] case class AssociationError(cause: Throwable, remoteAddress: RpcAddress)
+ extends InboxMessage
+
+/**
+ * A inbox that stores messages for an [[RpcEndpoint]] and posts messages to it thread-safely.
+ * @param endpointRef
+ * @param endpoint
+ */
+private[netty] class Inbox(
+ val endpointRef: NettyRpcEndpointRef,
+ val endpoint: RpcEndpoint) extends Logging {
+
+ @GuardedBy("this")
+ protected val messages = new LinkedList[InboxMessage]()
+
+ @GuardedBy("this")
+ private var stopped = false
+
+ @GuardedBy("this")
+ private var enableConcurrent = false
+
+ @GuardedBy("this")
+ private var workerCount = 0
+
+ // OnStart should be the first message to process
+ synchronized {
+ messages.add(OnStart)
+ }
+
+ /**
+ * Process stored messages.
+ */
+ def process(dispatcher: Dispatcher): Unit = {
+ var message: InboxMessage = null
+ synchronized {
+ if (!enableConcurrent && workerCount != 0) {
+ return
+ }
+ message = messages.poll()
+ if (message != null) {
+ workerCount += 1
+ } else {
+ return
+ }
+ }
+ while (true) {
+ safelyCall(endpoint) {
+ message match {
+ case ContentMessage(_sender, content, needReply, context) =>
+ val pf: PartialFunction[Any, Unit] =
+ if (needReply) {
+ endpoint.receiveAndReply(context)
+ } else {
+ endpoint.receive
+ }
+ try {
+ pf.applyOrElse[Any, Unit](content, { msg =>
+ throw new SparkException(s"Unmatched message $message from ${_sender}")
+ })
+ if (!needReply) {
+ context.finish()
+ }
+ } catch {
+ case NonFatal(e) =>
+ if (needReply) {
+ // If the sender asks a reply, we should send the error back to the sender
+ context.sendFailure(e)
+ } else {
+ context.finish()
+ throw e
+ }
+ }
+
+ case OnStart => {
+ endpoint.onStart()
+ if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
+ synchronized {
+ enableConcurrent = true
+ }
+ }
+ }
+
+ case OnStop =>
+ assert(synchronized { workerCount } == 1)
+ dispatcher.removeRpcEndpointRef(endpoint)
+ endpoint.onStop()
+ assert(isEmpty, "OnStop should be the last message")
+
+ case Associated(remoteAddress) =>
+ endpoint.onConnected(remoteAddress)
+
+ case Disassociated(remoteAddress) =>
+ endpoint.onDisconnected(remoteAddress)
+
+ case AssociationError(cause, remoteAddress) =>
+ endpoint.onNetworkError(cause, remoteAddress)
+ }
+ }
+
+ synchronized {
+ // "enableConcurrent" will be set to false after `onStop` is called, so we should check it
+ // every time.
+ if (!enableConcurrent && workerCount != 1) {
+ // If we are not the only one worker, exit
+ workerCount -= 1
+ return
+ }
+ message = messages.poll()
+ if (message == null) {
+ workerCount -= 1
+ return
+ }
+ }
+ }
+ }
+
+ def post(message: InboxMessage): Unit = {
+ val dropped =
+ synchronized {
+ if (stopped) {
+ // We already put "OnStop" into "messages", so we should drop further messages
+ true
+ } else {
+ messages.add(message)
+ false
+ }
+ }
+ if (dropped) {
+ onDrop(message)
+ }
+ }
+
+ def stop(): Unit = synchronized {
+ // The following codes should be in `synchronized` so that we can make sure "OnStop" is the last
+ // message
+ if (!stopped) {
+ // We should disable concurrent here. Then when RpcEndpoint.onStop is called, it's the only
+ // thread that is processing messages. So `RpcEndpoint.onStop` can release its resources
+ // safely.
+ enableConcurrent = false
+ stopped = true
+ messages.add(OnStop)
+ // Note: The concurrent events in messages will be processed one by one.
+ }
+ }
+
+ // Visible for testing.
+ protected def onDrop(message: InboxMessage): Unit = {
+ logWarning(s"Drop ${message} because $endpointRef is stopped")
+ }
+
+ def isEmpty: Boolean = synchronized { messages.isEmpty }
+
+ private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = {
+ try {
+ action
+ } catch {
+ case NonFatal(e) => {
+ try {
+ endpoint.onError(e)
+ } catch {
+ case NonFatal(e) => logWarning(s"Ignore error", e)
+ }
+ }
+ }
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala
new file mode 100644
index 0000000000..1876b25592
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rpc.netty
+
+import java.net.URI
+
+import org.apache.spark.SparkException
+import org.apache.spark.rpc.RpcAddress
+
+private[netty] case class NettyRpcAddress(host: String, port: Int, name: String) {
+
+ def toRpcAddress: RpcAddress = RpcAddress(host, port)
+
+ override val toString = s"spark://$name@$host:$port"
+}
+
+private[netty] object NettyRpcAddress {
+
+ def apply(sparkUrl: String): NettyRpcAddress = {
+ try {
+ val uri = new URI(sparkUrl)
+ val host = uri.getHost
+ val port = uri.getPort
+ val name = uri.getUserInfo
+ if (uri.getScheme != "spark" ||
+ host == null ||
+ port < 0 ||
+ name == null ||
+ (uri.getPath != null && !uri.getPath.isEmpty) || // uri.getPath returns "" instead of null
+ uri.getFragment != null ||
+ uri.getQuery != null) {
+ throw new SparkException("Invalid Spark URL: " + sparkUrl)
+ }
+ NettyRpcAddress(host, port, name)
+ } catch {
+ case e: java.net.URISyntaxException =>
+ throw new SparkException("Invalid Spark URL: " + sparkUrl, e)
+ }
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala
new file mode 100644
index 0000000000..75dcc02a0c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rpc.netty
+
+import scala.concurrent.Promise
+
+import org.apache.spark.Logging
+import org.apache.spark.network.client.RpcResponseCallback
+import org.apache.spark.rpc.{RpcAddress, RpcCallContext}
+
+private[netty] abstract class NettyRpcCallContext(
+ endpointRef: NettyRpcEndpointRef,
+ override val senderAddress: RpcAddress,
+ needReply: Boolean) extends RpcCallContext with Logging {
+
+ protected def send(message: Any): Unit
+
+ override def reply(response: Any): Unit = {
+ if (needReply) {
+ send(AskResponse(endpointRef, response))
+ } else {
+ throw new IllegalStateException(
+ s"Cannot send $response to the sender because the sender won't handle it")
+ }
+ }
+
+ override def sendFailure(e: Throwable): Unit = {
+ if (needReply) {
+ send(AskResponse(endpointRef, RpcFailure(e)))
+ } else {
+ logError(e.getMessage, e)
+ throw new IllegalStateException(
+ "Cannot send reply to the sender because the sender won't handle it")
+ }
+ }
+
+ def finish(): Unit = {
+ if (!needReply) {
+ send(Ack(endpointRef))
+ }
+ }
+}
+
+/**
+ * If the sender and the receiver are in the same process, the reply can be sent back via `Promise`.
+ */
+private[netty] class LocalNettyRpcCallContext(
+ endpointRef: NettyRpcEndpointRef,
+ senderAddress: RpcAddress,
+ needReply: Boolean,
+ p: Promise[Any]) extends NettyRpcCallContext(endpointRef, senderAddress, needReply) {
+
+ override protected def send(message: Any): Unit = {
+ p.success(message)
+ }
+}
+
+/**
+ * A [[RpcCallContext]] that will call [[RpcResponseCallback]] to send the reply back.
+ */
+private[netty] class RemoteNettyRpcCallContext(
+ nettyEnv: NettyRpcEnv,
+ endpointRef: NettyRpcEndpointRef,
+ callback: RpcResponseCallback,
+ senderAddress: RpcAddress,
+ needReply: Boolean) extends NettyRpcCallContext(endpointRef, senderAddress, needReply) {
+
+ override protected def send(message: Any): Unit = {
+ val reply = nettyEnv.serialize(message)
+ callback.onSuccess(reply)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
new file mode 100644
index 0000000000..5522b40782
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
@@ -0,0 +1,504 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.rpc.netty
+
+import java.io._
+import java.net.{InetSocketAddress, URI}
+import java.nio.ByteBuffer
+import java.util.Arrays
+import java.util.concurrent._
+import javax.annotation.concurrent.GuardedBy
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+import scala.concurrent.{Future, Promise}
+import scala.reflect.ClassTag
+import scala.util.{DynamicVariable, Failure, Success}
+import scala.util.control.NonFatal
+
+import org.apache.spark.{Logging, SecurityManager, SparkConf}
+import org.apache.spark.network.TransportContext
+import org.apache.spark.network.client._
+import org.apache.spark.network.netty.SparkTransportConf
+import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap}
+import org.apache.spark.network.server._
+import org.apache.spark.rpc._
+import org.apache.spark.serializer.{JavaSerializer, JavaSerializerInstance}
+import org.apache.spark.util.{ThreadUtils, Utils}
+
+private[netty] class NettyRpcEnv(
+ val conf: SparkConf,
+ javaSerializerInstance: JavaSerializerInstance,
+ host: String,
+ securityManager: SecurityManager) extends RpcEnv(conf) with Logging {
+
+ private val transportConf =
+ SparkTransportConf.fromSparkConf(conf, conf.getInt("spark.rpc.io.threads", 0))
+
+ private val dispatcher: Dispatcher = new Dispatcher(this)
+
+ private val transportContext =
+ new TransportContext(transportConf, new NettyRpcHandler(dispatcher, this))
+
+ private val clientFactory = {
+ val bootstraps: Seq[TransportClientBootstrap] =
+ if (securityManager.isAuthenticationEnabled()) {
+ Seq(new SaslClientBootstrap(transportConf, "", securityManager,
+ securityManager.isSaslEncryptionEnabled()))
+ } else {
+ Nil
+ }
+ transportContext.createClientFactory(bootstraps.asJava)
+ }
+
+ val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout")
+
+ // Because TransportClientFactory.createClient is blocking, we need to run it in this thread pool
+ // to implement non-blocking send/ask.
+ // TODO: a non-blocking TransportClientFactory.createClient in future
+ private val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool(
+ "netty-rpc-connection",
+ conf.getInt("spark.rpc.connect.threads", 256))
+
+ @volatile private var server: TransportServer = _
+
+ def start(port: Int): Unit = {
+ val bootstraps: Seq[TransportServerBootstrap] =
+ if (securityManager.isAuthenticationEnabled()) {
+ Seq(new SaslServerBootstrap(transportConf, securityManager))
+ } else {
+ Nil
+ }
+ server = transportContext.createServer(port, bootstraps.asJava)
+ dispatcher.registerRpcEndpoint(IDVerifier.NAME, new IDVerifier(this, dispatcher))
+ }
+
+ override lazy val address: RpcAddress = {
+ require(server != null, "NettyRpcEnv has not yet started")
+ RpcAddress(host, server.getPort())
+ }
+
+ override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = {
+ dispatcher.registerRpcEndpoint(name, endpoint)
+ }
+
+ def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = {
+ val addr = NettyRpcAddress(uri)
+ val endpointRef = new NettyRpcEndpointRef(conf, addr, this)
+ val idVerifierRef =
+ new NettyRpcEndpointRef(conf, NettyRpcAddress(addr.host, addr.port, IDVerifier.NAME), this)
+ idVerifierRef.ask[Boolean](ID(endpointRef.name)).flatMap { find =>
+ if (find) {
+ Future.successful(endpointRef)
+ } else {
+ Future.failed(new RpcEndpointNotFoundException(uri))
+ }
+ }(ThreadUtils.sameThread)
+ }
+
+ override def stop(endpointRef: RpcEndpointRef): Unit = {
+ require(endpointRef.isInstanceOf[NettyRpcEndpointRef])
+ dispatcher.stop(endpointRef)
+ }
+
+ private[netty] def send(message: RequestMessage): Unit = {
+ val remoteAddr = message.receiver.address
+ if (remoteAddr == address) {
+ val promise = Promise[Any]()
+ dispatcher.postMessage(message, promise)
+ promise.future.onComplete {
+ case Success(response) =>
+ val ack = response.asInstanceOf[Ack]
+ logDebug(s"Receive ack from ${ack.sender}")
+ case Failure(e) =>
+ logError(s"Exception when sending $message", e)
+ }(ThreadUtils.sameThread)
+ } else {
+ try {
+ // `createClient` will block if it cannot find a known connection, so we should run it in
+ // clientConnectionExecutor
+ clientConnectionExecutor.execute(new Runnable {
+ override def run(): Unit = Utils.tryLogNonFatalError {
+ val client = clientFactory.createClient(remoteAddr.host, remoteAddr.port)
+ client.sendRpc(serialize(message), new RpcResponseCallback {
+
+ override def onFailure(e: Throwable): Unit = {
+ logError(s"Exception when sending $message", e)
+ }
+
+ override def onSuccess(response: Array[Byte]): Unit = {
+ val ack = deserialize[Ack](response)
+ logDebug(s"Receive ack from ${ack.sender}")
+ }
+ })
+ }
+ })
+ } catch {
+ case e: RejectedExecutionException => {
+ // `send` after shutting clientConnectionExecutor down, ignore it
+ logWarning(s"Cannot send ${message} because RpcEnv is stopped")
+ }
+ }
+ }
+ }
+
+ private[netty] def ask(message: RequestMessage): Future[Any] = {
+ val promise = Promise[Any]()
+ val remoteAddr = message.receiver.address
+ if (remoteAddr == address) {
+ val p = Promise[Any]()
+ dispatcher.postMessage(message, p)
+ p.future.onComplete {
+ case Success(response) =>
+ val reply = response.asInstanceOf[AskResponse]
+ if (reply.reply.isInstanceOf[RpcFailure]) {
+ if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) {
+ logWarning(s"Ignore failure: ${reply.reply}")
+ }
+ } else if (!promise.trySuccess(reply.reply)) {
+ logWarning(s"Ignore message: ${reply}")
+ }
+ case Failure(e) =>
+ if (!promise.tryFailure(e)) {
+ logWarning("Ignore Exception", e)
+ }
+ }(ThreadUtils.sameThread)
+ } else {
+ try {
+ // `createClient` will block if it cannot find a known connection, so we should run it in
+ // clientConnectionExecutor
+ clientConnectionExecutor.execute(new Runnable {
+ override def run(): Unit = {
+ val client = clientFactory.createClient(remoteAddr.host, remoteAddr.port)
+ client.sendRpc(serialize(message), new RpcResponseCallback {
+
+ override def onFailure(e: Throwable): Unit = {
+ if (!promise.tryFailure(e)) {
+ logWarning("Ignore Exception", e)
+ }
+ }
+
+ override def onSuccess(response: Array[Byte]): Unit = {
+ val reply = deserialize[AskResponse](response)
+ if (reply.reply.isInstanceOf[RpcFailure]) {
+ if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) {
+ logWarning(s"Ignore failure: ${reply.reply}")
+ }
+ } else if (!promise.trySuccess(reply.reply)) {
+ logWarning(s"Ignore message: ${reply}")
+ }
+ }
+ })
+ }
+ })
+ } catch {
+ case e: RejectedExecutionException => {
+ if (!promise.tryFailure(e)) {
+ logWarning(s"Ignore failure", e)
+ }
+ }
+ }
+ }
+ promise.future
+ }
+
+ private[netty] def serialize(content: Any): Array[Byte] = {
+ val buffer = javaSerializerInstance.serialize(content)
+ Arrays.copyOfRange(
+ buffer.array(), buffer.arrayOffset + buffer.position, buffer.arrayOffset + buffer.limit)
+ }
+
+ private[netty] def deserialize[T: ClassTag](bytes: Array[Byte]): T = {
+ deserialize { () =>
+ javaSerializerInstance.deserialize[T](ByteBuffer.wrap(bytes))
+ }
+ }
+
+ override def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = {
+ dispatcher.getRpcEndpointRef(endpoint)
+ }
+
+ override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String =
+ new NettyRpcAddress(address.host, address.port, endpointName).toString
+
+ override def shutdown(): Unit = {
+ cleanup()
+ }
+
+ override def awaitTermination(): Unit = {
+ dispatcher.awaitTermination()
+ }
+
+ private def cleanup(): Unit = {
+ if (timeoutScheduler != null) {
+ timeoutScheduler.shutdownNow()
+ }
+ if (server != null) {
+ server.close()
+ }
+ if (clientFactory != null) {
+ clientFactory.close()
+ }
+ if (dispatcher != null) {
+ dispatcher.stop()
+ }
+ if (clientConnectionExecutor != null) {
+ clientConnectionExecutor.shutdownNow()
+ }
+ }
+
+ override def deserialize[T](deserializationAction: () => T): T = {
+ NettyRpcEnv.currentEnv.withValue(this) {
+ deserializationAction()
+ }
+ }
+}
+
+private[netty] object NettyRpcEnv extends Logging {
+
+ /**
+ * When deserializing the [[NettyRpcEndpointRef]], it needs a reference to [[NettyRpcEnv]].
+ * Use `currentEnv` to wrap the deserialization codes. E.g.,
+ *
+ * {{{
+ * NettyRpcEnv.currentEnv.withValue(this) {
+ * your deserialization codes
+ * }
+ * }}}
+ */
+ private[netty] val currentEnv = new DynamicVariable[NettyRpcEnv](null)
+}
+
+private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
+
+ def create(config: RpcEnvConfig): RpcEnv = {
+ val sparkConf = config.conf
+ // Use JavaSerializerInstance in multiple threads is safe. However, if we plan to support
+ // KryoSerializer in future, we have to use ThreadLocal to store SerializerInstance
+ val javaSerializerInstance =
+ new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance]
+ val nettyEnv =
+ new NettyRpcEnv(sparkConf, javaSerializerInstance, config.host, config.securityManager)
+ val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort =>
+ nettyEnv.start(actualPort)
+ (nettyEnv, actualPort)
+ }
+ try {
+ Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, "NettyRpcEnv")._1
+ } catch {
+ case NonFatal(e) =>
+ nettyEnv.shutdown()
+ throw e
+ }
+ }
+}
+
+private[netty] class NettyRpcEndpointRef(@transient conf: SparkConf)
+ extends RpcEndpointRef(conf) with Serializable with Logging {
+
+ @transient @volatile private var nettyEnv: NettyRpcEnv = _
+
+ @transient @volatile private var _address: NettyRpcAddress = _
+
+ def this(conf: SparkConf, _address: NettyRpcAddress, nettyEnv: NettyRpcEnv) {
+ this(conf)
+ this._address = _address
+ this.nettyEnv = nettyEnv
+ }
+
+ override def address: RpcAddress = _address.toRpcAddress
+
+ private def readObject(in: ObjectInputStream): Unit = {
+ in.defaultReadObject()
+ _address = in.readObject().asInstanceOf[NettyRpcAddress]
+ nettyEnv = NettyRpcEnv.currentEnv.value
+ }
+
+ private def writeObject(out: ObjectOutputStream): Unit = {
+ out.defaultWriteObject()
+ out.writeObject(_address)
+ }
+
+ override def name: String = _address.name
+
+ override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
+ val promise = Promise[Any]()
+ val timeoutCancelable = nettyEnv.timeoutScheduler.schedule(new Runnable {
+ override def run(): Unit = {
+ promise.tryFailure(new TimeoutException("Cannot receive any reply in " + timeout.duration))
+ }
+ }, timeout.duration.toNanos, TimeUnit.NANOSECONDS)
+ val f = nettyEnv.ask(RequestMessage(nettyEnv.address, this, message, true))
+ f.onComplete { v =>
+ timeoutCancelable.cancel(true)
+ if (!promise.tryComplete(v)) {
+ logWarning(s"Ignore message $v")
+ }
+ }(ThreadUtils.sameThread)
+ promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread)
+ }
+
+ override def send(message: Any): Unit = {
+ require(message != null, "Message is null")
+ nettyEnv.send(RequestMessage(nettyEnv.address, this, message, false))
+ }
+
+ override def toString: String = s"NettyRpcEndpointRef(${_address})"
+
+ def toURI: URI = new URI(s"spark://${_address}")
+
+ final override def equals(that: Any): Boolean = that match {
+ case other: NettyRpcEndpointRef => _address == other._address
+ case _ => false
+ }
+
+ final override def hashCode(): Int = if (_address == null) 0 else _address.hashCode()
+}
+
+/**
+ * The message that is sent from the sender to the receiver.
+ */
+private[netty] case class RequestMessage(
+ senderAddress: RpcAddress, receiver: NettyRpcEndpointRef, content: Any, needReply: Boolean)
+
+/**
+ * The base trait for all messages that are sent back from the receiver to the sender.
+ */
+private[netty] trait ResponseMessage
+
+/**
+ * The reply for `ask` from the receiver side.
+ */
+private[netty] case class AskResponse(sender: NettyRpcEndpointRef, reply: Any)
+ extends ResponseMessage
+
+/**
+ * A message to send back to the receiver side. It's necessary because [[TransportClient]] only
+ * clean the resources when it receives a reply.
+ */
+private[netty] case class Ack(sender: NettyRpcEndpointRef) extends ResponseMessage
+
+/**
+ * A response that indicates some failure happens in the receiver side.
+ */
+private[netty] case class RpcFailure(e: Throwable)
+
+/**
+ * Maintain the mapping relations between client addresses and [[RpcEnv]] addresses, broadcast
+ * network events and forward messages to [[Dispatcher]].
+ */
+private[netty] class NettyRpcHandler(
+ dispatcher: Dispatcher, nettyEnv: NettyRpcEnv) extends RpcHandler with Logging {
+
+ private type ClientAddress = RpcAddress
+ private type RemoteEnvAddress = RpcAddress
+
+ // Store all client addresses and their NettyRpcEnv addresses.
+ @GuardedBy("this")
+ private val remoteAddresses = new mutable.HashMap[ClientAddress, RemoteEnvAddress]()
+
+ // Store the connections from other NettyRpcEnv addresses. We need to keep track of the connection
+ // count because `TransportClientFactory.createClient` will create multiple connections
+ // (at most `spark.shuffle.io.numConnectionsPerPeer` connections) and randomly select a connection
+ // to send the message. See `TransportClientFactory.createClient` for more details.
+ @GuardedBy("this")
+ private val remoteConnectionCount = new mutable.HashMap[RemoteEnvAddress, Int]()
+
+ override def receive(
+ client: TransportClient, message: Array[Byte], callback: RpcResponseCallback): Unit = {
+ val requestMessage = nettyEnv.deserialize[RequestMessage](message)
+ val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
+ assert(addr != null)
+ val remoteEnvAddress = requestMessage.senderAddress
+ val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
+ val broadcastMessage =
+ synchronized {
+ // If the first connection to a remote RpcEnv is found, we should broadcast "Associated"
+ if (remoteAddresses.put(clientAddr, remoteEnvAddress).isEmpty) {
+ // clientAddr connects at the first time
+ val count = remoteConnectionCount.getOrElse(remoteEnvAddress, 0)
+ // Increase the connection number of remoteEnvAddress
+ remoteConnectionCount.put(remoteEnvAddress, count + 1)
+ if (count == 0) {
+ // This is the first connection, so fire "Associated"
+ Some(Associated(remoteEnvAddress))
+ } else {
+ None
+ }
+ } else {
+ None
+ }
+ }
+ broadcastMessage.foreach(dispatcher.broadcastMessage)
+ dispatcher.postMessage(requestMessage, callback)
+ }
+
+ override def getStreamManager: StreamManager = new OneForOneStreamManager
+
+ override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = {
+ val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
+ if (addr != null) {
+ val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
+ val broadcastMessage =
+ synchronized {
+ remoteAddresses.get(clientAddr).map(AssociationError(cause, _))
+ }
+ if (broadcastMessage.isEmpty) {
+ logError(cause.getMessage, cause)
+ } else {
+ dispatcher.broadcastMessage(broadcastMessage.get)
+ }
+ } else {
+ // If the channel is closed before connecting, its remoteAddress will be null.
+ // See java.net.Socket.getRemoteSocketAddress
+ // Because we cannot get a RpcAddress, just log it
+ logError("Exception before connecting to the client", cause)
+ }
+ }
+
+ override def connectionTerminated(client: TransportClient): Unit = {
+ val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
+ if (addr != null) {
+ val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
+ val broadcastMessage =
+ synchronized {
+ // If the last connection to a remote RpcEnv is terminated, we should broadcast
+ // "Disassociated"
+ remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress =>
+ remoteAddresses -= clientAddr
+ val count = remoteConnectionCount.getOrElse(remoteEnvAddress, 0)
+ assert(count != 0, "remoteAddresses and remoteConnectionCount are not consistent")
+ if (count - 1 == 0) {
+ // We lost all clients, so clean up and fire "Disassociated"
+ remoteConnectionCount.remove(remoteEnvAddress)
+ Some(Disassociated(remoteEnvAddress))
+ } else {
+ // Decrease the connection number of remoteEnvAddress
+ remoteConnectionCount.put(remoteEnvAddress, count - 1)
+ None
+ }
+ }
+ }
+ broadcastMessage.foreach(dispatcher.broadcastMessage)
+ } else {
+ // If the channel is closed before connecting, its remoteAddress will be null. In this case,
+ // we can ignore it since we don't fire "Associated".
+ // See java.net.Socket.getRemoteSocketAddress
+ }
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala
index 7478ab0fc2..e749631bf6 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala
@@ -19,7 +19,7 @@ package org.apache.spark.storage
import scala.concurrent.{ExecutionContext, Future}
-import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint}
+import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext, RpcEndpoint}
import org.apache.spark.util.ThreadUtils
import org.apache.spark.{Logging, MapOutputTracker, SparkEnv}
import org.apache.spark.storage.BlockManagerMessages._
@@ -33,7 +33,7 @@ class BlockManagerSlaveEndpoint(
override val rpcEnv: RpcEnv,
blockManager: BlockManager,
mapOutputTracker: MapOutputTracker)
- extends RpcEndpoint with Logging {
+ extends ThreadSafeRpcEndpoint with Logging {
private val asyncThreadPool =
ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool")
@@ -80,7 +80,7 @@ class BlockManagerSlaveEndpoint(
future.onSuccess { case response =>
logDebug("Done " + actionMessage + ", response is " + response)
context.reply(response)
- logDebug("Sent response: " + response + " to " + context.sender)
+ logDebug("Sent response: " + response + " to " + context.senderAddress)
}
future.onFailure { case t: Throwable =>
logError("Error in " + actionMessage, t)
diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
index 22e291a2b4..1ed098379e 100644
--- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
@@ -85,7 +85,11 @@ private[spark] object ThreadUtils {
*/
def newDaemonSingleThreadScheduledExecutor(threadName: String): ScheduledExecutorService = {
val threadFactory = new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadName).build()
- Executors.newSingleThreadScheduledExecutor(threadFactory)
+ val executor = new ScheduledThreadPoolExecutor(1, threadFactory)
+ // By default, a cancelled task is not automatically removed from the work queue until its delay
+ // elapses. We have to enable it manually.
+ executor.setRemoveOnCancelPolicy(true)
+ executor
}
/**