aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorzsxwing <zsxwing@gmail.com>2015-10-03 01:04:35 -0700
committerReynold Xin <rxin@databricks.com>2015-10-03 01:04:35 -0700
commit107320c9bbfe2496963a4e75e60fd6ba7fbfbabc (patch)
tree869ad3f1a13445c563493492075feb77d4c45606
parent314bc68435ac3901a97724b9eccd1daf8f89578e (diff)
downloadspark-107320c9bbfe2496963a4e75e60fd6ba7fbfbabc.tar.gz
spark-107320c9bbfe2496963a4e75e60fd6ba7fbfbabc.tar.bz2
spark-107320c9bbfe2496963a4e75e60fd6ba7fbfbabc.zip
[SPARK-6028] [CORE] Remerge #6457: new RPC implemetation and also pick #8905
This PR just reverted https://github.com/apache/spark/commit/02144d6745ec0a6d8877d969feb82139bd22437f to remerge #6457 and also included the commits in #8905. Author: zsxwing <zsxwing@gmail.com> Closes #8944 from zsxwing/SPARK-6028.
-rw-r--r--core/src/main/scala/org/apache/spark/MapOutputTracker.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala20
-rwxr-xr-xcore/src/main/scala/org/apache/spark/deploy/worker/Worker.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala51
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/RpcEndpointNotFoundException.scala22
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala218
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala39
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala227
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala56
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala87
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala504
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/util/ThreadUtils.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala10
-rw-r--r--core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala81
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala123
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala150
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala29
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala38
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala67
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClient.java4
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java2
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java1
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala2
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala5
31 files changed, 1715 insertions, 71 deletions
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index e4cb72e39d..45e12e40c8 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -45,7 +45,7 @@ private[spark] class MapOutputTrackerMasterEndpoint(
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case GetMapOutputStatuses(shuffleId: Int) =>
- val hostPort = context.sender.address.hostPort
+ val hostPort = context.senderAddress.hostPort
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId)
val serializedSize = mapOutputStatuses.size
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index c6fef7f91f..cfde27fb2e 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -20,11 +20,10 @@ package org.apache.spark
import java.io.File
import java.net.Socket
-import akka.actor.ActorSystem
-
import scala.collection.mutable
import scala.util.Properties
+import akka.actor.ActorSystem
import com.google.common.collect.MapMaker
import org.apache.spark.annotation.DeveloperApi
@@ -41,7 +40,7 @@ import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager}
import org.apache.spark.storage._
import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator}
-import org.apache.spark.util.{RpcUtils, Utils}
+import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils}
/**
* :: DeveloperApi ::
@@ -57,6 +56,7 @@ import org.apache.spark.util.{RpcUtils, Utils}
class SparkEnv (
val executorId: String,
private[spark] val rpcEnv: RpcEnv,
+ _actorSystem: ActorSystem, // TODO Remove actorSystem
val serializer: Serializer,
val closureSerializer: Serializer,
val cacheManager: CacheManager,
@@ -76,7 +76,7 @@ class SparkEnv (
// TODO Remove actorSystem
@deprecated("Actor system is no longer supported as of 1.4.0", "1.4.0")
- val actorSystem: ActorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem
+ val actorSystem: ActorSystem = _actorSystem
private[spark] var isStopped = false
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
@@ -100,6 +100,9 @@ class SparkEnv (
blockManager.master.stop()
metricsSystem.stop()
outputCommitCoordinator.stop()
+ if (!rpcEnv.isInstanceOf[AkkaRpcEnv]) {
+ actorSystem.shutdown()
+ }
rpcEnv.shutdown()
// Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut
@@ -249,7 +252,13 @@ object SparkEnv extends Logging {
// Create the ActorSystem for Akka and get the port it binds to.
val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName
val rpcEnv = RpcEnv.create(actorSystemName, hostname, port, conf, securityManager)
- val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem
+ val actorSystem: ActorSystem =
+ if (rpcEnv.isInstanceOf[AkkaRpcEnv]) {
+ rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem
+ } else {
+ // Create a ActorSystem for legacy codes
+ AkkaUtils.createActorSystem(actorSystemName, hostname, port, conf, securityManager)._1
+ }
// Figure out which port Akka actually bound to in case the original port is 0 or occupied.
if (isDriver) {
@@ -395,6 +404,7 @@ object SparkEnv extends Logging {
val envInstance = new SparkEnv(
executorId,
rpcEnv,
+ actorSystem,
serializer,
closureSerializer,
cacheManager,
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index 770927c80f..93a1b3f310 100755
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -329,7 +329,7 @@ private[deploy] class Worker(
registrationRetryTimer = Some(forwordMessageScheduler.scheduleAtFixedRate(
new Runnable {
override def run(): Unit = Utils.tryLogNonFatalError {
- self.send(ReregisterWithMaster)
+ Option(self).foreach(_.send(ReregisterWithMaster))
}
},
INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS,
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
index 735c4f0927..ab56fde938 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
@@ -24,14 +24,13 @@ import org.apache.spark.rpc._
* Actor which connects to a worker process and terminates the JVM if the connection is severed.
* Provides fate sharing between a worker and its associated child processes.
*/
-private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: String)
+private[spark] class WorkerWatcher(
+ override val rpcEnv: RpcEnv, workerUrl: String, isTesting: Boolean = false)
extends RpcEndpoint with Logging {
- override def onStart() {
- logInfo(s"Connecting to worker $workerUrl")
- if (!isTesting) {
- rpcEnv.asyncSetupEndpointRefByURI(workerUrl)
- }
+ logInfo(s"Connecting to worker $workerUrl")
+ if (!isTesting) {
+ rpcEnv.asyncSetupEndpointRefByURI(workerUrl)
}
// Used to avoid shutting down JVM during tests
@@ -40,8 +39,6 @@ private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: Strin
// true rather than calling `System.exit`. The user can check `isShutDown` to know if
// `exitNonZero` is called.
private[deploy] var isShutDown = false
- private[deploy] def setTesting(testing: Boolean) = isTesting = testing
- private var isTesting = false
// Lets filter events only from the worker's rpc system
private val expectedAddress = RpcAddress.fromURIString(workerUrl)
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala
index 3e5b64265e..f527ec86ab 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala
@@ -37,5 +37,5 @@ private[spark] trait RpcCallContext {
/**
* The sender of this message.
*/
- def sender: RpcEndpointRef
+ def senderAddress: RpcAddress
}
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala
index dfcbc51cdf..f1ddc6d2cd 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala
@@ -29,20 +29,6 @@ private[spark] trait RpcEnvFactory {
}
/**
- * A trait that requires RpcEnv thread-safely sending messages to it.
- *
- * Thread-safety means processing of one message happens before processing of the next message by
- * the same [[ThreadSafeRpcEndpoint]]. In the other words, changes to internal fields of a
- * [[ThreadSafeRpcEndpoint]] are visible when processing the next message, and fields in the
- * [[ThreadSafeRpcEndpoint]] need not be volatile or equivalent.
- *
- * However, there is no guarantee that the same thread will be executing the same
- * [[ThreadSafeRpcEndpoint]] for different messages.
- */
-private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint
-
-
-/**
* An end point for the RPC that defines what functions to trigger given a message.
*
* It is guaranteed that `onStart`, `receive` and `onStop` will be called in sequence.
@@ -101,38 +87,39 @@ private[spark] trait RpcEndpoint {
}
/**
- * Invoked before [[RpcEndpoint]] starts to handle any message.
+ * Invoked when `remoteAddress` is connected to the current node.
*/
- def onStart(): Unit = {
+ def onConnected(remoteAddress: RpcAddress): Unit = {
// By default, do nothing.
}
/**
- * Invoked when [[RpcEndpoint]] is stopping.
+ * Invoked when `remoteAddress` is lost.
*/
- def onStop(): Unit = {
+ def onDisconnected(remoteAddress: RpcAddress): Unit = {
// By default, do nothing.
}
/**
- * Invoked when `remoteAddress` is connected to the current node.
+ * Invoked when some network error happens in the connection between the current node and
+ * `remoteAddress`.
*/
- def onConnected(remoteAddress: RpcAddress): Unit = {
+ def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = {
// By default, do nothing.
}
/**
- * Invoked when `remoteAddress` is lost.
+ * Invoked before [[RpcEndpoint]] starts to handle any message.
*/
- def onDisconnected(remoteAddress: RpcAddress): Unit = {
+ def onStart(): Unit = {
// By default, do nothing.
}
/**
- * Invoked when some network error happens in the connection between the current node and
- * `remoteAddress`.
+ * Invoked when [[RpcEndpoint]] is stopping. `self` will be `null` in this method and you cannot
+ * use it to send or ask messages.
*/
- def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = {
+ def onStop(): Unit = {
// By default, do nothing.
}
@@ -146,3 +133,17 @@ private[spark] trait RpcEndpoint {
}
}
}
+
+/**
+ * A trait that requires RpcEnv thread-safely sending messages to it.
+ *
+ * Thread-safety means processing of one message happens before processing of the next message by
+ * the same [[ThreadSafeRpcEndpoint]]. In the other words, changes to internal fields of a
+ * [[ThreadSafeRpcEndpoint]] are visible when processing the next message, and fields in the
+ * [[ThreadSafeRpcEndpoint]] need not be volatile or equivalent.
+ *
+ * However, there is no guarantee that the same thread will be executing the same
+ * [[ThreadSafeRpcEndpoint]] for different messages.
+ */
+private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint {
+}
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointNotFoundException.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointNotFoundException.scala
new file mode 100644
index 0000000000..d177881fb3
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointNotFoundException.scala
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.rpc
+
+import org.apache.spark.SparkException
+
+private[rpc] class RpcEndpointNotFoundException(uri: String)
+ extends SparkException(s"Cannot find endpoint: $uri")
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
index 29debe8081..35e402c725 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
@@ -36,8 +36,9 @@ private[spark] object RpcEnv {
private def getRpcEnvFactory(conf: SparkConf): RpcEnvFactory = {
// Add more RpcEnv implementations here
- val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory")
- val rpcEnvName = conf.get("spark.rpc", "akka")
+ val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory",
+ "netty" -> "org.apache.spark.rpc.netty.NettyRpcEnvFactory")
+ val rpcEnvName = conf.get("spark.rpc", "netty")
val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName)
Utils.classForName(rpcEnvFactoryClassName).newInstance().asInstanceOf[RpcEnvFactory]
}
diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
index ad67e1c5ad..95132a4e4a 100644
--- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
@@ -166,9 +166,9 @@ private[spark] class AkkaRpcEnv private[akka] (
_sender ! AkkaMessage(response, false)
}
- // Some RpcEndpoints need to know the sender's address
- override val sender: RpcEndpointRef =
- new AkkaRpcEndpointRef(defaultAddress, _sender, conf)
+ // Use "lazy" because most of RpcEndpoints don't need "senderAddress"
+ override lazy val senderAddress: RpcAddress =
+ new AkkaRpcEndpointRef(defaultAddress, _sender, conf).address
})
} else {
endpoint.receive
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
new file mode 100644
index 0000000000..d71e6f01db
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rpc.netty
+
+import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, TimeUnit}
+import javax.annotation.concurrent.GuardedBy
+
+import scala.collection.JavaConverters._
+import scala.concurrent.Promise
+import scala.util.control.NonFatal
+
+import org.apache.spark.{SparkException, Logging}
+import org.apache.spark.network.client.RpcResponseCallback
+import org.apache.spark.rpc._
+import org.apache.spark.util.ThreadUtils
+
+private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
+
+ private class EndpointData(
+ val name: String,
+ val endpoint: RpcEndpoint,
+ val ref: NettyRpcEndpointRef) {
+ val inbox = new Inbox(ref, endpoint)
+ }
+
+ private val endpoints = new ConcurrentHashMap[String, EndpointData]()
+ private val endpointRefs = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]()
+
+ // Track the receivers whose inboxes may contain messages.
+ private val receivers = new LinkedBlockingQueue[EndpointData]()
+
+ @GuardedBy("this")
+ private var stopped = false
+
+ def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = {
+ val addr = new NettyRpcAddress(nettyEnv.address.host, nettyEnv.address.port, name)
+ val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv)
+ synchronized {
+ if (stopped) {
+ throw new IllegalStateException("RpcEnv has been stopped")
+ }
+ if (endpoints.putIfAbsent(name, new EndpointData(name, endpoint, endpointRef)) != null) {
+ throw new IllegalArgumentException(s"There is already an RpcEndpoint called $name")
+ }
+ val data = endpoints.get(name)
+ endpointRefs.put(data.endpoint, data.ref)
+ receivers.put(data)
+ }
+ endpointRef
+ }
+
+ def getRpcEndpointRef(endpoint: RpcEndpoint): RpcEndpointRef = endpointRefs.get(endpoint)
+
+ def removeRpcEndpointRef(endpoint: RpcEndpoint): Unit = endpointRefs.remove(endpoint)
+
+ // Should be idempotent
+ private def unregisterRpcEndpoint(name: String): Unit = {
+ val data = endpoints.remove(name)
+ if (data != null) {
+ data.inbox.stop()
+ receivers.put(data)
+ }
+ // Don't clean `endpointRefs` here because it's possible that some messages are being processed
+ // now and they can use `getRpcEndpointRef`. So `endpointRefs` will be cleaned in Inbox via
+ // `removeRpcEndpointRef`.
+ }
+
+ def stop(rpcEndpointRef: RpcEndpointRef): Unit = {
+ synchronized {
+ if (stopped) {
+ // This endpoint will be stopped by Dispatcher.stop() method.
+ return
+ }
+ unregisterRpcEndpoint(rpcEndpointRef.name)
+ }
+ }
+
+ /**
+ * Send a message to all registered [[RpcEndpoint]]s.
+ * @param message
+ */
+ def broadcastMessage(message: InboxMessage): Unit = {
+ val iter = endpoints.keySet().iterator()
+ while (iter.hasNext) {
+ val name = iter.next
+ postMessageToInbox(name, (_) => message,
+ () => { logWarning(s"Drop ${message} because ${name} has been stopped") })
+ }
+ }
+
+ def postMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = {
+ def createMessage(sender: NettyRpcEndpointRef): InboxMessage = {
+ val rpcCallContext =
+ new RemoteNettyRpcCallContext(
+ nettyEnv, sender, callback, message.senderAddress, message.needReply)
+ ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext)
+ }
+
+ def onEndpointStopped(): Unit = {
+ callback.onFailure(
+ new SparkException(s"Could not find ${message.receiver.name} or it has been stopped"))
+ }
+
+ postMessageToInbox(message.receiver.name, createMessage, onEndpointStopped)
+ }
+
+ def postMessage(message: RequestMessage, p: Promise[Any]): Unit = {
+ def createMessage(sender: NettyRpcEndpointRef): InboxMessage = {
+ val rpcCallContext =
+ new LocalNettyRpcCallContext(sender, message.senderAddress, message.needReply, p)
+ ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext)
+ }
+
+ def onEndpointStopped(): Unit = {
+ p.tryFailure(
+ new SparkException(s"Could not find ${message.receiver.name} or it has been stopped"))
+ }
+
+ postMessageToInbox(message.receiver.name, createMessage, onEndpointStopped)
+ }
+
+ private def postMessageToInbox(
+ endpointName: String,
+ createMessageFn: NettyRpcEndpointRef => InboxMessage,
+ onStopped: () => Unit): Unit = {
+ val shouldCallOnStop =
+ synchronized {
+ val data = endpoints.get(endpointName)
+ if (stopped || data == null) {
+ true
+ } else {
+ data.inbox.post(createMessageFn(data.ref))
+ receivers.put(data)
+ false
+ }
+ }
+ if (shouldCallOnStop) {
+ // We don't need to call `onStop` in the `synchronized` block
+ onStopped()
+ }
+ }
+
+ private val parallelism = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.parallelism",
+ Runtime.getRuntime.availableProcessors())
+
+ private val executor = ThreadUtils.newDaemonFixedThreadPool(parallelism, "dispatcher-event-loop")
+
+ (0 until parallelism) foreach { _ =>
+ executor.execute(new MessageLoop)
+ }
+
+ def stop(): Unit = {
+ synchronized {
+ if (stopped) {
+ return
+ }
+ stopped = true
+ }
+ // Stop all endpoints. This will queue all endpoints for processing by the message loops.
+ endpoints.keySet().asScala.foreach(unregisterRpcEndpoint)
+ // Enqueue a message that tells the message loops to stop.
+ receivers.put(PoisonEndpoint)
+ executor.shutdown()
+ }
+
+ def awaitTermination(): Unit = {
+ executor.awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS)
+ }
+
+ /**
+ * Return if the endpoint exists
+ */
+ def verify(name: String): Boolean = {
+ endpoints.containsKey(name)
+ }
+
+ private class MessageLoop extends Runnable {
+ override def run(): Unit = {
+ try {
+ while (true) {
+ try {
+ val data = receivers.take()
+ if (data == PoisonEndpoint) {
+ // Put PoisonEndpoint back so that other MessageLoops can see it.
+ receivers.put(PoisonEndpoint)
+ return
+ }
+ data.inbox.process(Dispatcher.this)
+ } catch {
+ case NonFatal(e) => logError(e.getMessage, e)
+ }
+ }
+ } catch {
+ case ie: InterruptedException => // exit
+ }
+ }
+ }
+
+ /**
+ * A poison endpoint that indicates MessageLoop should exit its loop.
+ */
+ private val PoisonEndpoint = new EndpointData(null, null, null)
+}
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala b/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala
new file mode 100644
index 0000000000..6061c9b8de
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.rpc.netty
+
+import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv}
+
+/**
+ * A message used to ask the remote [[IDVerifier]] if an [[RpcEndpoint]] exists
+ */
+private[netty] case class ID(name: String)
+
+/**
+ * An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if a [[RpcEndpoint]] exists in this [[RpcEnv]]
+ */
+private[netty] class IDVerifier(
+ override val rpcEnv: RpcEnv, dispatcher: Dispatcher) extends RpcEndpoint {
+
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ case ID(name) => context.reply(dispatcher.verify(name))
+ }
+}
+
+private[netty] object IDVerifier {
+ val NAME = "id-verifier"
+}
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala
new file mode 100644
index 0000000000..b669f59a28
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala
@@ -0,0 +1,227 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rpc.netty
+
+import java.util.LinkedList
+import javax.annotation.concurrent.GuardedBy
+
+import scala.util.control.NonFatal
+
+import org.apache.spark.{Logging, SparkException}
+import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, ThreadSafeRpcEndpoint}
+
+private[netty] sealed trait InboxMessage
+
+private[netty] case class ContentMessage(
+ senderAddress: RpcAddress,
+ content: Any,
+ needReply: Boolean,
+ context: NettyRpcCallContext) extends InboxMessage
+
+private[netty] case object OnStart extends InboxMessage
+
+private[netty] case object OnStop extends InboxMessage
+
+/**
+ * A broadcast message that indicates connecting to a remote node.
+ */
+private[netty] case class Associated(remoteAddress: RpcAddress) extends InboxMessage
+
+/**
+ * A broadcast message that indicates a remote connection is lost.
+ */
+private[netty] case class Disassociated(remoteAddress: RpcAddress) extends InboxMessage
+
+/**
+ * A broadcast message that indicates a network error
+ */
+private[netty] case class AssociationError(cause: Throwable, remoteAddress: RpcAddress)
+ extends InboxMessage
+
+/**
+ * A inbox that stores messages for an [[RpcEndpoint]] and posts messages to it thread-safely.
+ * @param endpointRef
+ * @param endpoint
+ */
+private[netty] class Inbox(
+ val endpointRef: NettyRpcEndpointRef,
+ val endpoint: RpcEndpoint) extends Logging {
+
+ inbox =>
+
+ @GuardedBy("this")
+ protected val messages = new LinkedList[InboxMessage]()
+
+ @GuardedBy("this")
+ private var stopped = false
+
+ @GuardedBy("this")
+ private var enableConcurrent = false
+
+ @GuardedBy("this")
+ private var workerCount = 0
+
+ // OnStart should be the first message to process
+ inbox.synchronized {
+ messages.add(OnStart)
+ }
+
+ /**
+ * Process stored messages.
+ */
+ def process(dispatcher: Dispatcher): Unit = {
+ var message: InboxMessage = null
+ inbox.synchronized {
+ if (!enableConcurrent && workerCount != 0) {
+ return
+ }
+ message = messages.poll()
+ if (message != null) {
+ workerCount += 1
+ } else {
+ return
+ }
+ }
+ while (true) {
+ safelyCall(endpoint) {
+ message match {
+ case ContentMessage(_sender, content, needReply, context) =>
+ val pf: PartialFunction[Any, Unit] =
+ if (needReply) {
+ endpoint.receiveAndReply(context)
+ } else {
+ endpoint.receive
+ }
+ try {
+ pf.applyOrElse[Any, Unit](content, { msg =>
+ throw new SparkException(s"Unmatched message $message from ${_sender}")
+ })
+ if (!needReply) {
+ context.finish()
+ }
+ } catch {
+ case NonFatal(e) =>
+ if (needReply) {
+ // If the sender asks a reply, we should send the error back to the sender
+ context.sendFailure(e)
+ } else {
+ context.finish()
+ throw e
+ }
+ }
+
+ case OnStart => {
+ endpoint.onStart()
+ if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
+ inbox.synchronized {
+ if (!stopped) {
+ enableConcurrent = true
+ }
+ }
+ }
+ }
+
+ case OnStop =>
+ val _workCount = inbox.synchronized {
+ workerCount
+ }
+ assert(_workCount == 1, s"There should be only one worker but was ${_workCount}")
+ dispatcher.removeRpcEndpointRef(endpoint)
+ endpoint.onStop()
+ assert(isEmpty, "OnStop should be the last message")
+
+ case Associated(remoteAddress) =>
+ endpoint.onConnected(remoteAddress)
+
+ case Disassociated(remoteAddress) =>
+ endpoint.onDisconnected(remoteAddress)
+
+ case AssociationError(cause, remoteAddress) =>
+ endpoint.onNetworkError(cause, remoteAddress)
+ }
+ }
+
+ inbox.synchronized {
+ // "enableConcurrent" will be set to false after `onStop` is called, so we should check it
+ // every time.
+ if (!enableConcurrent && workerCount != 1) {
+ // If we are not the only one worker, exit
+ workerCount -= 1
+ return
+ }
+ message = messages.poll()
+ if (message == null) {
+ workerCount -= 1
+ return
+ }
+ }
+ }
+ }
+
+ def post(message: InboxMessage): Unit = {
+ val dropped =
+ inbox.synchronized {
+ if (stopped) {
+ // We already put "OnStop" into "messages", so we should drop further messages
+ true
+ } else {
+ messages.add(message)
+ false
+ }
+ }
+ if (dropped) {
+ onDrop(message)
+ }
+ }
+
+ def stop(): Unit = inbox.synchronized {
+ // The following codes should be in `synchronized` so that we can make sure "OnStop" is the last
+ // message
+ if (!stopped) {
+ // We should disable concurrent here. Then when RpcEndpoint.onStop is called, it's the only
+ // thread that is processing messages. So `RpcEndpoint.onStop` can release its resources
+ // safely.
+ enableConcurrent = false
+ stopped = true
+ messages.add(OnStop)
+ // Note: The concurrent events in messages will be processed one by one.
+ }
+ }
+
+ // Visible for testing.
+ protected def onDrop(message: InboxMessage): Unit = {
+ logWarning(s"Drop ${message} because $endpointRef is stopped")
+ }
+
+ def isEmpty: Boolean = inbox.synchronized { messages.isEmpty }
+
+ private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = {
+ try {
+ action
+ } catch {
+ case NonFatal(e) => {
+ try {
+ endpoint.onError(e)
+ } catch {
+ case NonFatal(e) => logWarning(s"Ignore error", e)
+ }
+ }
+ }
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala
new file mode 100644
index 0000000000..1876b25592
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rpc.netty
+
+import java.net.URI
+
+import org.apache.spark.SparkException
+import org.apache.spark.rpc.RpcAddress
+
+private[netty] case class NettyRpcAddress(host: String, port: Int, name: String) {
+
+ def toRpcAddress: RpcAddress = RpcAddress(host, port)
+
+ override val toString = s"spark://$name@$host:$port"
+}
+
+private[netty] object NettyRpcAddress {
+
+ def apply(sparkUrl: String): NettyRpcAddress = {
+ try {
+ val uri = new URI(sparkUrl)
+ val host = uri.getHost
+ val port = uri.getPort
+ val name = uri.getUserInfo
+ if (uri.getScheme != "spark" ||
+ host == null ||
+ port < 0 ||
+ name == null ||
+ (uri.getPath != null && !uri.getPath.isEmpty) || // uri.getPath returns "" instead of null
+ uri.getFragment != null ||
+ uri.getQuery != null) {
+ throw new SparkException("Invalid Spark URL: " + sparkUrl)
+ }
+ NettyRpcAddress(host, port, name)
+ } catch {
+ case e: java.net.URISyntaxException =>
+ throw new SparkException("Invalid Spark URL: " + sparkUrl, e)
+ }
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala
new file mode 100644
index 0000000000..75dcc02a0c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rpc.netty
+
+import scala.concurrent.Promise
+
+import org.apache.spark.Logging
+import org.apache.spark.network.client.RpcResponseCallback
+import org.apache.spark.rpc.{RpcAddress, RpcCallContext}
+
+private[netty] abstract class NettyRpcCallContext(
+ endpointRef: NettyRpcEndpointRef,
+ override val senderAddress: RpcAddress,
+ needReply: Boolean) extends RpcCallContext with Logging {
+
+ protected def send(message: Any): Unit
+
+ override def reply(response: Any): Unit = {
+ if (needReply) {
+ send(AskResponse(endpointRef, response))
+ } else {
+ throw new IllegalStateException(
+ s"Cannot send $response to the sender because the sender won't handle it")
+ }
+ }
+
+ override def sendFailure(e: Throwable): Unit = {
+ if (needReply) {
+ send(AskResponse(endpointRef, RpcFailure(e)))
+ } else {
+ logError(e.getMessage, e)
+ throw new IllegalStateException(
+ "Cannot send reply to the sender because the sender won't handle it")
+ }
+ }
+
+ def finish(): Unit = {
+ if (!needReply) {
+ send(Ack(endpointRef))
+ }
+ }
+}
+
+/**
+ * If the sender and the receiver are in the same process, the reply can be sent back via `Promise`.
+ */
+private[netty] class LocalNettyRpcCallContext(
+ endpointRef: NettyRpcEndpointRef,
+ senderAddress: RpcAddress,
+ needReply: Boolean,
+ p: Promise[Any]) extends NettyRpcCallContext(endpointRef, senderAddress, needReply) {
+
+ override protected def send(message: Any): Unit = {
+ p.success(message)
+ }
+}
+
+/**
+ * A [[RpcCallContext]] that will call [[RpcResponseCallback]] to send the reply back.
+ */
+private[netty] class RemoteNettyRpcCallContext(
+ nettyEnv: NettyRpcEnv,
+ endpointRef: NettyRpcEndpointRef,
+ callback: RpcResponseCallback,
+ senderAddress: RpcAddress,
+ needReply: Boolean) extends NettyRpcCallContext(endpointRef, senderAddress, needReply) {
+
+ override protected def send(message: Any): Unit = {
+ val reply = nettyEnv.serialize(message)
+ callback.onSuccess(reply)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
new file mode 100644
index 0000000000..5522b40782
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
@@ -0,0 +1,504 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.rpc.netty
+
+import java.io._
+import java.net.{InetSocketAddress, URI}
+import java.nio.ByteBuffer
+import java.util.Arrays
+import java.util.concurrent._
+import javax.annotation.concurrent.GuardedBy
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+import scala.concurrent.{Future, Promise}
+import scala.reflect.ClassTag
+import scala.util.{DynamicVariable, Failure, Success}
+import scala.util.control.NonFatal
+
+import org.apache.spark.{Logging, SecurityManager, SparkConf}
+import org.apache.spark.network.TransportContext
+import org.apache.spark.network.client._
+import org.apache.spark.network.netty.SparkTransportConf
+import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap}
+import org.apache.spark.network.server._
+import org.apache.spark.rpc._
+import org.apache.spark.serializer.{JavaSerializer, JavaSerializerInstance}
+import org.apache.spark.util.{ThreadUtils, Utils}
+
+private[netty] class NettyRpcEnv(
+ val conf: SparkConf,
+ javaSerializerInstance: JavaSerializerInstance,
+ host: String,
+ securityManager: SecurityManager) extends RpcEnv(conf) with Logging {
+
+ private val transportConf =
+ SparkTransportConf.fromSparkConf(conf, conf.getInt("spark.rpc.io.threads", 0))
+
+ private val dispatcher: Dispatcher = new Dispatcher(this)
+
+ private val transportContext =
+ new TransportContext(transportConf, new NettyRpcHandler(dispatcher, this))
+
+ private val clientFactory = {
+ val bootstraps: Seq[TransportClientBootstrap] =
+ if (securityManager.isAuthenticationEnabled()) {
+ Seq(new SaslClientBootstrap(transportConf, "", securityManager,
+ securityManager.isSaslEncryptionEnabled()))
+ } else {
+ Nil
+ }
+ transportContext.createClientFactory(bootstraps.asJava)
+ }
+
+ val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout")
+
+ // Because TransportClientFactory.createClient is blocking, we need to run it in this thread pool
+ // to implement non-blocking send/ask.
+ // TODO: a non-blocking TransportClientFactory.createClient in future
+ private val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool(
+ "netty-rpc-connection",
+ conf.getInt("spark.rpc.connect.threads", 256))
+
+ @volatile private var server: TransportServer = _
+
+ def start(port: Int): Unit = {
+ val bootstraps: Seq[TransportServerBootstrap] =
+ if (securityManager.isAuthenticationEnabled()) {
+ Seq(new SaslServerBootstrap(transportConf, securityManager))
+ } else {
+ Nil
+ }
+ server = transportContext.createServer(port, bootstraps.asJava)
+ dispatcher.registerRpcEndpoint(IDVerifier.NAME, new IDVerifier(this, dispatcher))
+ }
+
+ override lazy val address: RpcAddress = {
+ require(server != null, "NettyRpcEnv has not yet started")
+ RpcAddress(host, server.getPort())
+ }
+
+ override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = {
+ dispatcher.registerRpcEndpoint(name, endpoint)
+ }
+
+ def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = {
+ val addr = NettyRpcAddress(uri)
+ val endpointRef = new NettyRpcEndpointRef(conf, addr, this)
+ val idVerifierRef =
+ new NettyRpcEndpointRef(conf, NettyRpcAddress(addr.host, addr.port, IDVerifier.NAME), this)
+ idVerifierRef.ask[Boolean](ID(endpointRef.name)).flatMap { find =>
+ if (find) {
+ Future.successful(endpointRef)
+ } else {
+ Future.failed(new RpcEndpointNotFoundException(uri))
+ }
+ }(ThreadUtils.sameThread)
+ }
+
+ override def stop(endpointRef: RpcEndpointRef): Unit = {
+ require(endpointRef.isInstanceOf[NettyRpcEndpointRef])
+ dispatcher.stop(endpointRef)
+ }
+
+ private[netty] def send(message: RequestMessage): Unit = {
+ val remoteAddr = message.receiver.address
+ if (remoteAddr == address) {
+ val promise = Promise[Any]()
+ dispatcher.postMessage(message, promise)
+ promise.future.onComplete {
+ case Success(response) =>
+ val ack = response.asInstanceOf[Ack]
+ logDebug(s"Receive ack from ${ack.sender}")
+ case Failure(e) =>
+ logError(s"Exception when sending $message", e)
+ }(ThreadUtils.sameThread)
+ } else {
+ try {
+ // `createClient` will block if it cannot find a known connection, so we should run it in
+ // clientConnectionExecutor
+ clientConnectionExecutor.execute(new Runnable {
+ override def run(): Unit = Utils.tryLogNonFatalError {
+ val client = clientFactory.createClient(remoteAddr.host, remoteAddr.port)
+ client.sendRpc(serialize(message), new RpcResponseCallback {
+
+ override def onFailure(e: Throwable): Unit = {
+ logError(s"Exception when sending $message", e)
+ }
+
+ override def onSuccess(response: Array[Byte]): Unit = {
+ val ack = deserialize[Ack](response)
+ logDebug(s"Receive ack from ${ack.sender}")
+ }
+ })
+ }
+ })
+ } catch {
+ case e: RejectedExecutionException => {
+ // `send` after shutting clientConnectionExecutor down, ignore it
+ logWarning(s"Cannot send ${message} because RpcEnv is stopped")
+ }
+ }
+ }
+ }
+
+ private[netty] def ask(message: RequestMessage): Future[Any] = {
+ val promise = Promise[Any]()
+ val remoteAddr = message.receiver.address
+ if (remoteAddr == address) {
+ val p = Promise[Any]()
+ dispatcher.postMessage(message, p)
+ p.future.onComplete {
+ case Success(response) =>
+ val reply = response.asInstanceOf[AskResponse]
+ if (reply.reply.isInstanceOf[RpcFailure]) {
+ if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) {
+ logWarning(s"Ignore failure: ${reply.reply}")
+ }
+ } else if (!promise.trySuccess(reply.reply)) {
+ logWarning(s"Ignore message: ${reply}")
+ }
+ case Failure(e) =>
+ if (!promise.tryFailure(e)) {
+ logWarning("Ignore Exception", e)
+ }
+ }(ThreadUtils.sameThread)
+ } else {
+ try {
+ // `createClient` will block if it cannot find a known connection, so we should run it in
+ // clientConnectionExecutor
+ clientConnectionExecutor.execute(new Runnable {
+ override def run(): Unit = {
+ val client = clientFactory.createClient(remoteAddr.host, remoteAddr.port)
+ client.sendRpc(serialize(message), new RpcResponseCallback {
+
+ override def onFailure(e: Throwable): Unit = {
+ if (!promise.tryFailure(e)) {
+ logWarning("Ignore Exception", e)
+ }
+ }
+
+ override def onSuccess(response: Array[Byte]): Unit = {
+ val reply = deserialize[AskResponse](response)
+ if (reply.reply.isInstanceOf[RpcFailure]) {
+ if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) {
+ logWarning(s"Ignore failure: ${reply.reply}")
+ }
+ } else if (!promise.trySuccess(reply.reply)) {
+ logWarning(s"Ignore message: ${reply}")
+ }
+ }
+ })
+ }
+ })
+ } catch {
+ case e: RejectedExecutionException => {
+ if (!promise.tryFailure(e)) {
+ logWarning(s"Ignore failure", e)
+ }
+ }
+ }
+ }
+ promise.future
+ }
+
+ private[netty] def serialize(content: Any): Array[Byte] = {
+ val buffer = javaSerializerInstance.serialize(content)
+ Arrays.copyOfRange(
+ buffer.array(), buffer.arrayOffset + buffer.position, buffer.arrayOffset + buffer.limit)
+ }
+
+ private[netty] def deserialize[T: ClassTag](bytes: Array[Byte]): T = {
+ deserialize { () =>
+ javaSerializerInstance.deserialize[T](ByteBuffer.wrap(bytes))
+ }
+ }
+
+ override def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = {
+ dispatcher.getRpcEndpointRef(endpoint)
+ }
+
+ override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String =
+ new NettyRpcAddress(address.host, address.port, endpointName).toString
+
+ override def shutdown(): Unit = {
+ cleanup()
+ }
+
+ override def awaitTermination(): Unit = {
+ dispatcher.awaitTermination()
+ }
+
+ private def cleanup(): Unit = {
+ if (timeoutScheduler != null) {
+ timeoutScheduler.shutdownNow()
+ }
+ if (server != null) {
+ server.close()
+ }
+ if (clientFactory != null) {
+ clientFactory.close()
+ }
+ if (dispatcher != null) {
+ dispatcher.stop()
+ }
+ if (clientConnectionExecutor != null) {
+ clientConnectionExecutor.shutdownNow()
+ }
+ }
+
+ override def deserialize[T](deserializationAction: () => T): T = {
+ NettyRpcEnv.currentEnv.withValue(this) {
+ deserializationAction()
+ }
+ }
+}
+
+private[netty] object NettyRpcEnv extends Logging {
+
+ /**
+ * When deserializing the [[NettyRpcEndpointRef]], it needs a reference to [[NettyRpcEnv]].
+ * Use `currentEnv` to wrap the deserialization codes. E.g.,
+ *
+ * {{{
+ * NettyRpcEnv.currentEnv.withValue(this) {
+ * your deserialization codes
+ * }
+ * }}}
+ */
+ private[netty] val currentEnv = new DynamicVariable[NettyRpcEnv](null)
+}
+
+private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
+
+ def create(config: RpcEnvConfig): RpcEnv = {
+ val sparkConf = config.conf
+ // Use JavaSerializerInstance in multiple threads is safe. However, if we plan to support
+ // KryoSerializer in future, we have to use ThreadLocal to store SerializerInstance
+ val javaSerializerInstance =
+ new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance]
+ val nettyEnv =
+ new NettyRpcEnv(sparkConf, javaSerializerInstance, config.host, config.securityManager)
+ val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort =>
+ nettyEnv.start(actualPort)
+ (nettyEnv, actualPort)
+ }
+ try {
+ Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, "NettyRpcEnv")._1
+ } catch {
+ case NonFatal(e) =>
+ nettyEnv.shutdown()
+ throw e
+ }
+ }
+}
+
+private[netty] class NettyRpcEndpointRef(@transient conf: SparkConf)
+ extends RpcEndpointRef(conf) with Serializable with Logging {
+
+ @transient @volatile private var nettyEnv: NettyRpcEnv = _
+
+ @transient @volatile private var _address: NettyRpcAddress = _
+
+ def this(conf: SparkConf, _address: NettyRpcAddress, nettyEnv: NettyRpcEnv) {
+ this(conf)
+ this._address = _address
+ this.nettyEnv = nettyEnv
+ }
+
+ override def address: RpcAddress = _address.toRpcAddress
+
+ private def readObject(in: ObjectInputStream): Unit = {
+ in.defaultReadObject()
+ _address = in.readObject().asInstanceOf[NettyRpcAddress]
+ nettyEnv = NettyRpcEnv.currentEnv.value
+ }
+
+ private def writeObject(out: ObjectOutputStream): Unit = {
+ out.defaultWriteObject()
+ out.writeObject(_address)
+ }
+
+ override def name: String = _address.name
+
+ override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
+ val promise = Promise[Any]()
+ val timeoutCancelable = nettyEnv.timeoutScheduler.schedule(new Runnable {
+ override def run(): Unit = {
+ promise.tryFailure(new TimeoutException("Cannot receive any reply in " + timeout.duration))
+ }
+ }, timeout.duration.toNanos, TimeUnit.NANOSECONDS)
+ val f = nettyEnv.ask(RequestMessage(nettyEnv.address, this, message, true))
+ f.onComplete { v =>
+ timeoutCancelable.cancel(true)
+ if (!promise.tryComplete(v)) {
+ logWarning(s"Ignore message $v")
+ }
+ }(ThreadUtils.sameThread)
+ promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread)
+ }
+
+ override def send(message: Any): Unit = {
+ require(message != null, "Message is null")
+ nettyEnv.send(RequestMessage(nettyEnv.address, this, message, false))
+ }
+
+ override def toString: String = s"NettyRpcEndpointRef(${_address})"
+
+ def toURI: URI = new URI(s"spark://${_address}")
+
+ final override def equals(that: Any): Boolean = that match {
+ case other: NettyRpcEndpointRef => _address == other._address
+ case _ => false
+ }
+
+ final override def hashCode(): Int = if (_address == null) 0 else _address.hashCode()
+}
+
+/**
+ * The message that is sent from the sender to the receiver.
+ */
+private[netty] case class RequestMessage(
+ senderAddress: RpcAddress, receiver: NettyRpcEndpointRef, content: Any, needReply: Boolean)
+
+/**
+ * The base trait for all messages that are sent back from the receiver to the sender.
+ */
+private[netty] trait ResponseMessage
+
+/**
+ * The reply for `ask` from the receiver side.
+ */
+private[netty] case class AskResponse(sender: NettyRpcEndpointRef, reply: Any)
+ extends ResponseMessage
+
+/**
+ * A message to send back to the receiver side. It's necessary because [[TransportClient]] only
+ * clean the resources when it receives a reply.
+ */
+private[netty] case class Ack(sender: NettyRpcEndpointRef) extends ResponseMessage
+
+/**
+ * A response that indicates some failure happens in the receiver side.
+ */
+private[netty] case class RpcFailure(e: Throwable)
+
+/**
+ * Maintain the mapping relations between client addresses and [[RpcEnv]] addresses, broadcast
+ * network events and forward messages to [[Dispatcher]].
+ */
+private[netty] class NettyRpcHandler(
+ dispatcher: Dispatcher, nettyEnv: NettyRpcEnv) extends RpcHandler with Logging {
+
+ private type ClientAddress = RpcAddress
+ private type RemoteEnvAddress = RpcAddress
+
+ // Store all client addresses and their NettyRpcEnv addresses.
+ @GuardedBy("this")
+ private val remoteAddresses = new mutable.HashMap[ClientAddress, RemoteEnvAddress]()
+
+ // Store the connections from other NettyRpcEnv addresses. We need to keep track of the connection
+ // count because `TransportClientFactory.createClient` will create multiple connections
+ // (at most `spark.shuffle.io.numConnectionsPerPeer` connections) and randomly select a connection
+ // to send the message. See `TransportClientFactory.createClient` for more details.
+ @GuardedBy("this")
+ private val remoteConnectionCount = new mutable.HashMap[RemoteEnvAddress, Int]()
+
+ override def receive(
+ client: TransportClient, message: Array[Byte], callback: RpcResponseCallback): Unit = {
+ val requestMessage = nettyEnv.deserialize[RequestMessage](message)
+ val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
+ assert(addr != null)
+ val remoteEnvAddress = requestMessage.senderAddress
+ val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
+ val broadcastMessage =
+ synchronized {
+ // If the first connection to a remote RpcEnv is found, we should broadcast "Associated"
+ if (remoteAddresses.put(clientAddr, remoteEnvAddress).isEmpty) {
+ // clientAddr connects at the first time
+ val count = remoteConnectionCount.getOrElse(remoteEnvAddress, 0)
+ // Increase the connection number of remoteEnvAddress
+ remoteConnectionCount.put(remoteEnvAddress, count + 1)
+ if (count == 0) {
+ // This is the first connection, so fire "Associated"
+ Some(Associated(remoteEnvAddress))
+ } else {
+ None
+ }
+ } else {
+ None
+ }
+ }
+ broadcastMessage.foreach(dispatcher.broadcastMessage)
+ dispatcher.postMessage(requestMessage, callback)
+ }
+
+ override def getStreamManager: StreamManager = new OneForOneStreamManager
+
+ override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = {
+ val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
+ if (addr != null) {
+ val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
+ val broadcastMessage =
+ synchronized {
+ remoteAddresses.get(clientAddr).map(AssociationError(cause, _))
+ }
+ if (broadcastMessage.isEmpty) {
+ logError(cause.getMessage, cause)
+ } else {
+ dispatcher.broadcastMessage(broadcastMessage.get)
+ }
+ } else {
+ // If the channel is closed before connecting, its remoteAddress will be null.
+ // See java.net.Socket.getRemoteSocketAddress
+ // Because we cannot get a RpcAddress, just log it
+ logError("Exception before connecting to the client", cause)
+ }
+ }
+
+ override def connectionTerminated(client: TransportClient): Unit = {
+ val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
+ if (addr != null) {
+ val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
+ val broadcastMessage =
+ synchronized {
+ // If the last connection to a remote RpcEnv is terminated, we should broadcast
+ // "Disassociated"
+ remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress =>
+ remoteAddresses -= clientAddr
+ val count = remoteConnectionCount.getOrElse(remoteEnvAddress, 0)
+ assert(count != 0, "remoteAddresses and remoteConnectionCount are not consistent")
+ if (count - 1 == 0) {
+ // We lost all clients, so clean up and fire "Disassociated"
+ remoteConnectionCount.remove(remoteEnvAddress)
+ Some(Disassociated(remoteEnvAddress))
+ } else {
+ // Decrease the connection number of remoteEnvAddress
+ remoteConnectionCount.put(remoteEnvAddress, count - 1)
+ None
+ }
+ }
+ }
+ broadcastMessage.foreach(dispatcher.broadcastMessage)
+ } else {
+ // If the channel is closed before connecting, its remoteAddress will be null. In this case,
+ // we can ignore it since we don't fire "Associated".
+ // See java.net.Socket.getRemoteSocketAddress
+ }
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala
index 7478ab0fc2..e749631bf6 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala
@@ -19,7 +19,7 @@ package org.apache.spark.storage
import scala.concurrent.{ExecutionContext, Future}
-import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint}
+import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext, RpcEndpoint}
import org.apache.spark.util.ThreadUtils
import org.apache.spark.{Logging, MapOutputTracker, SparkEnv}
import org.apache.spark.storage.BlockManagerMessages._
@@ -33,7 +33,7 @@ class BlockManagerSlaveEndpoint(
override val rpcEnv: RpcEnv,
blockManager: BlockManager,
mapOutputTracker: MapOutputTracker)
- extends RpcEndpoint with Logging {
+ extends ThreadSafeRpcEndpoint with Logging {
private val asyncThreadPool =
ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool")
@@ -80,7 +80,7 @@ class BlockManagerSlaveEndpoint(
future.onSuccess { case response =>
logDebug("Done " + actionMessage + ", response is " + response)
context.reply(response)
- logDebug("Sent response: " + response + " to " + context.sender)
+ logDebug("Sent response: " + response + " to " + context.senderAddress)
}
future.onFailure { case t: Throwable =>
logError("Error in " + actionMessage, t)
diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
index 22e291a2b4..1ed098379e 100644
--- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
@@ -85,7 +85,11 @@ private[spark] object ThreadUtils {
*/
def newDaemonSingleThreadScheduledExecutor(threadName: String): ScheduledExecutorService = {
val threadFactory = new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadName).build()
- Executors.newSingleThreadScheduledExecutor(threadFactory)
+ val executor = new ScheduledThreadPoolExecutor(1, threadFactory)
+ // By default, a cancelled task is not automatically removed from the work queue until its delay
+ // elapses. We have to enable it manually.
+ executor.setRemoveOnCancelPolicy(true)
+ executor
}
/**
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index af4e68950f..7e70308bb3 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -168,10 +168,9 @@ class MapOutputTrackerSuite extends SparkFunSuite {
masterTracker.registerShuffle(10, 1)
masterTracker.registerMapOutput(10, 0, MapStatus(
BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0)))
- val sender = mock(classOf[RpcEndpointRef])
- when(sender.address).thenReturn(RpcAddress("localhost", 12345))
+ val senderAddress = RpcAddress("localhost", 12345)
val rpcCallContext = mock(classOf[RpcCallContext])
- when(rpcCallContext.sender).thenReturn(sender)
+ when(rpcCallContext.senderAddress).thenReturn(senderAddress)
masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(10))
verify(rpcCallContext).reply(any())
verify(rpcCallContext, never()).sendFailure(any())
@@ -198,10 +197,9 @@ class MapOutputTrackerSuite extends SparkFunSuite {
masterTracker.registerMapOutput(20, i, new CompressedMapStatus(
BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0)))
}
- val sender = mock(classOf[RpcEndpointRef])
- when(sender.address).thenReturn(RpcAddress("localhost", 12345))
+ val senderAddress = RpcAddress("localhost", 12345)
val rpcCallContext = mock(classOf[RpcCallContext])
- when(rpcCallContext.sender).thenReturn(sender)
+ when(rpcCallContext.senderAddress).thenReturn(senderAddress)
masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(20))
verify(rpcCallContext, never()).reply(any())
verify(rpcCallContext).sendFailure(isA(classOf[SparkException]))
diff --git a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala
index 33270bec62..2d14249855 100644
--- a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala
+++ b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala
@@ -41,6 +41,7 @@ object SSLSampleConfigs {
def sparkSSLConfig(): SparkConf = {
val conf = new SparkConf(loadDefaults = false)
+ conf.set("spark.rpc", "akka")
conf.set("spark.ssl.enabled", "true")
conf.set("spark.ssl.keyStore", keyStorePath)
conf.set("spark.ssl.keyStorePassword", "password")
@@ -54,6 +55,7 @@ object SSLSampleConfigs {
def sparkSSLConfigUntrusted(): SparkConf = {
val conf = new SparkConf(loadDefaults = false)
+ conf.set("spark.rpc", "akka")
conf.set("spark.ssl.enabled", "true")
conf.set("spark.ssl.keyStore", untrustedKeyStorePath)
conf.set("spark.ssl.keyStorePassword", "password")
diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala
index e9034e39a7..40c24bdecc 100644
--- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala
@@ -26,8 +26,7 @@ class WorkerWatcherSuite extends SparkFunSuite {
val conf = new SparkConf()
val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf))
val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker")
- val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl)
- workerWatcher.setTesting(testing = true)
+ val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl, isTesting = true)
rpcEnv.setupEndpoint("worker-watcher", workerWatcher)
workerWatcher.onDisconnected(RpcAddress("1.2.3.4", 1234))
assert(workerWatcher.isShutDown)
@@ -39,8 +38,7 @@ class WorkerWatcherSuite extends SparkFunSuite {
val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf))
val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker")
val otherRpcAddress = RpcAddress("4.3.2.1", 1234)
- val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl)
- workerWatcher.setTesting(testing = true)
+ val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl, isTesting = true)
rpcEnv.setupEndpoint("worker-watcher", workerWatcher)
workerWatcher.onDisconnected(otherRpcAddress)
assert(!workerWatcher.isShutDown)
diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
index 6ceafe4337..3bead6395d 100644
--- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.rpc
+import java.io.NotSerializableException
import java.util.concurrent.{TimeUnit, CountDownLatch, TimeoutException}
import scala.collection.mutable
@@ -99,7 +100,6 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
}
}
val rpcEndpointRef = env.setupEndpoint("send-ref", endpoint)
-
val newRpcEndpointRef = rpcEndpointRef.askWithRetry[RpcEndpointRef]("Hello")
val reply = newRpcEndpointRef.askWithRetry[String]("Echo")
assert("Echo" === reply)
@@ -328,9 +328,6 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
override def onStop(): Unit = {
selfOption = Option(self)
}
-
- override def onError(cause: Throwable): Unit = {
- }
})
env.stop(endpointRef)
@@ -516,6 +513,9 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
assert(events === List(
("onConnected", remoteAddress),
("onNetworkError", remoteAddress),
+ ("onDisconnected", remoteAddress)) ||
+ events === List(
+ ("onConnected", remoteAddress),
("onDisconnected", remoteAddress)))
}
}
@@ -535,15 +535,84 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
"local", env.address, "sendWithReply-unserializable-error")
try {
val f = rpcEndpointRef.ask[String]("hello")
- intercept[TimeoutException] {
+ val e = intercept[Exception] {
Await.result(f, 1 seconds)
}
+ assert(e.isInstanceOf[TimeoutException] || // For Akka
+ e.isInstanceOf[NotSerializableException] // For Netty
+ )
} finally {
anotherEnv.shutdown()
anotherEnv.awaitTermination()
}
}
+ test("port conflict") {
+ val anotherEnv = createRpcEnv(new SparkConf(), "remote", env.address.port)
+ assert(anotherEnv.address.port != env.address.port)
+ }
+
+ test("send with authentication") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "true")
+ conf.set("spark.authenticate.secret", "good")
+
+ val localEnv = createRpcEnv(conf, "authentication-local", 13345)
+ val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345)
+
+ try {
+ @volatile var message: String = null
+ localEnv.setupEndpoint("send-authentication", new RpcEndpoint {
+ override val rpcEnv = localEnv
+
+ override def receive: PartialFunction[Any, Unit] = {
+ case msg: String => message = msg
+ }
+ })
+ val rpcEndpointRef =
+ remoteEnv.setupEndpointRef("authentication-local", localEnv.address, "send-authentication")
+ rpcEndpointRef.send("hello")
+ eventually(timeout(5 seconds), interval(10 millis)) {
+ assert("hello" === message)
+ }
+ } finally {
+ localEnv.shutdown()
+ localEnv.awaitTermination()
+ remoteEnv.shutdown()
+ remoteEnv.awaitTermination()
+ }
+ }
+
+ test("ask with authentication") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "true")
+ conf.set("spark.authenticate.secret", "good")
+
+ val localEnv = createRpcEnv(conf, "authentication-local", 13345)
+ val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345)
+
+ try {
+ localEnv.setupEndpoint("ask-authentication", new RpcEndpoint {
+ override val rpcEnv = localEnv
+
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ case msg: String => {
+ context.reply(msg)
+ }
+ }
+ })
+ val rpcEndpointRef =
+ remoteEnv.setupEndpointRef("authentication-local", localEnv.address, "ask-authentication")
+ val reply = rpcEndpointRef.askWithRetry[String]("hello")
+ assert("hello" === reply)
+ } finally {
+ localEnv.shutdown()
+ localEnv.awaitTermination()
+ remoteEnv.shutdown()
+ remoteEnv.awaitTermination()
+ }
+ }
+
test("construct RpcTimeout with conf property") {
val conf = new SparkConf
@@ -612,7 +681,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
// once the future is complete to verify addMessageIfTimeout was invoked
val reply3 =
intercept[RpcTimeoutException] {
- Await.result(fut3, 200 millis)
+ Await.result(fut3, 2000 millis)
}.getMessage
// When the future timed out, the recover callback should have used
diff --git a/core/src/test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala b/core/src/test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala
new file mode 100644
index 0000000000..5e8da3e205
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rpc
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.scalactic.TripleEquals
+
+class TestRpcEndpoint extends ThreadSafeRpcEndpoint with TripleEquals {
+
+ override val rpcEnv: RpcEnv = null
+
+ @volatile private var receiveMessages = ArrayBuffer[Any]()
+
+ @volatile private var receiveAndReplyMessages = ArrayBuffer[Any]()
+
+ @volatile private var onConnectedMessages = ArrayBuffer[RpcAddress]()
+
+ @volatile private var onDisconnectedMessages = ArrayBuffer[RpcAddress]()
+
+ @volatile private var onNetworkErrorMessages = ArrayBuffer[(Throwable, RpcAddress)]()
+
+ @volatile private var started = false
+
+ @volatile private var stopped = false
+
+ override def receive: PartialFunction[Any, Unit] = {
+ case message: Any => receiveMessages += message
+ }
+
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ case message: Any => receiveAndReplyMessages += message
+ }
+
+ override def onConnected(remoteAddress: RpcAddress): Unit = {
+ onConnectedMessages += remoteAddress
+ }
+
+ /**
+ * Invoked when some network error happens in the connection between the current node and
+ * `remoteAddress`.
+ */
+ override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = {
+ onNetworkErrorMessages += cause -> remoteAddress
+ }
+
+ override def onDisconnected(remoteAddress: RpcAddress): Unit = {
+ onDisconnectedMessages += remoteAddress
+ }
+
+ def numReceiveMessages: Int = receiveMessages.size
+
+ override def onStart(): Unit = {
+ started = true
+ }
+
+ override def onStop(): Unit = {
+ stopped = true
+ }
+
+ def verifyStarted(): Unit = {
+ assert(started, "RpcEndpoint is not started")
+ }
+
+ def verifyStopped(): Unit = {
+ assert(stopped, "RpcEndpoint is not stopped")
+ }
+
+ def verifyReceiveMessages(expected: Seq[Any]): Unit = {
+ assert(receiveMessages === expected)
+ }
+
+ def verifySingleReceiveMessage(message: Any): Unit = {
+ verifyReceiveMessages(List(message))
+ }
+
+ def verifyReceiveAndReplyMessages(expected: Seq[Any]): Unit = {
+ assert(receiveAndReplyMessages === expected)
+ }
+
+ def verifySingleReceiveAndReplyMessage(message: Any): Unit = {
+ verifyReceiveAndReplyMessages(List(message))
+ }
+
+ def verifySingleOnConnectedMessage(remoteAddress: RpcAddress): Unit = {
+ verifyOnConnectedMessages(List(remoteAddress))
+ }
+
+ def verifyOnConnectedMessages(expected: Seq[RpcAddress]): Unit = {
+ assert(onConnectedMessages === expected)
+ }
+
+ def verifySingleOnDisconnectedMessage(remoteAddress: RpcAddress): Unit = {
+ verifyOnDisconnectedMessages(List(remoteAddress))
+ }
+
+ def verifyOnDisconnectedMessages(expected: Seq[RpcAddress]): Unit = {
+ assert(onDisconnectedMessages === expected)
+ }
+
+ def verifySingleOnNetworkErrorMessage(cause: Throwable, remoteAddress: RpcAddress): Unit = {
+ verifyOnNetworkErrorMessages(List(cause -> remoteAddress))
+ }
+
+ def verifyOnNetworkErrorMessages(expected: Seq[(Throwable, RpcAddress)]): Unit = {
+ assert(onNetworkErrorMessages === expected)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala
new file mode 100644
index 0000000000..120cf1b6fa
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rpc.netty
+
+import java.util.concurrent.{CountDownLatch, TimeUnit}
+import java.util.concurrent.atomic.AtomicInteger
+
+import org.mockito.Mockito._
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.rpc.{RpcEnv, RpcEndpoint, RpcAddress, TestRpcEndpoint}
+
+class InboxSuite extends SparkFunSuite {
+
+ test("post") {
+ val endpoint = new TestRpcEndpoint
+ val endpointRef = mock(classOf[NettyRpcEndpointRef])
+ when(endpointRef.name).thenReturn("hello")
+
+ val dispatcher = mock(classOf[Dispatcher])
+
+ val inbox = new Inbox(endpointRef, endpoint)
+ val message = ContentMessage(null, "hi", false, null)
+ inbox.post(message)
+ inbox.process(dispatcher)
+ assert(inbox.isEmpty)
+
+ endpoint.verifySingleReceiveMessage("hi")
+
+ inbox.stop()
+ inbox.process(dispatcher)
+ assert(inbox.isEmpty)
+ endpoint.verifyStarted()
+ endpoint.verifyStopped()
+ }
+
+ test("post: with reply") {
+ val endpoint = new TestRpcEndpoint
+ val endpointRef = mock(classOf[NettyRpcEndpointRef])
+ val dispatcher = mock(classOf[Dispatcher])
+
+ val inbox = new Inbox(endpointRef, endpoint)
+ val message = ContentMessage(null, "hi", true, null)
+ inbox.post(message)
+ inbox.process(dispatcher)
+ assert(inbox.isEmpty)
+
+ endpoint.verifySingleReceiveAndReplyMessage("hi")
+ }
+
+ test("post: multiple threads") {
+ val endpoint = new TestRpcEndpoint
+ val endpointRef = mock(classOf[NettyRpcEndpointRef])
+ when(endpointRef.name).thenReturn("hello")
+
+ val dispatcher = mock(classOf[Dispatcher])
+
+ val numDroppedMessages = new AtomicInteger(0)
+ val inbox = new Inbox(endpointRef, endpoint) {
+ override def onDrop(message: InboxMessage): Unit = {
+ numDroppedMessages.incrementAndGet()
+ }
+ }
+
+ val exitLatch = new CountDownLatch(10)
+
+ for (_ <- 0 until 10) {
+ new Thread {
+ override def run(): Unit = {
+ for (_ <- 0 until 100) {
+ val message = ContentMessage(null, "hi", false, null)
+ inbox.post(message)
+ }
+ exitLatch.countDown()
+ }
+ }.start()
+ }
+ // Try to process some messages
+ inbox.process(dispatcher)
+ inbox.stop()
+ // After `stop` is called, further messages will be dropped. However, while `stop` is called,
+ // some messages may be post to Inbox, so process them here.
+ inbox.process(dispatcher)
+ assert(inbox.isEmpty)
+
+ exitLatch.await(30, TimeUnit.SECONDS)
+
+ assert(1000 === endpoint.numReceiveMessages + numDroppedMessages.get)
+ endpoint.verifyStarted()
+ endpoint.verifyStopped()
+ }
+
+ test("post: Associated") {
+ val endpoint = new TestRpcEndpoint
+ val endpointRef = mock(classOf[NettyRpcEndpointRef])
+ val dispatcher = mock(classOf[Dispatcher])
+
+ val remoteAddress = RpcAddress("localhost", 11111)
+
+ val inbox = new Inbox(endpointRef, endpoint)
+ inbox.post(Associated(remoteAddress))
+ inbox.process(dispatcher)
+
+ endpoint.verifySingleOnConnectedMessage(remoteAddress)
+ }
+
+ test("post: Disassociated") {
+ val endpoint = new TestRpcEndpoint
+ val endpointRef = mock(classOf[NettyRpcEndpointRef])
+ val dispatcher = mock(classOf[Dispatcher])
+
+ val remoteAddress = RpcAddress("localhost", 11111)
+
+ val inbox = new Inbox(endpointRef, endpoint)
+ inbox.post(Disassociated(remoteAddress))
+ inbox.process(dispatcher)
+
+ endpoint.verifySingleOnDisconnectedMessage(remoteAddress)
+ }
+
+ test("post: AssociationError") {
+ val endpoint = new TestRpcEndpoint
+ val endpointRef = mock(classOf[NettyRpcEndpointRef])
+ val dispatcher = mock(classOf[Dispatcher])
+
+ val remoteAddress = RpcAddress("localhost", 11111)
+ val cause = new RuntimeException("Oops")
+
+ val inbox = new Inbox(endpointRef, endpoint)
+ inbox.post(AssociationError(cause, remoteAddress))
+ inbox.process(dispatcher)
+
+ endpoint.verifySingleOnNetworkErrorMessage(cause, remoteAddress)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala
new file mode 100644
index 0000000000..a5d43d3704
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rpc.netty
+
+import org.apache.spark.SparkFunSuite
+
+class NettyRpcAddressSuite extends SparkFunSuite {
+
+ test("toString") {
+ val addr = NettyRpcAddress("localhost", 12345, "test")
+ assert(addr.toString === "spark://test@localhost:12345")
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala
new file mode 100644
index 0000000000..be19668e17
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rpc.netty
+
+import org.apache.spark.{SecurityManager, SparkConf}
+import org.apache.spark.rpc._
+
+class NettyRpcEnvSuite extends RpcEnvSuite {
+
+ override def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv = {
+ val config = RpcEnvConfig(conf, "test", "localhost", port, new SecurityManager(conf))
+ new NettyRpcEnvFactory().create(config)
+ }
+
+ test("non-existent endpoint") {
+ val uri = env.uriOf("test", env.address, "nonexist-endpoint")
+ val e = intercept[RpcEndpointNotFoundException] {
+ env.setupEndpointRef("test", env.address, "nonexist-endpoint")
+ }
+ assert(e.getMessage.contains(uri))
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
new file mode 100644
index 0000000000..06ca035d19
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rpc.netty
+
+import java.net.InetSocketAddress
+
+import io.netty.channel.Channel
+import org.mockito.Mockito._
+import org.mockito.Matchers._
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.network.client.{TransportResponseHandler, TransportClient}
+import org.apache.spark.rpc._
+
+class NettyRpcHandlerSuite extends SparkFunSuite {
+
+ val env = mock(classOf[NettyRpcEnv])
+ when(env.deserialize(any(classOf[Array[Byte]]))(any())).
+ thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false))
+
+ test("receive") {
+ val dispatcher = mock(classOf[Dispatcher])
+ val nettyRpcHandler = new NettyRpcHandler(dispatcher, env)
+
+ val channel = mock(classOf[Channel])
+ val client = new TransportClient(channel, mock(classOf[TransportResponseHandler]))
+ when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000))
+ nettyRpcHandler.receive(client, null, null)
+
+ when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40001))
+ nettyRpcHandler.receive(client, null, null)
+
+ verify(dispatcher, times(1)).broadcastMessage(Associated(RpcAddress("localhost", 12345)))
+ }
+
+ test("connectionTerminated") {
+ val dispatcher = mock(classOf[Dispatcher])
+ val nettyRpcHandler = new NettyRpcHandler(dispatcher, env)
+
+ val channel = mock(classOf[Channel])
+ val client = new TransportClient(channel, mock(classOf[TransportResponseHandler]))
+ when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000))
+ nettyRpcHandler.receive(client, null, null)
+
+ when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000))
+ nettyRpcHandler.connectionTerminated(client)
+
+ verify(dispatcher, times(1)).broadcastMessage(Associated(RpcAddress("localhost", 12345)))
+ verify(dispatcher, times(1)).broadcastMessage(Disassociated(RpcAddress("localhost", 12345)))
+ }
+
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
index df841288a0..fbb8bb6b2f 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
@@ -78,6 +78,10 @@ public class TransportClient implements Closeable {
this.handler = Preconditions.checkNotNull(handler);
}
+ public Channel getChannel() {
+ return channel;
+ }
+
public boolean isActive() {
return channel.isOpen() || channel.isActive();
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
index 2ba92a40f8..dbb7f95f55 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
@@ -52,4 +52,6 @@ public abstract class RpcHandler {
* No further requests will come from this client.
*/
public void connectionTerminated(TransportClient client) { }
+
+ public void exceptionCaught(Throwable cause, TransportClient client) { }
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
index df60278058..96941d26be 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
@@ -71,6 +71,7 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
@Override
public void exceptionCaught(Throwable cause) {
+ rpcHandler.exceptionCaught(cause, reverseClient);
}
@Override
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
index 204e6142fd..d053e9e849 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
@@ -474,7 +474,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
// Remote messages
case RegisterReceiver(streamId, typ, hostPort, receiverEndpoint) =>
val successful =
- registerReceiver(streamId, typ, hostPort, receiverEndpoint, context.sender.address)
+ registerReceiver(streamId, typ, hostPort, receiverEndpoint, context.senderAddress)
context.reply(successful)
case AddBlock(receivedBlockInfo) =>
context.reply(addBlock(receivedBlockInfo))
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 07a0a45e60..0df31736c1 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -556,10 +556,7 @@ private[spark] class ApplicationMaster(
override val rpcEnv: RpcEnv, driver: RpcEndpointRef, isClusterMode: Boolean)
extends RpcEndpoint with Logging {
- override def onStart(): Unit = {
- driver.send(RegisterClusterManager(self))
-
- }
+ driver.send(RegisterClusterManager(self))
override def receive: PartialFunction[Any, Unit] = {
case x: AddWebUIFilter =>