aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-09-24 08:25:44 -0700
committerXiangrui Meng <meng@databricks.com>2015-09-24 08:25:44 -0700
commit02144d6745ec0a6d8877d969feb82139bd22437f (patch)
tree78c8b59523a2c8d51a9bfcfc1b337d80c647a2f8
parentd91967e159f416924bbd7f0db25156588d4bd7b1 (diff)
downloadspark-02144d6745ec0a6d8877d969feb82139bd22437f.tar.gz
spark-02144d6745ec0a6d8877d969feb82139bd22437f.tar.bz2
spark-02144d6745ec0a6d8877d969feb82139bd22437f.zip
Revert "[SPARK-6028][Core]A new RPC implemetation based on the network module"
This reverts commit 084e4e126211d74a79e8dbd2d0e604dd3c650822.
-rw-r--r--core/src/main/scala/org/apache/spark/MapOutputTracker.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala20
-rwxr-xr-xcore/src/main/scala/org/apache/spark/deploy/worker/Worker.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala51
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/RpcEndpointNotFoundException.scala22
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala218
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala39
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala220
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala56
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala87
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala504
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/util/ThreadUtils.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala10
-rw-r--r--core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala78
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala123
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala148
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala29
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala38
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala67
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClient.java4
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java2
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java1
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala2
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala5
31 files changed, 68 insertions, 1708 deletions
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index e380c5b8af..94eb8daa85 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -45,7 +45,7 @@ private[spark] class MapOutputTrackerMasterEndpoint(
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case GetMapOutputStatuses(shuffleId: Int) =>
- val hostPort = context.senderAddress.hostPort
+ val hostPort = context.sender.address.hostPort
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId)
val serializedSize = mapOutputStatuses.size
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index cfde27fb2e..c6fef7f91f 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -20,10 +20,11 @@ package org.apache.spark
import java.io.File
import java.net.Socket
+import akka.actor.ActorSystem
+
import scala.collection.mutable
import scala.util.Properties
-import akka.actor.ActorSystem
import com.google.common.collect.MapMaker
import org.apache.spark.annotation.DeveloperApi
@@ -40,7 +41,7 @@ import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager}
import org.apache.spark.storage._
import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator}
-import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils}
+import org.apache.spark.util.{RpcUtils, Utils}
/**
* :: DeveloperApi ::
@@ -56,7 +57,6 @@ import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils}
class SparkEnv (
val executorId: String,
private[spark] val rpcEnv: RpcEnv,
- _actorSystem: ActorSystem, // TODO Remove actorSystem
val serializer: Serializer,
val closureSerializer: Serializer,
val cacheManager: CacheManager,
@@ -76,7 +76,7 @@ class SparkEnv (
// TODO Remove actorSystem
@deprecated("Actor system is no longer supported as of 1.4.0", "1.4.0")
- val actorSystem: ActorSystem = _actorSystem
+ val actorSystem: ActorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem
private[spark] var isStopped = false
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
@@ -100,9 +100,6 @@ class SparkEnv (
blockManager.master.stop()
metricsSystem.stop()
outputCommitCoordinator.stop()
- if (!rpcEnv.isInstanceOf[AkkaRpcEnv]) {
- actorSystem.shutdown()
- }
rpcEnv.shutdown()
// Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut
@@ -252,13 +249,7 @@ object SparkEnv extends Logging {
// Create the ActorSystem for Akka and get the port it binds to.
val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName
val rpcEnv = RpcEnv.create(actorSystemName, hostname, port, conf, securityManager)
- val actorSystem: ActorSystem =
- if (rpcEnv.isInstanceOf[AkkaRpcEnv]) {
- rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem
- } else {
- // Create a ActorSystem for legacy codes
- AkkaUtils.createActorSystem(actorSystemName, hostname, port, conf, securityManager)._1
- }
+ val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem
// Figure out which port Akka actually bound to in case the original port is 0 or occupied.
if (isDriver) {
@@ -404,7 +395,6 @@ object SparkEnv extends Logging {
val envInstance = new SparkEnv(
executorId,
rpcEnv,
- actorSystem,
serializer,
closureSerializer,
cacheManager,
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index 93a1b3f310..770927c80f 100755
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -329,7 +329,7 @@ private[deploy] class Worker(
registrationRetryTimer = Some(forwordMessageScheduler.scheduleAtFixedRate(
new Runnable {
override def run(): Unit = Utils.tryLogNonFatalError {
- Option(self).foreach(_.send(ReregisterWithMaster))
+ self.send(ReregisterWithMaster)
}
},
INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS,
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
index ab56fde938..735c4f0927 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
@@ -24,13 +24,14 @@ import org.apache.spark.rpc._
* Actor which connects to a worker process and terminates the JVM if the connection is severed.
* Provides fate sharing between a worker and its associated child processes.
*/
-private[spark] class WorkerWatcher(
- override val rpcEnv: RpcEnv, workerUrl: String, isTesting: Boolean = false)
+private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: String)
extends RpcEndpoint with Logging {
- logInfo(s"Connecting to worker $workerUrl")
- if (!isTesting) {
- rpcEnv.asyncSetupEndpointRefByURI(workerUrl)
+ override def onStart() {
+ logInfo(s"Connecting to worker $workerUrl")
+ if (!isTesting) {
+ rpcEnv.asyncSetupEndpointRefByURI(workerUrl)
+ }
}
// Used to avoid shutting down JVM during tests
@@ -39,6 +40,8 @@ private[spark] class WorkerWatcher(
// true rather than calling `System.exit`. The user can check `isShutDown` to know if
// `exitNonZero` is called.
private[deploy] var isShutDown = false
+ private[deploy] def setTesting(testing: Boolean) = isTesting = testing
+ private var isTesting = false
// Lets filter events only from the worker's rpc system
private val expectedAddress = RpcAddress.fromURIString(workerUrl)
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala
index f527ec86ab..3e5b64265e 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala
@@ -37,5 +37,5 @@ private[spark] trait RpcCallContext {
/**
* The sender of this message.
*/
- def senderAddress: RpcAddress
+ def sender: RpcEndpointRef
}
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala
index f1ddc6d2cd..dfcbc51cdf 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala
@@ -29,6 +29,20 @@ private[spark] trait RpcEnvFactory {
}
/**
+ * A trait that requires RpcEnv thread-safely sending messages to it.
+ *
+ * Thread-safety means processing of one message happens before processing of the next message by
+ * the same [[ThreadSafeRpcEndpoint]]. In the other words, changes to internal fields of a
+ * [[ThreadSafeRpcEndpoint]] are visible when processing the next message, and fields in the
+ * [[ThreadSafeRpcEndpoint]] need not be volatile or equivalent.
+ *
+ * However, there is no guarantee that the same thread will be executing the same
+ * [[ThreadSafeRpcEndpoint]] for different messages.
+ */
+private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint
+
+
+/**
* An end point for the RPC that defines what functions to trigger given a message.
*
* It is guaranteed that `onStart`, `receive` and `onStop` will be called in sequence.
@@ -87,39 +101,38 @@ private[spark] trait RpcEndpoint {
}
/**
- * Invoked when `remoteAddress` is connected to the current node.
+ * Invoked before [[RpcEndpoint]] starts to handle any message.
*/
- def onConnected(remoteAddress: RpcAddress): Unit = {
+ def onStart(): Unit = {
// By default, do nothing.
}
/**
- * Invoked when `remoteAddress` is lost.
+ * Invoked when [[RpcEndpoint]] is stopping.
*/
- def onDisconnected(remoteAddress: RpcAddress): Unit = {
+ def onStop(): Unit = {
// By default, do nothing.
}
/**
- * Invoked when some network error happens in the connection between the current node and
- * `remoteAddress`.
+ * Invoked when `remoteAddress` is connected to the current node.
*/
- def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = {
+ def onConnected(remoteAddress: RpcAddress): Unit = {
// By default, do nothing.
}
/**
- * Invoked before [[RpcEndpoint]] starts to handle any message.
+ * Invoked when `remoteAddress` is lost.
*/
- def onStart(): Unit = {
+ def onDisconnected(remoteAddress: RpcAddress): Unit = {
// By default, do nothing.
}
/**
- * Invoked when [[RpcEndpoint]] is stopping. `self` will be `null` in this method and you cannot
- * use it to send or ask messages.
+ * Invoked when some network error happens in the connection between the current node and
+ * `remoteAddress`.
*/
- def onStop(): Unit = {
+ def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = {
// By default, do nothing.
}
@@ -133,17 +146,3 @@ private[spark] trait RpcEndpoint {
}
}
}
-
-/**
- * A trait that requires RpcEnv thread-safely sending messages to it.
- *
- * Thread-safety means processing of one message happens before processing of the next message by
- * the same [[ThreadSafeRpcEndpoint]]. In the other words, changes to internal fields of a
- * [[ThreadSafeRpcEndpoint]] are visible when processing the next message, and fields in the
- * [[ThreadSafeRpcEndpoint]] need not be volatile or equivalent.
- *
- * However, there is no guarantee that the same thread will be executing the same
- * [[ThreadSafeRpcEndpoint]] for different messages.
- */
-private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint {
-}
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointNotFoundException.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointNotFoundException.scala
deleted file mode 100644
index d177881fb3..0000000000
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointNotFoundException.scala
+++ /dev/null
@@ -1,22 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.rpc
-
-import org.apache.spark.SparkException
-
-private[rpc] class RpcEndpointNotFoundException(uri: String)
- extends SparkException(s"Cannot find endpoint: $uri")
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
index afe1ee8a0b..29debe8081 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
@@ -36,11 +36,8 @@ private[spark] object RpcEnv {
private def getRpcEnvFactory(conf: SparkConf): RpcEnvFactory = {
// Add more RpcEnv implementations here
- val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory",
- "netty" -> "org.apache.spark.rpc.netty.NettyRpcEnvFactory")
- // Use "netty" by default so that Jenkins can run all tests using NettyRpcEnv.
- // Will change it back to "akka" before merging the new implementation.
- val rpcEnvName = conf.get("spark.rpc", "netty")
+ val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory")
+ val rpcEnvName = conf.get("spark.rpc", "akka")
val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName)
Utils.classForName(rpcEnvFactoryClassName).newInstance().asInstanceOf[RpcEnvFactory]
}
diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
index 95132a4e4a..ad67e1c5ad 100644
--- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
@@ -166,9 +166,9 @@ private[spark] class AkkaRpcEnv private[akka] (
_sender ! AkkaMessage(response, false)
}
- // Use "lazy" because most of RpcEndpoints don't need "senderAddress"
- override lazy val senderAddress: RpcAddress =
- new AkkaRpcEndpointRef(defaultAddress, _sender, conf).address
+ // Some RpcEndpoints need to know the sender's address
+ override val sender: RpcEndpointRef =
+ new AkkaRpcEndpointRef(defaultAddress, _sender, conf)
})
} else {
endpoint.receive
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
deleted file mode 100644
index d71e6f01db..0000000000
--- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
+++ /dev/null
@@ -1,218 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rpc.netty
-
-import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, TimeUnit}
-import javax.annotation.concurrent.GuardedBy
-
-import scala.collection.JavaConverters._
-import scala.concurrent.Promise
-import scala.util.control.NonFatal
-
-import org.apache.spark.{SparkException, Logging}
-import org.apache.spark.network.client.RpcResponseCallback
-import org.apache.spark.rpc._
-import org.apache.spark.util.ThreadUtils
-
-private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
-
- private class EndpointData(
- val name: String,
- val endpoint: RpcEndpoint,
- val ref: NettyRpcEndpointRef) {
- val inbox = new Inbox(ref, endpoint)
- }
-
- private val endpoints = new ConcurrentHashMap[String, EndpointData]()
- private val endpointRefs = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]()
-
- // Track the receivers whose inboxes may contain messages.
- private val receivers = new LinkedBlockingQueue[EndpointData]()
-
- @GuardedBy("this")
- private var stopped = false
-
- def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = {
- val addr = new NettyRpcAddress(nettyEnv.address.host, nettyEnv.address.port, name)
- val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv)
- synchronized {
- if (stopped) {
- throw new IllegalStateException("RpcEnv has been stopped")
- }
- if (endpoints.putIfAbsent(name, new EndpointData(name, endpoint, endpointRef)) != null) {
- throw new IllegalArgumentException(s"There is already an RpcEndpoint called $name")
- }
- val data = endpoints.get(name)
- endpointRefs.put(data.endpoint, data.ref)
- receivers.put(data)
- }
- endpointRef
- }
-
- def getRpcEndpointRef(endpoint: RpcEndpoint): RpcEndpointRef = endpointRefs.get(endpoint)
-
- def removeRpcEndpointRef(endpoint: RpcEndpoint): Unit = endpointRefs.remove(endpoint)
-
- // Should be idempotent
- private def unregisterRpcEndpoint(name: String): Unit = {
- val data = endpoints.remove(name)
- if (data != null) {
- data.inbox.stop()
- receivers.put(data)
- }
- // Don't clean `endpointRefs` here because it's possible that some messages are being processed
- // now and they can use `getRpcEndpointRef`. So `endpointRefs` will be cleaned in Inbox via
- // `removeRpcEndpointRef`.
- }
-
- def stop(rpcEndpointRef: RpcEndpointRef): Unit = {
- synchronized {
- if (stopped) {
- // This endpoint will be stopped by Dispatcher.stop() method.
- return
- }
- unregisterRpcEndpoint(rpcEndpointRef.name)
- }
- }
-
- /**
- * Send a message to all registered [[RpcEndpoint]]s.
- * @param message
- */
- def broadcastMessage(message: InboxMessage): Unit = {
- val iter = endpoints.keySet().iterator()
- while (iter.hasNext) {
- val name = iter.next
- postMessageToInbox(name, (_) => message,
- () => { logWarning(s"Drop ${message} because ${name} has been stopped") })
- }
- }
-
- def postMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = {
- def createMessage(sender: NettyRpcEndpointRef): InboxMessage = {
- val rpcCallContext =
- new RemoteNettyRpcCallContext(
- nettyEnv, sender, callback, message.senderAddress, message.needReply)
- ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext)
- }
-
- def onEndpointStopped(): Unit = {
- callback.onFailure(
- new SparkException(s"Could not find ${message.receiver.name} or it has been stopped"))
- }
-
- postMessageToInbox(message.receiver.name, createMessage, onEndpointStopped)
- }
-
- def postMessage(message: RequestMessage, p: Promise[Any]): Unit = {
- def createMessage(sender: NettyRpcEndpointRef): InboxMessage = {
- val rpcCallContext =
- new LocalNettyRpcCallContext(sender, message.senderAddress, message.needReply, p)
- ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext)
- }
-
- def onEndpointStopped(): Unit = {
- p.tryFailure(
- new SparkException(s"Could not find ${message.receiver.name} or it has been stopped"))
- }
-
- postMessageToInbox(message.receiver.name, createMessage, onEndpointStopped)
- }
-
- private def postMessageToInbox(
- endpointName: String,
- createMessageFn: NettyRpcEndpointRef => InboxMessage,
- onStopped: () => Unit): Unit = {
- val shouldCallOnStop =
- synchronized {
- val data = endpoints.get(endpointName)
- if (stopped || data == null) {
- true
- } else {
- data.inbox.post(createMessageFn(data.ref))
- receivers.put(data)
- false
- }
- }
- if (shouldCallOnStop) {
- // We don't need to call `onStop` in the `synchronized` block
- onStopped()
- }
- }
-
- private val parallelism = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.parallelism",
- Runtime.getRuntime.availableProcessors())
-
- private val executor = ThreadUtils.newDaemonFixedThreadPool(parallelism, "dispatcher-event-loop")
-
- (0 until parallelism) foreach { _ =>
- executor.execute(new MessageLoop)
- }
-
- def stop(): Unit = {
- synchronized {
- if (stopped) {
- return
- }
- stopped = true
- }
- // Stop all endpoints. This will queue all endpoints for processing by the message loops.
- endpoints.keySet().asScala.foreach(unregisterRpcEndpoint)
- // Enqueue a message that tells the message loops to stop.
- receivers.put(PoisonEndpoint)
- executor.shutdown()
- }
-
- def awaitTermination(): Unit = {
- executor.awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS)
- }
-
- /**
- * Return if the endpoint exists
- */
- def verify(name: String): Boolean = {
- endpoints.containsKey(name)
- }
-
- private class MessageLoop extends Runnable {
- override def run(): Unit = {
- try {
- while (true) {
- try {
- val data = receivers.take()
- if (data == PoisonEndpoint) {
- // Put PoisonEndpoint back so that other MessageLoops can see it.
- receivers.put(PoisonEndpoint)
- return
- }
- data.inbox.process(Dispatcher.this)
- } catch {
- case NonFatal(e) => logError(e.getMessage, e)
- }
- }
- } catch {
- case ie: InterruptedException => // exit
- }
- }
- }
-
- /**
- * A poison endpoint that indicates MessageLoop should exit its loop.
- */
- private val PoisonEndpoint = new EndpointData(null, null, null)
-}
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala b/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala
deleted file mode 100644
index 6061c9b8de..0000000000
--- a/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala
+++ /dev/null
@@ -1,39 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.rpc.netty
-
-import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv}
-
-/**
- * A message used to ask the remote [[IDVerifier]] if an [[RpcEndpoint]] exists
- */
-private[netty] case class ID(name: String)
-
-/**
- * An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if a [[RpcEndpoint]] exists in this [[RpcEnv]]
- */
-private[netty] class IDVerifier(
- override val rpcEnv: RpcEnv, dispatcher: Dispatcher) extends RpcEndpoint {
-
- override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
- case ID(name) => context.reply(dispatcher.verify(name))
- }
-}
-
-private[netty] object IDVerifier {
- val NAME = "id-verifier"
-}
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala
deleted file mode 100644
index 4803548365..0000000000
--- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala
+++ /dev/null
@@ -1,220 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rpc.netty
-
-import java.util.LinkedList
-import javax.annotation.concurrent.GuardedBy
-
-import scala.util.control.NonFatal
-
-import org.apache.spark.{Logging, SparkException}
-import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, ThreadSafeRpcEndpoint}
-
-private[netty] sealed trait InboxMessage
-
-private[netty] case class ContentMessage(
- senderAddress: RpcAddress,
- content: Any,
- needReply: Boolean,
- context: NettyRpcCallContext) extends InboxMessage
-
-private[netty] case object OnStart extends InboxMessage
-
-private[netty] case object OnStop extends InboxMessage
-
-/**
- * A broadcast message that indicates connecting to a remote node.
- */
-private[netty] case class Associated(remoteAddress: RpcAddress) extends InboxMessage
-
-/**
- * A broadcast message that indicates a remote connection is lost.
- */
-private[netty] case class Disassociated(remoteAddress: RpcAddress) extends InboxMessage
-
-/**
- * A broadcast message that indicates a network error
- */
-private[netty] case class AssociationError(cause: Throwable, remoteAddress: RpcAddress)
- extends InboxMessage
-
-/**
- * A inbox that stores messages for an [[RpcEndpoint]] and posts messages to it thread-safely.
- * @param endpointRef
- * @param endpoint
- */
-private[netty] class Inbox(
- val endpointRef: NettyRpcEndpointRef,
- val endpoint: RpcEndpoint) extends Logging {
-
- @GuardedBy("this")
- protected val messages = new LinkedList[InboxMessage]()
-
- @GuardedBy("this")
- private var stopped = false
-
- @GuardedBy("this")
- private var enableConcurrent = false
-
- @GuardedBy("this")
- private var workerCount = 0
-
- // OnStart should be the first message to process
- synchronized {
- messages.add(OnStart)
- }
-
- /**
- * Process stored messages.
- */
- def process(dispatcher: Dispatcher): Unit = {
- var message: InboxMessage = null
- synchronized {
- if (!enableConcurrent && workerCount != 0) {
- return
- }
- message = messages.poll()
- if (message != null) {
- workerCount += 1
- } else {
- return
- }
- }
- while (true) {
- safelyCall(endpoint) {
- message match {
- case ContentMessage(_sender, content, needReply, context) =>
- val pf: PartialFunction[Any, Unit] =
- if (needReply) {
- endpoint.receiveAndReply(context)
- } else {
- endpoint.receive
- }
- try {
- pf.applyOrElse[Any, Unit](content, { msg =>
- throw new SparkException(s"Unmatched message $message from ${_sender}")
- })
- if (!needReply) {
- context.finish()
- }
- } catch {
- case NonFatal(e) =>
- if (needReply) {
- // If the sender asks a reply, we should send the error back to the sender
- context.sendFailure(e)
- } else {
- context.finish()
- throw e
- }
- }
-
- case OnStart => {
- endpoint.onStart()
- if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
- synchronized {
- enableConcurrent = true
- }
- }
- }
-
- case OnStop =>
- assert(synchronized { workerCount } == 1)
- dispatcher.removeRpcEndpointRef(endpoint)
- endpoint.onStop()
- assert(isEmpty, "OnStop should be the last message")
-
- case Associated(remoteAddress) =>
- endpoint.onConnected(remoteAddress)
-
- case Disassociated(remoteAddress) =>
- endpoint.onDisconnected(remoteAddress)
-
- case AssociationError(cause, remoteAddress) =>
- endpoint.onNetworkError(cause, remoteAddress)
- }
- }
-
- synchronized {
- // "enableConcurrent" will be set to false after `onStop` is called, so we should check it
- // every time.
- if (!enableConcurrent && workerCount != 1) {
- // If we are not the only one worker, exit
- workerCount -= 1
- return
- }
- message = messages.poll()
- if (message == null) {
- workerCount -= 1
- return
- }
- }
- }
- }
-
- def post(message: InboxMessage): Unit = {
- val dropped =
- synchronized {
- if (stopped) {
- // We already put "OnStop" into "messages", so we should drop further messages
- true
- } else {
- messages.add(message)
- false
- }
- }
- if (dropped) {
- onDrop(message)
- }
- }
-
- def stop(): Unit = synchronized {
- // The following codes should be in `synchronized` so that we can make sure "OnStop" is the last
- // message
- if (!stopped) {
- // We should disable concurrent here. Then when RpcEndpoint.onStop is called, it's the only
- // thread that is processing messages. So `RpcEndpoint.onStop` can release its resources
- // safely.
- enableConcurrent = false
- stopped = true
- messages.add(OnStop)
- // Note: The concurrent events in messages will be processed one by one.
- }
- }
-
- // Visible for testing.
- protected def onDrop(message: InboxMessage): Unit = {
- logWarning(s"Drop ${message} because $endpointRef is stopped")
- }
-
- def isEmpty: Boolean = synchronized { messages.isEmpty }
-
- private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = {
- try {
- action
- } catch {
- case NonFatal(e) => {
- try {
- endpoint.onError(e)
- } catch {
- case NonFatal(e) => logWarning(s"Ignore error", e)
- }
- }
- }
- }
-
-}
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala
deleted file mode 100644
index 1876b25592..0000000000
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala
+++ /dev/null
@@ -1,56 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rpc.netty
-
-import java.net.URI
-
-import org.apache.spark.SparkException
-import org.apache.spark.rpc.RpcAddress
-
-private[netty] case class NettyRpcAddress(host: String, port: Int, name: String) {
-
- def toRpcAddress: RpcAddress = RpcAddress(host, port)
-
- override val toString = s"spark://$name@$host:$port"
-}
-
-private[netty] object NettyRpcAddress {
-
- def apply(sparkUrl: String): NettyRpcAddress = {
- try {
- val uri = new URI(sparkUrl)
- val host = uri.getHost
- val port = uri.getPort
- val name = uri.getUserInfo
- if (uri.getScheme != "spark" ||
- host == null ||
- port < 0 ||
- name == null ||
- (uri.getPath != null && !uri.getPath.isEmpty) || // uri.getPath returns "" instead of null
- uri.getFragment != null ||
- uri.getQuery != null) {
- throw new SparkException("Invalid Spark URL: " + sparkUrl)
- }
- NettyRpcAddress(host, port, name)
- } catch {
- case e: java.net.URISyntaxException =>
- throw new SparkException("Invalid Spark URL: " + sparkUrl, e)
- }
- }
-
-}
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala
deleted file mode 100644
index 75dcc02a0c..0000000000
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala
+++ /dev/null
@@ -1,87 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rpc.netty
-
-import scala.concurrent.Promise
-
-import org.apache.spark.Logging
-import org.apache.spark.network.client.RpcResponseCallback
-import org.apache.spark.rpc.{RpcAddress, RpcCallContext}
-
-private[netty] abstract class NettyRpcCallContext(
- endpointRef: NettyRpcEndpointRef,
- override val senderAddress: RpcAddress,
- needReply: Boolean) extends RpcCallContext with Logging {
-
- protected def send(message: Any): Unit
-
- override def reply(response: Any): Unit = {
- if (needReply) {
- send(AskResponse(endpointRef, response))
- } else {
- throw new IllegalStateException(
- s"Cannot send $response to the sender because the sender won't handle it")
- }
- }
-
- override def sendFailure(e: Throwable): Unit = {
- if (needReply) {
- send(AskResponse(endpointRef, RpcFailure(e)))
- } else {
- logError(e.getMessage, e)
- throw new IllegalStateException(
- "Cannot send reply to the sender because the sender won't handle it")
- }
- }
-
- def finish(): Unit = {
- if (!needReply) {
- send(Ack(endpointRef))
- }
- }
-}
-
-/**
- * If the sender and the receiver are in the same process, the reply can be sent back via `Promise`.
- */
-private[netty] class LocalNettyRpcCallContext(
- endpointRef: NettyRpcEndpointRef,
- senderAddress: RpcAddress,
- needReply: Boolean,
- p: Promise[Any]) extends NettyRpcCallContext(endpointRef, senderAddress, needReply) {
-
- override protected def send(message: Any): Unit = {
- p.success(message)
- }
-}
-
-/**
- * A [[RpcCallContext]] that will call [[RpcResponseCallback]] to send the reply back.
- */
-private[netty] class RemoteNettyRpcCallContext(
- nettyEnv: NettyRpcEnv,
- endpointRef: NettyRpcEndpointRef,
- callback: RpcResponseCallback,
- senderAddress: RpcAddress,
- needReply: Boolean) extends NettyRpcCallContext(endpointRef, senderAddress, needReply) {
-
- override protected def send(message: Any): Unit = {
- val reply = nettyEnv.serialize(message)
- callback.onSuccess(reply)
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
deleted file mode 100644
index 5522b40782..0000000000
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
+++ /dev/null
@@ -1,504 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.rpc.netty
-
-import java.io._
-import java.net.{InetSocketAddress, URI}
-import java.nio.ByteBuffer
-import java.util.Arrays
-import java.util.concurrent._
-import javax.annotation.concurrent.GuardedBy
-
-import scala.collection.JavaConverters._
-import scala.collection.mutable
-import scala.concurrent.{Future, Promise}
-import scala.reflect.ClassTag
-import scala.util.{DynamicVariable, Failure, Success}
-import scala.util.control.NonFatal
-
-import org.apache.spark.{Logging, SecurityManager, SparkConf}
-import org.apache.spark.network.TransportContext
-import org.apache.spark.network.client._
-import org.apache.spark.network.netty.SparkTransportConf
-import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap}
-import org.apache.spark.network.server._
-import org.apache.spark.rpc._
-import org.apache.spark.serializer.{JavaSerializer, JavaSerializerInstance}
-import org.apache.spark.util.{ThreadUtils, Utils}
-
-private[netty] class NettyRpcEnv(
- val conf: SparkConf,
- javaSerializerInstance: JavaSerializerInstance,
- host: String,
- securityManager: SecurityManager) extends RpcEnv(conf) with Logging {
-
- private val transportConf =
- SparkTransportConf.fromSparkConf(conf, conf.getInt("spark.rpc.io.threads", 0))
-
- private val dispatcher: Dispatcher = new Dispatcher(this)
-
- private val transportContext =
- new TransportContext(transportConf, new NettyRpcHandler(dispatcher, this))
-
- private val clientFactory = {
- val bootstraps: Seq[TransportClientBootstrap] =
- if (securityManager.isAuthenticationEnabled()) {
- Seq(new SaslClientBootstrap(transportConf, "", securityManager,
- securityManager.isSaslEncryptionEnabled()))
- } else {
- Nil
- }
- transportContext.createClientFactory(bootstraps.asJava)
- }
-
- val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout")
-
- // Because TransportClientFactory.createClient is blocking, we need to run it in this thread pool
- // to implement non-blocking send/ask.
- // TODO: a non-blocking TransportClientFactory.createClient in future
- private val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool(
- "netty-rpc-connection",
- conf.getInt("spark.rpc.connect.threads", 256))
-
- @volatile private var server: TransportServer = _
-
- def start(port: Int): Unit = {
- val bootstraps: Seq[TransportServerBootstrap] =
- if (securityManager.isAuthenticationEnabled()) {
- Seq(new SaslServerBootstrap(transportConf, securityManager))
- } else {
- Nil
- }
- server = transportContext.createServer(port, bootstraps.asJava)
- dispatcher.registerRpcEndpoint(IDVerifier.NAME, new IDVerifier(this, dispatcher))
- }
-
- override lazy val address: RpcAddress = {
- require(server != null, "NettyRpcEnv has not yet started")
- RpcAddress(host, server.getPort())
- }
-
- override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = {
- dispatcher.registerRpcEndpoint(name, endpoint)
- }
-
- def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = {
- val addr = NettyRpcAddress(uri)
- val endpointRef = new NettyRpcEndpointRef(conf, addr, this)
- val idVerifierRef =
- new NettyRpcEndpointRef(conf, NettyRpcAddress(addr.host, addr.port, IDVerifier.NAME), this)
- idVerifierRef.ask[Boolean](ID(endpointRef.name)).flatMap { find =>
- if (find) {
- Future.successful(endpointRef)
- } else {
- Future.failed(new RpcEndpointNotFoundException(uri))
- }
- }(ThreadUtils.sameThread)
- }
-
- override def stop(endpointRef: RpcEndpointRef): Unit = {
- require(endpointRef.isInstanceOf[NettyRpcEndpointRef])
- dispatcher.stop(endpointRef)
- }
-
- private[netty] def send(message: RequestMessage): Unit = {
- val remoteAddr = message.receiver.address
- if (remoteAddr == address) {
- val promise = Promise[Any]()
- dispatcher.postMessage(message, promise)
- promise.future.onComplete {
- case Success(response) =>
- val ack = response.asInstanceOf[Ack]
- logDebug(s"Receive ack from ${ack.sender}")
- case Failure(e) =>
- logError(s"Exception when sending $message", e)
- }(ThreadUtils.sameThread)
- } else {
- try {
- // `createClient` will block if it cannot find a known connection, so we should run it in
- // clientConnectionExecutor
- clientConnectionExecutor.execute(new Runnable {
- override def run(): Unit = Utils.tryLogNonFatalError {
- val client = clientFactory.createClient(remoteAddr.host, remoteAddr.port)
- client.sendRpc(serialize(message), new RpcResponseCallback {
-
- override def onFailure(e: Throwable): Unit = {
- logError(s"Exception when sending $message", e)
- }
-
- override def onSuccess(response: Array[Byte]): Unit = {
- val ack = deserialize[Ack](response)
- logDebug(s"Receive ack from ${ack.sender}")
- }
- })
- }
- })
- } catch {
- case e: RejectedExecutionException => {
- // `send` after shutting clientConnectionExecutor down, ignore it
- logWarning(s"Cannot send ${message} because RpcEnv is stopped")
- }
- }
- }
- }
-
- private[netty] def ask(message: RequestMessage): Future[Any] = {
- val promise = Promise[Any]()
- val remoteAddr = message.receiver.address
- if (remoteAddr == address) {
- val p = Promise[Any]()
- dispatcher.postMessage(message, p)
- p.future.onComplete {
- case Success(response) =>
- val reply = response.asInstanceOf[AskResponse]
- if (reply.reply.isInstanceOf[RpcFailure]) {
- if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) {
- logWarning(s"Ignore failure: ${reply.reply}")
- }
- } else if (!promise.trySuccess(reply.reply)) {
- logWarning(s"Ignore message: ${reply}")
- }
- case Failure(e) =>
- if (!promise.tryFailure(e)) {
- logWarning("Ignore Exception", e)
- }
- }(ThreadUtils.sameThread)
- } else {
- try {
- // `createClient` will block if it cannot find a known connection, so we should run it in
- // clientConnectionExecutor
- clientConnectionExecutor.execute(new Runnable {
- override def run(): Unit = {
- val client = clientFactory.createClient(remoteAddr.host, remoteAddr.port)
- client.sendRpc(serialize(message), new RpcResponseCallback {
-
- override def onFailure(e: Throwable): Unit = {
- if (!promise.tryFailure(e)) {
- logWarning("Ignore Exception", e)
- }
- }
-
- override def onSuccess(response: Array[Byte]): Unit = {
- val reply = deserialize[AskResponse](response)
- if (reply.reply.isInstanceOf[RpcFailure]) {
- if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) {
- logWarning(s"Ignore failure: ${reply.reply}")
- }
- } else if (!promise.trySuccess(reply.reply)) {
- logWarning(s"Ignore message: ${reply}")
- }
- }
- })
- }
- })
- } catch {
- case e: RejectedExecutionException => {
- if (!promise.tryFailure(e)) {
- logWarning(s"Ignore failure", e)
- }
- }
- }
- }
- promise.future
- }
-
- private[netty] def serialize(content: Any): Array[Byte] = {
- val buffer = javaSerializerInstance.serialize(content)
- Arrays.copyOfRange(
- buffer.array(), buffer.arrayOffset + buffer.position, buffer.arrayOffset + buffer.limit)
- }
-
- private[netty] def deserialize[T: ClassTag](bytes: Array[Byte]): T = {
- deserialize { () =>
- javaSerializerInstance.deserialize[T](ByteBuffer.wrap(bytes))
- }
- }
-
- override def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = {
- dispatcher.getRpcEndpointRef(endpoint)
- }
-
- override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String =
- new NettyRpcAddress(address.host, address.port, endpointName).toString
-
- override def shutdown(): Unit = {
- cleanup()
- }
-
- override def awaitTermination(): Unit = {
- dispatcher.awaitTermination()
- }
-
- private def cleanup(): Unit = {
- if (timeoutScheduler != null) {
- timeoutScheduler.shutdownNow()
- }
- if (server != null) {
- server.close()
- }
- if (clientFactory != null) {
- clientFactory.close()
- }
- if (dispatcher != null) {
- dispatcher.stop()
- }
- if (clientConnectionExecutor != null) {
- clientConnectionExecutor.shutdownNow()
- }
- }
-
- override def deserialize[T](deserializationAction: () => T): T = {
- NettyRpcEnv.currentEnv.withValue(this) {
- deserializationAction()
- }
- }
-}
-
-private[netty] object NettyRpcEnv extends Logging {
-
- /**
- * When deserializing the [[NettyRpcEndpointRef]], it needs a reference to [[NettyRpcEnv]].
- * Use `currentEnv` to wrap the deserialization codes. E.g.,
- *
- * {{{
- * NettyRpcEnv.currentEnv.withValue(this) {
- * your deserialization codes
- * }
- * }}}
- */
- private[netty] val currentEnv = new DynamicVariable[NettyRpcEnv](null)
-}
-
-private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
-
- def create(config: RpcEnvConfig): RpcEnv = {
- val sparkConf = config.conf
- // Use JavaSerializerInstance in multiple threads is safe. However, if we plan to support
- // KryoSerializer in future, we have to use ThreadLocal to store SerializerInstance
- val javaSerializerInstance =
- new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance]
- val nettyEnv =
- new NettyRpcEnv(sparkConf, javaSerializerInstance, config.host, config.securityManager)
- val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort =>
- nettyEnv.start(actualPort)
- (nettyEnv, actualPort)
- }
- try {
- Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, "NettyRpcEnv")._1
- } catch {
- case NonFatal(e) =>
- nettyEnv.shutdown()
- throw e
- }
- }
-}
-
-private[netty] class NettyRpcEndpointRef(@transient conf: SparkConf)
- extends RpcEndpointRef(conf) with Serializable with Logging {
-
- @transient @volatile private var nettyEnv: NettyRpcEnv = _
-
- @transient @volatile private var _address: NettyRpcAddress = _
-
- def this(conf: SparkConf, _address: NettyRpcAddress, nettyEnv: NettyRpcEnv) {
- this(conf)
- this._address = _address
- this.nettyEnv = nettyEnv
- }
-
- override def address: RpcAddress = _address.toRpcAddress
-
- private def readObject(in: ObjectInputStream): Unit = {
- in.defaultReadObject()
- _address = in.readObject().asInstanceOf[NettyRpcAddress]
- nettyEnv = NettyRpcEnv.currentEnv.value
- }
-
- private def writeObject(out: ObjectOutputStream): Unit = {
- out.defaultWriteObject()
- out.writeObject(_address)
- }
-
- override def name: String = _address.name
-
- override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
- val promise = Promise[Any]()
- val timeoutCancelable = nettyEnv.timeoutScheduler.schedule(new Runnable {
- override def run(): Unit = {
- promise.tryFailure(new TimeoutException("Cannot receive any reply in " + timeout.duration))
- }
- }, timeout.duration.toNanos, TimeUnit.NANOSECONDS)
- val f = nettyEnv.ask(RequestMessage(nettyEnv.address, this, message, true))
- f.onComplete { v =>
- timeoutCancelable.cancel(true)
- if (!promise.tryComplete(v)) {
- logWarning(s"Ignore message $v")
- }
- }(ThreadUtils.sameThread)
- promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread)
- }
-
- override def send(message: Any): Unit = {
- require(message != null, "Message is null")
- nettyEnv.send(RequestMessage(nettyEnv.address, this, message, false))
- }
-
- override def toString: String = s"NettyRpcEndpointRef(${_address})"
-
- def toURI: URI = new URI(s"spark://${_address}")
-
- final override def equals(that: Any): Boolean = that match {
- case other: NettyRpcEndpointRef => _address == other._address
- case _ => false
- }
-
- final override def hashCode(): Int = if (_address == null) 0 else _address.hashCode()
-}
-
-/**
- * The message that is sent from the sender to the receiver.
- */
-private[netty] case class RequestMessage(
- senderAddress: RpcAddress, receiver: NettyRpcEndpointRef, content: Any, needReply: Boolean)
-
-/**
- * The base trait for all messages that are sent back from the receiver to the sender.
- */
-private[netty] trait ResponseMessage
-
-/**
- * The reply for `ask` from the receiver side.
- */
-private[netty] case class AskResponse(sender: NettyRpcEndpointRef, reply: Any)
- extends ResponseMessage
-
-/**
- * A message to send back to the receiver side. It's necessary because [[TransportClient]] only
- * clean the resources when it receives a reply.
- */
-private[netty] case class Ack(sender: NettyRpcEndpointRef) extends ResponseMessage
-
-/**
- * A response that indicates some failure happens in the receiver side.
- */
-private[netty] case class RpcFailure(e: Throwable)
-
-/**
- * Maintain the mapping relations between client addresses and [[RpcEnv]] addresses, broadcast
- * network events and forward messages to [[Dispatcher]].
- */
-private[netty] class NettyRpcHandler(
- dispatcher: Dispatcher, nettyEnv: NettyRpcEnv) extends RpcHandler with Logging {
-
- private type ClientAddress = RpcAddress
- private type RemoteEnvAddress = RpcAddress
-
- // Store all client addresses and their NettyRpcEnv addresses.
- @GuardedBy("this")
- private val remoteAddresses = new mutable.HashMap[ClientAddress, RemoteEnvAddress]()
-
- // Store the connections from other NettyRpcEnv addresses. We need to keep track of the connection
- // count because `TransportClientFactory.createClient` will create multiple connections
- // (at most `spark.shuffle.io.numConnectionsPerPeer` connections) and randomly select a connection
- // to send the message. See `TransportClientFactory.createClient` for more details.
- @GuardedBy("this")
- private val remoteConnectionCount = new mutable.HashMap[RemoteEnvAddress, Int]()
-
- override def receive(
- client: TransportClient, message: Array[Byte], callback: RpcResponseCallback): Unit = {
- val requestMessage = nettyEnv.deserialize[RequestMessage](message)
- val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
- assert(addr != null)
- val remoteEnvAddress = requestMessage.senderAddress
- val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
- val broadcastMessage =
- synchronized {
- // If the first connection to a remote RpcEnv is found, we should broadcast "Associated"
- if (remoteAddresses.put(clientAddr, remoteEnvAddress).isEmpty) {
- // clientAddr connects at the first time
- val count = remoteConnectionCount.getOrElse(remoteEnvAddress, 0)
- // Increase the connection number of remoteEnvAddress
- remoteConnectionCount.put(remoteEnvAddress, count + 1)
- if (count == 0) {
- // This is the first connection, so fire "Associated"
- Some(Associated(remoteEnvAddress))
- } else {
- None
- }
- } else {
- None
- }
- }
- broadcastMessage.foreach(dispatcher.broadcastMessage)
- dispatcher.postMessage(requestMessage, callback)
- }
-
- override def getStreamManager: StreamManager = new OneForOneStreamManager
-
- override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = {
- val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
- if (addr != null) {
- val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
- val broadcastMessage =
- synchronized {
- remoteAddresses.get(clientAddr).map(AssociationError(cause, _))
- }
- if (broadcastMessage.isEmpty) {
- logError(cause.getMessage, cause)
- } else {
- dispatcher.broadcastMessage(broadcastMessage.get)
- }
- } else {
- // If the channel is closed before connecting, its remoteAddress will be null.
- // See java.net.Socket.getRemoteSocketAddress
- // Because we cannot get a RpcAddress, just log it
- logError("Exception before connecting to the client", cause)
- }
- }
-
- override def connectionTerminated(client: TransportClient): Unit = {
- val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
- if (addr != null) {
- val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
- val broadcastMessage =
- synchronized {
- // If the last connection to a remote RpcEnv is terminated, we should broadcast
- // "Disassociated"
- remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress =>
- remoteAddresses -= clientAddr
- val count = remoteConnectionCount.getOrElse(remoteEnvAddress, 0)
- assert(count != 0, "remoteAddresses and remoteConnectionCount are not consistent")
- if (count - 1 == 0) {
- // We lost all clients, so clean up and fire "Disassociated"
- remoteConnectionCount.remove(remoteEnvAddress)
- Some(Disassociated(remoteEnvAddress))
- } else {
- // Decrease the connection number of remoteEnvAddress
- remoteConnectionCount.put(remoteEnvAddress, count - 1)
- None
- }
- }
- }
- broadcastMessage.foreach(dispatcher.broadcastMessage)
- } else {
- // If the channel is closed before connecting, its remoteAddress will be null. In this case,
- // we can ignore it since we don't fire "Associated".
- // See java.net.Socket.getRemoteSocketAddress
- }
- }
-
-}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala
index e749631bf6..7478ab0fc2 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala
@@ -19,7 +19,7 @@ package org.apache.spark.storage
import scala.concurrent.{ExecutionContext, Future}
-import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext, RpcEndpoint}
+import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint}
import org.apache.spark.util.ThreadUtils
import org.apache.spark.{Logging, MapOutputTracker, SparkEnv}
import org.apache.spark.storage.BlockManagerMessages._
@@ -33,7 +33,7 @@ class BlockManagerSlaveEndpoint(
override val rpcEnv: RpcEnv,
blockManager: BlockManager,
mapOutputTracker: MapOutputTracker)
- extends ThreadSafeRpcEndpoint with Logging {
+ extends RpcEndpoint with Logging {
private val asyncThreadPool =
ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool")
@@ -80,7 +80,7 @@ class BlockManagerSlaveEndpoint(
future.onSuccess { case response =>
logDebug("Done " + actionMessage + ", response is " + response)
context.reply(response)
- logDebug("Sent response: " + response + " to " + context.senderAddress)
+ logDebug("Sent response: " + response + " to " + context.sender)
}
future.onFailure { case t: Throwable =>
logError("Error in " + actionMessage, t)
diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
index 1ed098379e..22e291a2b4 100644
--- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
@@ -85,11 +85,7 @@ private[spark] object ThreadUtils {
*/
def newDaemonSingleThreadScheduledExecutor(threadName: String): ScheduledExecutorService = {
val threadFactory = new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadName).build()
- val executor = new ScheduledThreadPoolExecutor(1, threadFactory)
- // By default, a cancelled task is not automatically removed from the work queue until its delay
- // elapses. We have to enable it manually.
- executor.setRemoveOnCancelPolicy(true)
- executor
+ Executors.newSingleThreadScheduledExecutor(threadFactory)
}
/**
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index 7e70308bb3..af4e68950f 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -168,9 +168,10 @@ class MapOutputTrackerSuite extends SparkFunSuite {
masterTracker.registerShuffle(10, 1)
masterTracker.registerMapOutput(10, 0, MapStatus(
BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0)))
- val senderAddress = RpcAddress("localhost", 12345)
+ val sender = mock(classOf[RpcEndpointRef])
+ when(sender.address).thenReturn(RpcAddress("localhost", 12345))
val rpcCallContext = mock(classOf[RpcCallContext])
- when(rpcCallContext.senderAddress).thenReturn(senderAddress)
+ when(rpcCallContext.sender).thenReturn(sender)
masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(10))
verify(rpcCallContext).reply(any())
verify(rpcCallContext, never()).sendFailure(any())
@@ -197,9 +198,10 @@ class MapOutputTrackerSuite extends SparkFunSuite {
masterTracker.registerMapOutput(20, i, new CompressedMapStatus(
BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0)))
}
- val senderAddress = RpcAddress("localhost", 12345)
+ val sender = mock(classOf[RpcEndpointRef])
+ when(sender.address).thenReturn(RpcAddress("localhost", 12345))
val rpcCallContext = mock(classOf[RpcCallContext])
- when(rpcCallContext.senderAddress).thenReturn(senderAddress)
+ when(rpcCallContext.sender).thenReturn(sender)
masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(20))
verify(rpcCallContext, never()).reply(any())
verify(rpcCallContext).sendFailure(isA(classOf[SparkException]))
diff --git a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala
index 2d14249855..33270bec62 100644
--- a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala
+++ b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala
@@ -41,7 +41,6 @@ object SSLSampleConfigs {
def sparkSSLConfig(): SparkConf = {
val conf = new SparkConf(loadDefaults = false)
- conf.set("spark.rpc", "akka")
conf.set("spark.ssl.enabled", "true")
conf.set("spark.ssl.keyStore", keyStorePath)
conf.set("spark.ssl.keyStorePassword", "password")
@@ -55,7 +54,6 @@ object SSLSampleConfigs {
def sparkSSLConfigUntrusted(): SparkConf = {
val conf = new SparkConf(loadDefaults = false)
- conf.set("spark.rpc", "akka")
conf.set("spark.ssl.enabled", "true")
conf.set("spark.ssl.keyStore", untrustedKeyStorePath)
conf.set("spark.ssl.keyStorePassword", "password")
diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala
index 40c24bdecc..e9034e39a7 100644
--- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala
@@ -26,7 +26,8 @@ class WorkerWatcherSuite extends SparkFunSuite {
val conf = new SparkConf()
val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf))
val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker")
- val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl, isTesting = true)
+ val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl)
+ workerWatcher.setTesting(testing = true)
rpcEnv.setupEndpoint("worker-watcher", workerWatcher)
workerWatcher.onDisconnected(RpcAddress("1.2.3.4", 1234))
assert(workerWatcher.isShutDown)
@@ -38,7 +39,8 @@ class WorkerWatcherSuite extends SparkFunSuite {
val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf))
val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker")
val otherRpcAddress = RpcAddress("4.3.2.1", 1234)
- val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl, isTesting = true)
+ val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl)
+ workerWatcher.setTesting(testing = true)
rpcEnv.setupEndpoint("worker-watcher", workerWatcher)
workerWatcher.onDisconnected(otherRpcAddress)
assert(!workerWatcher.isShutDown)
diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
index e836946a59..6ceafe4337 100644
--- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
@@ -17,7 +17,6 @@
package org.apache.spark.rpc
-import java.io.NotSerializableException
import java.util.concurrent.{TimeUnit, CountDownLatch, TimeoutException}
import scala.collection.mutable
@@ -100,6 +99,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
}
}
val rpcEndpointRef = env.setupEndpoint("send-ref", endpoint)
+
val newRpcEndpointRef = rpcEndpointRef.askWithRetry[RpcEndpointRef]("Hello")
val reply = newRpcEndpointRef.askWithRetry[String]("Echo")
assert("Echo" === reply)
@@ -516,9 +516,6 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
assert(events === List(
("onConnected", remoteAddress),
("onNetworkError", remoteAddress),
- ("onDisconnected", remoteAddress)) ||
- events === List(
- ("onConnected", remoteAddress),
("onDisconnected", remoteAddress)))
}
}
@@ -538,84 +535,15 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
"local", env.address, "sendWithReply-unserializable-error")
try {
val f = rpcEndpointRef.ask[String]("hello")
- val e = intercept[Exception] {
+ intercept[TimeoutException] {
Await.result(f, 1 seconds)
}
- assert(e.isInstanceOf[TimeoutException] || // For Akka
- e.isInstanceOf[NotSerializableException] // For Netty
- )
} finally {
anotherEnv.shutdown()
anotherEnv.awaitTermination()
}
}
- test("port conflict") {
- val anotherEnv = createRpcEnv(new SparkConf(), "remote", env.address.port)
- assert(anotherEnv.address.port != env.address.port)
- }
-
- test("send with authentication") {
- val conf = new SparkConf
- conf.set("spark.authenticate", "true")
- conf.set("spark.authenticate.secret", "good")
-
- val localEnv = createRpcEnv(conf, "authentication-local", 13345)
- val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345)
-
- try {
- @volatile var message: String = null
- localEnv.setupEndpoint("send-authentication", new RpcEndpoint {
- override val rpcEnv = localEnv
-
- override def receive: PartialFunction[Any, Unit] = {
- case msg: String => message = msg
- }
- })
- val rpcEndpointRef =
- remoteEnv.setupEndpointRef("authentication-local", localEnv.address, "send-authentication")
- rpcEndpointRef.send("hello")
- eventually(timeout(5 seconds), interval(10 millis)) {
- assert("hello" === message)
- }
- } finally {
- localEnv.shutdown()
- localEnv.awaitTermination()
- remoteEnv.shutdown()
- remoteEnv.awaitTermination()
- }
- }
-
- test("ask with authentication") {
- val conf = new SparkConf
- conf.set("spark.authenticate", "true")
- conf.set("spark.authenticate.secret", "good")
-
- val localEnv = createRpcEnv(conf, "authentication-local", 13345)
- val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345)
-
- try {
- localEnv.setupEndpoint("ask-authentication", new RpcEndpoint {
- override val rpcEnv = localEnv
-
- override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
- case msg: String => {
- context.reply(msg)
- }
- }
- })
- val rpcEndpointRef =
- remoteEnv.setupEndpointRef("authentication-local", localEnv.address, "ask-authentication")
- val reply = rpcEndpointRef.askWithRetry[String]("hello")
- assert("hello" === reply)
- } finally {
- localEnv.shutdown()
- localEnv.awaitTermination()
- remoteEnv.shutdown()
- remoteEnv.awaitTermination()
- }
- }
-
test("construct RpcTimeout with conf property") {
val conf = new SparkConf
@@ -684,7 +612,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
// once the future is complete to verify addMessageIfTimeout was invoked
val reply3 =
intercept[RpcTimeoutException] {
- Await.result(fut3, 2000 millis)
+ Await.result(fut3, 200 millis)
}.getMessage
// When the future timed out, the recover callback should have used
diff --git a/core/src/test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala b/core/src/test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala
deleted file mode 100644
index 5e8da3e205..0000000000
--- a/core/src/test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala
+++ /dev/null
@@ -1,123 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rpc
-
-import scala.collection.mutable.ArrayBuffer
-
-import org.scalactic.TripleEquals
-
-class TestRpcEndpoint extends ThreadSafeRpcEndpoint with TripleEquals {
-
- override val rpcEnv: RpcEnv = null
-
- @volatile private var receiveMessages = ArrayBuffer[Any]()
-
- @volatile private var receiveAndReplyMessages = ArrayBuffer[Any]()
-
- @volatile private var onConnectedMessages = ArrayBuffer[RpcAddress]()
-
- @volatile private var onDisconnectedMessages = ArrayBuffer[RpcAddress]()
-
- @volatile private var onNetworkErrorMessages = ArrayBuffer[(Throwable, RpcAddress)]()
-
- @volatile private var started = false
-
- @volatile private var stopped = false
-
- override def receive: PartialFunction[Any, Unit] = {
- case message: Any => receiveMessages += message
- }
-
- override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
- case message: Any => receiveAndReplyMessages += message
- }
-
- override def onConnected(remoteAddress: RpcAddress): Unit = {
- onConnectedMessages += remoteAddress
- }
-
- /**
- * Invoked when some network error happens in the connection between the current node and
- * `remoteAddress`.
- */
- override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = {
- onNetworkErrorMessages += cause -> remoteAddress
- }
-
- override def onDisconnected(remoteAddress: RpcAddress): Unit = {
- onDisconnectedMessages += remoteAddress
- }
-
- def numReceiveMessages: Int = receiveMessages.size
-
- override def onStart(): Unit = {
- started = true
- }
-
- override def onStop(): Unit = {
- stopped = true
- }
-
- def verifyStarted(): Unit = {
- assert(started, "RpcEndpoint is not started")
- }
-
- def verifyStopped(): Unit = {
- assert(stopped, "RpcEndpoint is not stopped")
- }
-
- def verifyReceiveMessages(expected: Seq[Any]): Unit = {
- assert(receiveMessages === expected)
- }
-
- def verifySingleReceiveMessage(message: Any): Unit = {
- verifyReceiveMessages(List(message))
- }
-
- def verifyReceiveAndReplyMessages(expected: Seq[Any]): Unit = {
- assert(receiveAndReplyMessages === expected)
- }
-
- def verifySingleReceiveAndReplyMessage(message: Any): Unit = {
- verifyReceiveAndReplyMessages(List(message))
- }
-
- def verifySingleOnConnectedMessage(remoteAddress: RpcAddress): Unit = {
- verifyOnConnectedMessages(List(remoteAddress))
- }
-
- def verifyOnConnectedMessages(expected: Seq[RpcAddress]): Unit = {
- assert(onConnectedMessages === expected)
- }
-
- def verifySingleOnDisconnectedMessage(remoteAddress: RpcAddress): Unit = {
- verifyOnDisconnectedMessages(List(remoteAddress))
- }
-
- def verifyOnDisconnectedMessages(expected: Seq[RpcAddress]): Unit = {
- assert(onDisconnectedMessages === expected)
- }
-
- def verifySingleOnNetworkErrorMessage(cause: Throwable, remoteAddress: RpcAddress): Unit = {
- verifyOnNetworkErrorMessages(List(cause -> remoteAddress))
- }
-
- def verifyOnNetworkErrorMessages(expected: Seq[(Throwable, RpcAddress)]): Unit = {
- assert(onNetworkErrorMessages === expected)
- }
-}
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala
deleted file mode 100644
index ff83ab9b32..0000000000
--- a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala
+++ /dev/null
@@ -1,148 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rpc.netty
-
-import java.util.concurrent.{CountDownLatch, TimeUnit}
-import java.util.concurrent.atomic.AtomicInteger
-
-import org.mockito.Mockito._
-
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.rpc.{RpcEnv, RpcEndpoint, RpcAddress, TestRpcEndpoint}
-
-class InboxSuite extends SparkFunSuite {
-
- test("post") {
- val endpoint = new TestRpcEndpoint
- val endpointRef = mock(classOf[NettyRpcEndpointRef])
- when(endpointRef.name).thenReturn("hello")
-
- val dispatcher = mock(classOf[Dispatcher])
-
- val inbox = new Inbox(endpointRef, endpoint)
- val message = ContentMessage(null, "hi", false, null)
- inbox.post(message)
- inbox.process(dispatcher)
- assert(inbox.isEmpty)
-
- endpoint.verifySingleReceiveMessage("hi")
-
- inbox.stop()
- inbox.process(dispatcher)
- assert(inbox.isEmpty)
- endpoint.verifyStarted()
- endpoint.verifyStopped()
- }
-
- test("post: with reply") {
- val endpoint = new TestRpcEndpoint
- val endpointRef = mock(classOf[NettyRpcEndpointRef])
- val dispatcher = mock(classOf[Dispatcher])
-
- val inbox = new Inbox(endpointRef, endpoint)
- val message = ContentMessage(null, "hi", true, null)
- inbox.post(message)
- inbox.process(dispatcher)
- assert(inbox.isEmpty)
-
- endpoint.verifySingleReceiveAndReplyMessage("hi")
- }
-
- test("post: multiple threads") {
- val endpoint = new TestRpcEndpoint
- val endpointRef = mock(classOf[NettyRpcEndpointRef])
- when(endpointRef.name).thenReturn("hello")
-
- val dispatcher = mock(classOf[Dispatcher])
-
- val numDroppedMessages = new AtomicInteger(0)
- val inbox = new Inbox(endpointRef, endpoint) {
- override def onDrop(message: InboxMessage): Unit = {
- numDroppedMessages.incrementAndGet()
- }
- }
-
- val exitLatch = new CountDownLatch(10)
-
- for (_ <- 0 until 10) {
- new Thread {
- override def run(): Unit = {
- for (_ <- 0 until 100) {
- val message = ContentMessage(null, "hi", false, null)
- inbox.post(message)
- }
- exitLatch.countDown()
- }
- }.start()
- }
- inbox.process(dispatcher)
- assert(inbox.isEmpty)
- inbox.stop()
- inbox.process(dispatcher)
- assert(inbox.isEmpty)
-
- exitLatch.await(30, TimeUnit.SECONDS)
-
- assert(1000 === endpoint.numReceiveMessages + numDroppedMessages.get)
- endpoint.verifyStarted()
- endpoint.verifyStopped()
- }
-
- test("post: Associated") {
- val endpoint = new TestRpcEndpoint
- val endpointRef = mock(classOf[NettyRpcEndpointRef])
- val dispatcher = mock(classOf[Dispatcher])
-
- val remoteAddress = RpcAddress("localhost", 11111)
-
- val inbox = new Inbox(endpointRef, endpoint)
- inbox.post(Associated(remoteAddress))
- inbox.process(dispatcher)
-
- endpoint.verifySingleOnConnectedMessage(remoteAddress)
- }
-
- test("post: Disassociated") {
- val endpoint = new TestRpcEndpoint
- val endpointRef = mock(classOf[NettyRpcEndpointRef])
- val dispatcher = mock(classOf[Dispatcher])
-
- val remoteAddress = RpcAddress("localhost", 11111)
-
- val inbox = new Inbox(endpointRef, endpoint)
- inbox.post(Disassociated(remoteAddress))
- inbox.process(dispatcher)
-
- endpoint.verifySingleOnDisconnectedMessage(remoteAddress)
- }
-
- test("post: AssociationError") {
- val endpoint = new TestRpcEndpoint
- val endpointRef = mock(classOf[NettyRpcEndpointRef])
- val dispatcher = mock(classOf[Dispatcher])
-
- val remoteAddress = RpcAddress("localhost", 11111)
- val cause = new RuntimeException("Oops")
-
- val inbox = new Inbox(endpointRef, endpoint)
- inbox.post(AssociationError(cause, remoteAddress))
- inbox.process(dispatcher)
-
- endpoint.verifySingleOnNetworkErrorMessage(cause, remoteAddress)
- }
-}
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala
deleted file mode 100644
index a5d43d3704..0000000000
--- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala
+++ /dev/null
@@ -1,29 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rpc.netty
-
-import org.apache.spark.SparkFunSuite
-
-class NettyRpcAddressSuite extends SparkFunSuite {
-
- test("toString") {
- val addr = NettyRpcAddress("localhost", 12345, "test")
- assert(addr.toString === "spark://test@localhost:12345")
- }
-
-}
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala
deleted file mode 100644
index be19668e17..0000000000
--- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala
+++ /dev/null
@@ -1,38 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rpc.netty
-
-import org.apache.spark.{SecurityManager, SparkConf}
-import org.apache.spark.rpc._
-
-class NettyRpcEnvSuite extends RpcEnvSuite {
-
- override def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv = {
- val config = RpcEnvConfig(conf, "test", "localhost", port, new SecurityManager(conf))
- new NettyRpcEnvFactory().create(config)
- }
-
- test("non-existent endpoint") {
- val uri = env.uriOf("test", env.address, "nonexist-endpoint")
- val e = intercept[RpcEndpointNotFoundException] {
- env.setupEndpointRef("test", env.address, "nonexist-endpoint")
- }
- assert(e.getMessage.contains(uri))
- }
-
-}
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
deleted file mode 100644
index 06ca035d19..0000000000
--- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
+++ /dev/null
@@ -1,67 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rpc.netty
-
-import java.net.InetSocketAddress
-
-import io.netty.channel.Channel
-import org.mockito.Mockito._
-import org.mockito.Matchers._
-
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.network.client.{TransportResponseHandler, TransportClient}
-import org.apache.spark.rpc._
-
-class NettyRpcHandlerSuite extends SparkFunSuite {
-
- val env = mock(classOf[NettyRpcEnv])
- when(env.deserialize(any(classOf[Array[Byte]]))(any())).
- thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false))
-
- test("receive") {
- val dispatcher = mock(classOf[Dispatcher])
- val nettyRpcHandler = new NettyRpcHandler(dispatcher, env)
-
- val channel = mock(classOf[Channel])
- val client = new TransportClient(channel, mock(classOf[TransportResponseHandler]))
- when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000))
- nettyRpcHandler.receive(client, null, null)
-
- when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40001))
- nettyRpcHandler.receive(client, null, null)
-
- verify(dispatcher, times(1)).broadcastMessage(Associated(RpcAddress("localhost", 12345)))
- }
-
- test("connectionTerminated") {
- val dispatcher = mock(classOf[Dispatcher])
- val nettyRpcHandler = new NettyRpcHandler(dispatcher, env)
-
- val channel = mock(classOf[Channel])
- val client = new TransportClient(channel, mock(classOf[TransportResponseHandler]))
- when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000))
- nettyRpcHandler.receive(client, null, null)
-
- when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000))
- nettyRpcHandler.connectionTerminated(client)
-
- verify(dispatcher, times(1)).broadcastMessage(Associated(RpcAddress("localhost", 12345)))
- verify(dispatcher, times(1)).broadcastMessage(Disassociated(RpcAddress("localhost", 12345)))
- }
-
-}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
index fbb8bb6b2f..df841288a0 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
@@ -78,10 +78,6 @@ public class TransportClient implements Closeable {
this.handler = Preconditions.checkNotNull(handler);
}
- public Channel getChannel() {
- return channel;
- }
-
public boolean isActive() {
return channel.isOpen() || channel.isActive();
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
index dbb7f95f55..2ba92a40f8 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
@@ -52,6 +52,4 @@ public abstract class RpcHandler {
* No further requests will come from this client.
*/
public void connectionTerminated(TransportClient client) { }
-
- public void exceptionCaught(Throwable cause, TransportClient client) { }
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
index 96941d26be..df60278058 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
@@ -71,7 +71,6 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
@Override
public void exceptionCaught(Throwable cause) {
- rpcHandler.exceptionCaught(cause, reverseClient);
}
@Override
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
index d053e9e849..204e6142fd 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
@@ -474,7 +474,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
// Remote messages
case RegisterReceiver(streamId, typ, hostPort, receiverEndpoint) =>
val successful =
- registerReceiver(streamId, typ, hostPort, receiverEndpoint, context.senderAddress)
+ registerReceiver(streamId, typ, hostPort, receiverEndpoint, context.sender.address)
context.reply(successful)
case AddBlock(receivedBlockInfo) =>
context.reply(addBlock(receivedBlockInfo))
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 32d218169a..93621b44c9 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -556,7 +556,10 @@ private[spark] class ApplicationMaster(
override val rpcEnv: RpcEnv, driver: RpcEndpointRef, isClusterMode: Boolean)
extends RpcEndpoint with Logging {
- driver.send(RegisterClusterManager(self))
+ override def onStart(): Unit = {
+ driver.send(RegisterClusterManager(self))
+
+ }
override def receive: PartialFunction[Any, Unit] = {
case x: AddWebUIFilter =>