aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorzsxwing <zsxwing@gmail.com>2015-03-29 21:25:09 -0700
committerReynold Xin <rxin@databricks.com>2015-03-29 21:25:09 -0700
commita8d53afb4e119788fa0d9dd6b3e3ca94cea98581 (patch)
tree10d1062830ccc923819c2a7f31b4d89af6ba33ed /core
parent0e2753ff14e0d3f2433272c13ce26f67dc89767f (diff)
downloadspark-a8d53afb4e119788fa0d9dd6b3e3ca94cea98581.tar.gz
spark-a8d53afb4e119788fa0d9dd6b3e3ca94cea98581.tar.bz2
spark-a8d53afb4e119788fa0d9dd6b3e3ca94cea98581.zip
[SPARK-5124][Core] A standard RPC interface and an Akka implementation
This PR added a standard internal RPC interface for Spark and an Akka implementation. See [the design document](https://issues.apache.org/jira/secure/attachment/12698710/Pluggable%20RPC%20-%20draft%202.pdf) for more details. I will split the whole work into multiple PRs to make it easier for code review. This is the first PR and avoid to touch too many files. Author: zsxwing <zsxwing@gmail.com> Closes #4588 from zsxwing/rpc-part1 and squashes the following commits: fe3df4c [zsxwing] Move registerEndpoint and use actorSystem.dispatcher in asyncSetupEndpointRefByURI f6f3287 [zsxwing] Remove RpcEndpointRef.toURI 8bd1097 [zsxwing] Fix docs and the code style f459380 [zsxwing] Add RpcAddress.fromURI and rename urls to uris b221398 [zsxwing] Move send methods above ask methods 15cfd7b [zsxwing] Merge branch 'master' into rpc-part1 9ffa997 [zsxwing] Fix MiMa tests 78a1733 [zsxwing] Merge remote-tracking branch 'origin/master' into rpc-part1 385b9c3 [zsxwing] Fix the code style and add docs 2cc3f78 [zsxwing] Add an asynchronous version of setupEndpointRefByUrl e8dfec3 [zsxwing] Remove 'sendWithReply(message: Any, sender: RpcEndpointRef): Unit' 08564ae [zsxwing] Add RpcEnvFactory to create RpcEnv e5df4ca [zsxwing] Handle AkkaFailure(e) in Actor ec7c5b0 [zsxwing] Fix docs 7fc95e1 [zsxwing] Implement askWithReply in RpcEndpointRef 9288406 [zsxwing] Document thread-safety for setupThreadSafeEndpoint 3007c09 [zsxwing] Move setupDriverEndpointRef to RpcUtils and rename to makeDriverRef c425022 [zsxwing] Fix the code style 5f87700 [zsxwing] Move the logical of processing message to a private function 3e56123 [zsxwing] Use lazy to eliminate CountDownLatch 07f128f [zsxwing] Remove ActionScheduler.scala 4d34191 [zsxwing] Remove scheduler from RpcEnv 7cdd95e [zsxwing] Add docs for RpcEnv 51e6667 [zsxwing] Add 'sender' to RpcCallContext and rename the parameter of receiveAndReply to 'context' ffc1280 [zsxwing] Rename 'fail' to 'sendFailure' and other minor code style changes 28e6d0f [zsxwing] Add onXXX for network events and remove the companion objects of network events 3751c97 [zsxwing] Rename RpcResponse to RpcCallContext fe7d1ff [zsxwing] Add explicit reply in rpc 7b9e0c9 [zsxwing] Fix the indentation 04a106e [zsxwing] Remove NopCancellable and add a const NOP in object SettableCancellable 2a579f4 [zsxwing] Remove RpcEnv.systemName 155b987 [zsxwing] Change newURI to uriOf and add some comments 45b2317 [zsxwing] A standard RPC interface and An Akka implementation
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala42
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala59
-rw-r--r--core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala429
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala318
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala37
-rw-r--r--core/src/main/scala/org/apache/spark/util/AkkaUtils.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/RpcUtils.scala35
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala38
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala525
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala50
12 files changed, 1463 insertions, 85 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 2a0c7e756d..4a2ed82a40 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -34,12 +34,14 @@ import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.network.BlockTransferService
import org.apache.spark.network.netty.NettyBlockTransferService
import org.apache.spark.network.nio.NioBlockTransferService
+import org.apache.spark.rpc.{RpcEndpointRef, RpcEndpoint, RpcEnv}
+import org.apache.spark.rpc.akka.AkkaRpcEnv
import org.apache.spark.scheduler.{OutputCommitCoordinator, LiveListenerBus}
-import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorActor
+import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager}
import org.apache.spark.storage._
-import org.apache.spark.util.{AkkaUtils, Utils}
+import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils}
/**
* :: DeveloperApi ::
@@ -54,7 +56,7 @@ import org.apache.spark.util.{AkkaUtils, Utils}
@DeveloperApi
class SparkEnv (
val executorId: String,
- val actorSystem: ActorSystem,
+ private[spark] val rpcEnv: RpcEnv,
val serializer: Serializer,
val closureSerializer: Serializer,
val cacheManager: CacheManager,
@@ -71,6 +73,9 @@ class SparkEnv (
val outputCommitCoordinator: OutputCommitCoordinator,
val conf: SparkConf) extends Logging {
+ // TODO Remove actorSystem
+ val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem
+
private[spark] var isStopped = false
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
@@ -91,7 +96,8 @@ class SparkEnv (
blockManager.master.stop()
metricsSystem.stop()
outputCommitCoordinator.stop()
- actorSystem.shutdown()
+ rpcEnv.shutdown()
+
// Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut
// down, but let's call it anyway in case it gets fixed in a later release
// UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it.
@@ -236,16 +242,15 @@ object SparkEnv extends Logging {
val securityManager = new SecurityManager(conf)
// Create the ActorSystem for Akka and get the port it binds to.
- val (actorSystem, boundPort) = {
- val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName
- AkkaUtils.createActorSystem(actorSystemName, hostname, port, conf, securityManager)
- }
+ val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName
+ val rpcEnv = RpcEnv.create(actorSystemName, hostname, port, conf, securityManager)
+ val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem
// Figure out which port Akka actually bound to in case the original port is 0 or occupied.
if (isDriver) {
- conf.set("spark.driver.port", boundPort.toString)
+ conf.set("spark.driver.port", rpcEnv.address.port.toString)
} else {
- conf.set("spark.executor.port", boundPort.toString)
+ conf.set("spark.executor.port", rpcEnv.address.port.toString)
}
// Create an instance of the class with the given name, possibly initializing it with our conf
@@ -290,6 +295,15 @@ object SparkEnv extends Logging {
}
}
+ def registerOrLookupEndpoint(name: String, endpointCreator: => RpcEndpoint): RpcEndpointRef = {
+ if (isDriver) {
+ logInfo("Registering " + name)
+ rpcEnv.setupEndpoint(name, endpointCreator)
+ } else {
+ RpcUtils.makeDriverRef(name, conf, rpcEnv)
+ }
+ }
+
val mapOutputTracker = if (isDriver) {
new MapOutputTrackerMaster(conf)
} else {
@@ -377,13 +391,13 @@ object SparkEnv extends Logging {
val outputCommitCoordinator = mockOutputCommitCoordinator.getOrElse {
new OutputCommitCoordinator(conf)
}
- val outputCommitCoordinatorActor = registerOrLookup("OutputCommitCoordinator",
- new OutputCommitCoordinatorActor(outputCommitCoordinator))
- outputCommitCoordinator.coordinatorActor = Some(outputCommitCoordinatorActor)
+ val outputCommitCoordinatorRef = registerOrLookupEndpoint("OutputCommitCoordinator",
+ new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator))
+ outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef)
val envInstance = new SparkEnv(
executorId,
- actorSystem,
+ rpcEnv,
serializer,
closureSerializer,
cacheManager,
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
index deef6ef904..d1a12b01e7 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
@@ -19,10 +19,9 @@ package org.apache.spark.deploy.worker
import java.io.File
-import akka.actor._
-
import org.apache.spark.{SecurityManager, SparkConf}
-import org.apache.spark.util.{AkkaUtils, ChildFirstURLClassLoader, MutableURLClassLoader, Utils}
+import org.apache.spark.rpc.RpcEnv
+import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils}
/**
* Utility object for launching driver programs such that they share fate with the Worker process.
@@ -39,9 +38,9 @@ object DriverWrapper {
*/
case workerUrl :: userJar :: mainClass :: extraArgs =>
val conf = new SparkConf()
- val (actorSystem, _) = AkkaUtils.createActorSystem("Driver",
+ val rpcEnv = RpcEnv.create("Driver",
Utils.localHostName(), 0, conf, new SecurityManager(conf))
- actorSystem.actorOf(Props(classOf[WorkerWatcher], workerUrl), name = "workerWatcher")
+ rpcEnv.setupEndpoint("workerWatcher", new WorkerWatcher(rpcEnv, workerUrl))
val currentLoader = Thread.currentThread.getContextClassLoader
val userJarUrl = new File(userJar).toURI().toURL()
@@ -58,7 +57,7 @@ object DriverWrapper {
val mainMethod = clazz.getMethod("main", classOf[Array[String]])
mainMethod.invoke(null, extraArgs.toArray[String])
- actorSystem.shutdown()
+ rpcEnv.shutdown()
case _ =>
System.err.println("Usage: DriverWrapper <workerUrl> <userJar> <driverMainClass> [options]")
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
index e0790274d7..83fb991891 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
@@ -17,58 +17,63 @@
package org.apache.spark.deploy.worker
-import akka.actor.{Actor, Address, AddressFromURIString}
-import akka.remote.{AssociatedEvent, AssociationErrorEvent, AssociationEvent, DisassociatedEvent, RemotingLifecycleEvent}
-
import org.apache.spark.Logging
import org.apache.spark.deploy.DeployMessages.SendHeartbeat
-import org.apache.spark.util.ActorLogReceive
+import org.apache.spark.rpc._
/**
* Actor which connects to a worker process and terminates the JVM if the connection is severed.
* Provides fate sharing between a worker and its associated child processes.
*/
-private[spark] class WorkerWatcher(workerUrl: String)
- extends Actor with ActorLogReceive with Logging {
-
- override def preStart() {
- context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
+private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: String)
+ extends RpcEndpoint with Logging {
+ override def onStart() {
logInfo(s"Connecting to worker $workerUrl")
- val worker = context.actorSelection(workerUrl)
- worker ! SendHeartbeat // need to send a message here to initiate connection
+ if (!isTesting) {
+ rpcEnv.asyncSetupEndpointRefByURI(workerUrl)
+ }
}
// Used to avoid shutting down JVM during tests
+ // In the normal case, exitNonZero will call `System.exit(-1)` to shutdown the JVM. In the unit
+ // test, the user should call `setTesting(true)` so that `exitNonZero` will set `isShutDown` to
+ // true rather than calling `System.exit`. The user can check `isShutDown` to know if
+ // `exitNonZero` is called.
private[deploy] var isShutDown = false
private[deploy] def setTesting(testing: Boolean) = isTesting = testing
private var isTesting = false
// Lets us filter events only from the worker's actor system
- private val expectedHostPort = AddressFromURIString(workerUrl).hostPort
- private def isWorker(address: Address) = address.hostPort == expectedHostPort
+ private val expectedAddress = RpcAddress.fromURIString(workerUrl)
+ private def isWorker(address: RpcAddress) = expectedAddress == address
private def exitNonZero() = if (isTesting) isShutDown = true else System.exit(-1)
- override def receiveWithLogging: PartialFunction[Any, Unit] = {
- case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) =>
- logInfo(s"Successfully connected to $workerUrl")
+ override def receive: PartialFunction[Any, Unit] = {
+ case e => logWarning(s"Received unexpected message: $e")
+ }
- case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _)
- if isWorker(remoteAddress) =>
- // These logs may not be seen if the worker (and associated pipe) has died
- logError(s"Could not initialize connection to worker $workerUrl. Exiting.")
- logError(s"Error was: $cause")
- exitNonZero()
+ override def onConnected(remoteAddress: RpcAddress): Unit = {
+ if (isWorker(remoteAddress)) {
+ logInfo(s"Successfully connected to $workerUrl")
+ }
+ }
- case DisassociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) =>
+ override def onDisconnected(remoteAddress: RpcAddress): Unit = {
+ if (isWorker(remoteAddress)) {
// This log message will never be seen
logError(s"Lost connection to worker actor $workerUrl. Exiting.")
exitNonZero()
+ }
+ }
- case e: AssociationEvent =>
- // pass through association events relating to other remote actor systems
-
- case e => logWarning(s"Received unexpected actor system event: $e")
+ override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = {
+ if (isWorker(remoteAddress)) {
+ // These logs may not be seen if the worker (and associated pipe) has died
+ logError(s"Could not initialize connection to worker $workerUrl. Exiting.")
+ logError(s"Error was: $cause")
+ exitNonZero()
+ }
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index b5205d4e99..900e678ee0 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -169,7 +169,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
driverUrl, executorId, sparkHostPort, cores, userClassPath, env),
name = "Executor")
workerUrl.foreach { url =>
- env.actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher")
+ env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url))
}
env.actorSystem.awaitTermination()
}
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
new file mode 100644
index 0000000000..7985941d94
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
@@ -0,0 +1,429 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rpc
+
+import java.net.URI
+
+import scala.concurrent.{Await, Future}
+import scala.concurrent.duration._
+import scala.language.postfixOps
+import scala.reflect.ClassTag
+
+import org.apache.spark.{Logging, SparkException, SecurityManager, SparkConf}
+import org.apache.spark.util.{AkkaUtils, Utils}
+
+/**
+ * An RPC environment. [[RpcEndpoint]]s need to register itself with a name to [[RpcEnv]] to
+ * receives messages. Then [[RpcEnv]] will process messages sent from [[RpcEndpointRef]] or remote
+ * nodes, and deliver them to corresponding [[RpcEndpoint]]s.
+ *
+ * [[RpcEnv]] also provides some methods to retrieve [[RpcEndpointRef]]s given name or uri.
+ */
+private[spark] abstract class RpcEnv(conf: SparkConf) {
+
+ private[spark] val defaultLookupTimeout = AkkaUtils.lookupTimeout(conf)
+
+ /**
+ * Return RpcEndpointRef of the registered [[RpcEndpoint]]. Will be used to implement
+ * [[RpcEndpoint.self]].
+ *
+ * Note: This method won't return null. `IllegalArgumentException` will be thrown if calling this
+ * on a non-existent endpoint.
+ */
+ private[rpc] def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef
+
+ /**
+ * Return the address that [[RpcEnv]] is listening to.
+ */
+ def address: RpcAddress
+
+ /**
+ * Register a [[RpcEndpoint]] with a name and return its [[RpcEndpointRef]]. [[RpcEnv]] does not
+ * guarantee thread-safety.
+ */
+ def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef
+
+ /**
+ * Register a [[RpcEndpoint]] with a name and return its [[RpcEndpointRef]]. [[RpcEnv]] should
+ * make sure thread-safely sending messages to [[RpcEndpoint]].
+ *
+ * Thread-safety means processing of one message happens before processing of the next message by
+ * the same [[RpcEndpoint]]. In the other words, changes to internal fields of a [[RpcEndpoint]]
+ * are visible when processing the next message, and fields in the [[RpcEndpoint]] need not be
+ * volatile or equivalent.
+ *
+ * However, there is no guarantee that the same thread will be executing the same [[RpcEndpoint]]
+ * for different messages.
+ */
+ def setupThreadSafeEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef
+
+ /**
+ * Retrieve the [[RpcEndpointRef]] represented by `uri` asynchronously.
+ */
+ def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef]
+
+ /**
+ * Retrieve the [[RpcEndpointRef]] represented by `uri`. This is a blocking action.
+ */
+ def setupEndpointRefByURI(uri: String): RpcEndpointRef = {
+ Await.result(asyncSetupEndpointRefByURI(uri), defaultLookupTimeout)
+ }
+
+ /**
+ * Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName`
+ * asynchronously.
+ */
+ def asyncSetupEndpointRef(
+ systemName: String, address: RpcAddress, endpointName: String): Future[RpcEndpointRef] = {
+ asyncSetupEndpointRefByURI(uriOf(systemName, address, endpointName))
+ }
+
+ /**
+ * Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName`.
+ * This is a blocking action.
+ */
+ def setupEndpointRef(
+ systemName: String, address: RpcAddress, endpointName: String): RpcEndpointRef = {
+ setupEndpointRefByURI(uriOf(systemName, address, endpointName))
+ }
+
+ /**
+ * Stop [[RpcEndpoint]] specified by `endpoint`.
+ */
+ def stop(endpoint: RpcEndpointRef): Unit
+
+ /**
+ * Shutdown this [[RpcEnv]] asynchronously. If need to make sure [[RpcEnv]] exits successfully,
+ * call [[awaitTermination()]] straight after [[shutdown()]].
+ */
+ def shutdown(): Unit
+
+ /**
+ * Wait until [[RpcEnv]] exits.
+ *
+ * TODO do we need a timeout parameter?
+ */
+ def awaitTermination(): Unit
+
+ /**
+ * Create a URI used to create a [[RpcEndpointRef]]. Use this one to create the URI instead of
+ * creating it manually because different [[RpcEnv]] may have different formats.
+ */
+ def uriOf(systemName: String, address: RpcAddress, endpointName: String): String
+}
+
+private[spark] case class RpcEnvConfig(
+ conf: SparkConf,
+ name: String,
+ host: String,
+ port: Int,
+ securityManager: SecurityManager)
+
+/**
+ * A RpcEnv implementation must have a [[RpcEnvFactory]] implementation with an empty constructor
+ * so that it can be created via Reflection.
+ */
+private[spark] object RpcEnv {
+
+ private def getRpcEnvFactory(conf: SparkConf): RpcEnvFactory = {
+ // Add more RpcEnv implementations here
+ val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory")
+ val rpcEnvName = conf.get("spark.rpc", "akka")
+ val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName)
+ Class.forName(rpcEnvFactoryClassName, true, Utils.getContextOrSparkClassLoader).
+ newInstance().asInstanceOf[RpcEnvFactory]
+ }
+
+ def create(
+ name: String,
+ host: String,
+ port: Int,
+ conf: SparkConf,
+ securityManager: SecurityManager): RpcEnv = {
+ // Using Reflection to create the RpcEnv to avoid to depend on Akka directly
+ val config = RpcEnvConfig(conf, name, host, port, securityManager)
+ getRpcEnvFactory(conf).create(config)
+ }
+
+}
+
+/**
+ * A factory class to create the [[RpcEnv]]. It must have an empty constructor so that it can be
+ * created using Reflection.
+ */
+private[spark] trait RpcEnvFactory {
+
+ def create(config: RpcEnvConfig): RpcEnv
+}
+
+/**
+ * An end point for the RPC that defines what functions to trigger given a message.
+ *
+ * It is guaranteed that `onStart`, `receive` and `onStop` will be called in sequence.
+ *
+ * The lift-cycle will be:
+ *
+ * constructor onStart receive* onStop
+ *
+ * Note: `receive` can be called concurrently. If you want `receive` is thread-safe, please use
+ * [[RpcEnv.setupThreadSafeEndpoint]]
+ *
+ * If any error is thrown from one of [[RpcEndpoint]] methods except `onError`, `onError` will be
+ * invoked with the cause. If `onError` throws an error, [[RpcEnv]] will ignore it.
+ */
+private[spark] trait RpcEndpoint {
+
+ /**
+ * The [[RpcEnv]] that this [[RpcEndpoint]] is registered to.
+ */
+ val rpcEnv: RpcEnv
+
+ /**
+ * The [[RpcEndpointRef]] of this [[RpcEndpoint]]. `self` will become valid when `onStart` is
+ * called.
+ *
+ * Note: Because before `onStart`, [[RpcEndpoint]] has not yet been registered and there is not
+ * valid [[RpcEndpointRef]] for it. So don't call `self` before `onStart` is called.
+ */
+ final def self: RpcEndpointRef = {
+ require(rpcEnv != null, "rpcEnv has not been initialized")
+ rpcEnv.endpointRef(this)
+ }
+
+ /**
+ * Process messages from [[RpcEndpointRef.send]] or [[RpcCallContext.reply)]]. If receiving a
+ * unmatched message, [[SparkException]] will be thrown and sent to `onError`.
+ */
+ def receive: PartialFunction[Any, Unit] = {
+ case _ => throw new SparkException(self + " does not implement 'receive'")
+ }
+
+ /**
+ * Process messages from [[RpcEndpointRef.sendWithReply]]. If receiving a unmatched message,
+ * [[SparkException]] will be thrown and sent to `onError`.
+ */
+ def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ case _ => context.sendFailure(new SparkException(self + " won't reply anything"))
+ }
+
+ /**
+ * Call onError when any exception is thrown during handling messages.
+ *
+ * @param cause
+ */
+ def onError(cause: Throwable): Unit = {
+ // By default, throw e and let RpcEnv handle it
+ throw cause
+ }
+
+ /**
+ * Invoked before [[RpcEndpoint]] starts to handle any message.
+ */
+ def onStart(): Unit = {
+ // By default, do nothing.
+ }
+
+ /**
+ * Invoked when [[RpcEndpoint]] is stopping.
+ */
+ def onStop(): Unit = {
+ // By default, do nothing.
+ }
+
+ /**
+ * Invoked when `remoteAddress` is connected to the current node.
+ */
+ def onConnected(remoteAddress: RpcAddress): Unit = {
+ // By default, do nothing.
+ }
+
+ /**
+ * Invoked when `remoteAddress` is lost.
+ */
+ def onDisconnected(remoteAddress: RpcAddress): Unit = {
+ // By default, do nothing.
+ }
+
+ /**
+ * Invoked when some network error happens in the connection between the current node and
+ * `remoteAddress`.
+ */
+ def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = {
+ // By default, do nothing.
+ }
+
+ /**
+ * A convenient method to stop [[RpcEndpoint]].
+ */
+ final def stop(): Unit = {
+ val _self = self
+ if (_self != null) {
+ rpcEnv.stop(self)
+ }
+ }
+}
+
+/**
+ * A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe.
+ */
+private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf)
+ extends Serializable with Logging {
+
+ private[this] val maxRetries = conf.getInt("spark.akka.num.retries", 3)
+ private[this] val retryWaitMs = conf.getLong("spark.akka.retry.wait", 3000)
+ private[this] val defaultTimeout = conf.getLong("spark.akka.lookupTimeout", 30) seconds
+
+ /**
+ * return the address for the [[RpcEndpointRef]]
+ */
+ def address: RpcAddress
+
+ def name: String
+
+ /**
+ * Sends a one-way asynchronous message. Fire-and-forget semantics.
+ */
+ def send(message: Any): Unit
+
+ /**
+ * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a `Future` to
+ * receive the reply within a default timeout.
+ *
+ * This method only sends the message once and never retries.
+ */
+ def sendWithReply[T: ClassTag](message: Any): Future[T] = sendWithReply(message, defaultTimeout)
+
+ /**
+ * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a `Future` to
+ * receive the reply within the specified timeout.
+ *
+ * This method only sends the message once and never retries.
+ */
+ def sendWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T]
+
+ /**
+ * Send a message to the corresponding [[RpcEndpoint]] and get its result within a default
+ * timeout, or throw a SparkException if this fails even after the default number of retries.
+ * The default `timeout` will be used in every trial of calling `sendWithReply`. Because this
+ * method retries, the message handling in the receiver side should be idempotent.
+ *
+ * Note: this is a blocking action which may cost a lot of time, so don't call it in an message
+ * loop of [[RpcEndpoint]].
+ *
+ * @param message the message to send
+ * @tparam T type of the reply message
+ * @return the reply message from the corresponding [[RpcEndpoint]]
+ */
+ def askWithReply[T: ClassTag](message: Any): T = askWithReply(message, defaultTimeout)
+
+ /**
+ * Send a message to the corresponding [[RpcEndpoint.receive]] and get its result within a
+ * specified timeout, throw a SparkException if this fails even after the specified number of
+ * retries. `timeout` will be used in every trial of calling `sendWithReply`. Because this method
+ * retries, the message handling in the receiver side should be idempotent.
+ *
+ * Note: this is a blocking action which may cost a lot of time, so don't call it in an message
+ * loop of [[RpcEndpoint]].
+ *
+ * @param message the message to send
+ * @param timeout the timeout duration
+ * @tparam T type of the reply message
+ * @return the reply message from the corresponding [[RpcEndpoint]]
+ */
+ def askWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): T = {
+ // TODO: Consider removing multiple attempts
+ var attempts = 0
+ var lastException: Exception = null
+ while (attempts < maxRetries) {
+ attempts += 1
+ try {
+ val future = sendWithReply[T](message, timeout)
+ val result = Await.result(future, timeout)
+ if (result == null) {
+ throw new SparkException("Actor returned null")
+ }
+ return result
+ } catch {
+ case ie: InterruptedException => throw ie
+ case e: Exception =>
+ lastException = e
+ logWarning(s"Error sending message [message = $message] in $attempts attempts", e)
+ }
+ Thread.sleep(retryWaitMs)
+ }
+
+ throw new SparkException(
+ s"Error sending message [message = $message]", lastException)
+ }
+
+}
+
+/**
+ * Represent a host with a port
+ */
+private[spark] case class RpcAddress(host: String, port: Int) {
+ // TODO do we need to add the type of RpcEnv in the address?
+
+ val hostPort: String = host + ":" + port
+
+ override val toString: String = hostPort
+}
+
+private[spark] object RpcAddress {
+
+ /**
+ * Return the [[RpcAddress]] represented by `uri`.
+ */
+ def fromURI(uri: URI): RpcAddress = {
+ RpcAddress(uri.getHost, uri.getPort)
+ }
+
+ /**
+ * Return the [[RpcAddress]] represented by `uri`.
+ */
+ def fromURIString(uri: String): RpcAddress = {
+ fromURI(new java.net.URI(uri))
+ }
+
+ def fromSparkURL(sparkUrl: String): RpcAddress = {
+ val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl)
+ RpcAddress(host, port)
+ }
+}
+
+/**
+ * A callback that [[RpcEndpoint]] can use it to send back a message or failure.
+ */
+private[spark] trait RpcCallContext {
+
+ /**
+ * Reply a message to the sender. If the sender is [[RpcEndpoint]], its [[RpcEndpoint.receive]]
+ * will be called.
+ */
+ def reply(response: Any): Unit
+
+ /**
+ * Report a failure to the sender.
+ */
+ def sendFailure(e: Throwable): Unit
+
+ /**
+ * The sender of this message.
+ */
+ def sender: RpcEndpointRef
+}
diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
new file mode 100644
index 0000000000..769d59b7b3
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
@@ -0,0 +1,318 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rpc.akka
+
+import java.net.URI
+import java.util.concurrent.ConcurrentHashMap
+
+import scala.concurrent.{Await, Future}
+import scala.concurrent.duration._
+import scala.language.postfixOps
+import scala.reflect.ClassTag
+import scala.util.control.NonFatal
+
+import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Address}
+import akka.pattern.{ask => akkaAsk}
+import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent}
+import org.apache.spark.{SparkException, Logging, SparkConf}
+import org.apache.spark.rpc._
+import org.apache.spark.util.{ActorLogReceive, AkkaUtils}
+
+/**
+ * A RpcEnv implementation based on Akka.
+ *
+ * TODO Once we remove all usages of Akka in other place, we can move this file to a new project and
+ * remove Akka from the dependencies.
+ *
+ * @param actorSystem
+ * @param conf
+ * @param boundPort
+ */
+private[spark] class AkkaRpcEnv private[akka] (
+ val actorSystem: ActorSystem, conf: SparkConf, boundPort: Int)
+ extends RpcEnv(conf) with Logging {
+
+ private val defaultAddress: RpcAddress = {
+ val address = actorSystem.asInstanceOf[ExtendedActorSystem].provider.getDefaultAddress
+ // In some test case, ActorSystem doesn't bind to any address.
+ // So just use some default value since they are only some unit tests
+ RpcAddress(address.host.getOrElse("localhost"), address.port.getOrElse(boundPort))
+ }
+
+ override val address: RpcAddress = defaultAddress
+
+ /**
+ * A lookup table to search a [[RpcEndpointRef]] for a [[RpcEndpoint]]. We need it to make
+ * [[RpcEndpoint.self]] work.
+ */
+ private val endpointToRef = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]()
+
+ /**
+ * Need this map to remove `RpcEndpoint` from `endpointToRef` via a `RpcEndpointRef`
+ */
+ private val refToEndpoint = new ConcurrentHashMap[RpcEndpointRef, RpcEndpoint]()
+
+ private def registerEndpoint(endpoint: RpcEndpoint, endpointRef: RpcEndpointRef): Unit = {
+ endpointToRef.put(endpoint, endpointRef)
+ refToEndpoint.put(endpointRef, endpoint)
+ }
+
+ private def unregisterEndpoint(endpointRef: RpcEndpointRef): Unit = {
+ val endpoint = refToEndpoint.remove(endpointRef)
+ if (endpoint != null) {
+ endpointToRef.remove(endpoint)
+ }
+ }
+
+ /**
+ * Retrieve the [[RpcEndpointRef]] of `endpoint`.
+ */
+ override def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = {
+ val endpointRef = endpointToRef.get(endpoint)
+ require(endpointRef != null, s"Cannot find RpcEndpointRef of ${endpoint} in ${this}")
+ endpointRef
+ }
+
+ override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = {
+ setupThreadSafeEndpoint(name, endpoint)
+ }
+
+ override def setupThreadSafeEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = {
+ @volatile var endpointRef: AkkaRpcEndpointRef = null
+ // Use lazy because the Actor needs to use `endpointRef`.
+ // So `actorRef` should be created after assigning `endpointRef`.
+ lazy val actorRef = actorSystem.actorOf(Props(new Actor with ActorLogReceive with Logging {
+
+ assert(endpointRef != null)
+
+ override def preStart(): Unit = {
+ // Listen for remote client network events
+ context.system.eventStream.subscribe(self, classOf[AssociationEvent])
+ safelyCall(endpoint) {
+ endpoint.onStart()
+ }
+ }
+
+ override def receiveWithLogging: Receive = {
+ case AssociatedEvent(_, remoteAddress, _) =>
+ safelyCall(endpoint) {
+ endpoint.onConnected(akkaAddressToRpcAddress(remoteAddress))
+ }
+
+ case DisassociatedEvent(_, remoteAddress, _) =>
+ safelyCall(endpoint) {
+ endpoint.onDisconnected(akkaAddressToRpcAddress(remoteAddress))
+ }
+
+ case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) =>
+ safelyCall(endpoint) {
+ endpoint.onNetworkError(cause, akkaAddressToRpcAddress(remoteAddress))
+ }
+
+ case e: AssociationEvent =>
+ // TODO ignore?
+
+ case m: AkkaMessage =>
+ logDebug(s"Received RPC message: $m")
+ safelyCall(endpoint) {
+ processMessage(endpoint, m, sender)
+ }
+
+ case AkkaFailure(e) =>
+ safelyCall(endpoint) {
+ throw e
+ }
+
+ case message: Any => {
+ logWarning(s"Unknown message: $message")
+ }
+
+ }
+
+ override def postStop(): Unit = {
+ unregisterEndpoint(endpoint.self)
+ safelyCall(endpoint) {
+ endpoint.onStop()
+ }
+ }
+
+ }), name = name)
+ endpointRef = new AkkaRpcEndpointRef(defaultAddress, actorRef, conf, initInConstructor = false)
+ registerEndpoint(endpoint, endpointRef)
+ // Now actorRef can be created safely
+ endpointRef.init()
+ endpointRef
+ }
+
+ private def processMessage(endpoint: RpcEndpoint, m: AkkaMessage, _sender: ActorRef): Unit = {
+ val message = m.message
+ val needReply = m.needReply
+ val pf: PartialFunction[Any, Unit] =
+ if (needReply) {
+ endpoint.receiveAndReply(new RpcCallContext {
+ override def sendFailure(e: Throwable): Unit = {
+ _sender ! AkkaFailure(e)
+ }
+
+ override def reply(response: Any): Unit = {
+ _sender ! AkkaMessage(response, false)
+ }
+
+ // Some RpcEndpoints need to know the sender's address
+ override val sender: RpcEndpointRef =
+ new AkkaRpcEndpointRef(defaultAddress, _sender, conf)
+ })
+ } else {
+ endpoint.receive
+ }
+ try {
+ pf.applyOrElse[Any, Unit](message, { message =>
+ throw new SparkException(s"Unmatched message $message from ${_sender}")
+ })
+ } catch {
+ case NonFatal(e) =>
+ if (needReply) {
+ // If the sender asks a reply, we should send the error back to the sender
+ _sender ! AkkaFailure(e)
+ } else {
+ throw e
+ }
+ }
+ }
+
+ /**
+ * Run `action` safely to avoid to crash the thread. If any non-fatal exception happens, it will
+ * call `endpoint.onError`. If `endpoint.onError` throws any non-fatal exception, just log it.
+ */
+ private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = {
+ try {
+ action
+ } catch {
+ case NonFatal(e) => {
+ try {
+ endpoint.onError(e)
+ } catch {
+ case NonFatal(e) => logError(s"Ignore error: ${e.getMessage}", e)
+ }
+ }
+ }
+ }
+
+ private def akkaAddressToRpcAddress(address: Address): RpcAddress = {
+ RpcAddress(address.host.getOrElse(defaultAddress.host),
+ address.port.getOrElse(defaultAddress.port))
+ }
+
+ override def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = {
+ import actorSystem.dispatcher
+ actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout).
+ map(new AkkaRpcEndpointRef(defaultAddress, _, conf))
+ }
+
+ override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = {
+ AkkaUtils.address(
+ AkkaUtils.protocol(actorSystem), systemName, address.host, address.port, endpointName)
+ }
+
+ override def shutdown(): Unit = {
+ actorSystem.shutdown()
+ }
+
+ override def stop(endpoint: RpcEndpointRef): Unit = {
+ require(endpoint.isInstanceOf[AkkaRpcEndpointRef])
+ actorSystem.stop(endpoint.asInstanceOf[AkkaRpcEndpointRef].actorRef)
+ }
+
+ override def awaitTermination(): Unit = {
+ actorSystem.awaitTermination()
+ }
+
+ override def toString: String = s"${getClass.getSimpleName}($actorSystem)"
+}
+
+private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory {
+
+ def create(config: RpcEnvConfig): RpcEnv = {
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem(
+ config.name, config.host, config.port, config.conf, config.securityManager)
+ new AkkaRpcEnv(actorSystem, config.conf, boundPort)
+ }
+}
+
+private[akka] class AkkaRpcEndpointRef(
+ @transient defaultAddress: RpcAddress,
+ @transient _actorRef: => ActorRef,
+ @transient conf: SparkConf,
+ @transient initInConstructor: Boolean = true)
+ extends RpcEndpointRef(conf) with Logging {
+
+ lazy val actorRef = _actorRef
+
+ override lazy val address: RpcAddress = {
+ val akkaAddress = actorRef.path.address
+ RpcAddress(akkaAddress.host.getOrElse(defaultAddress.host),
+ akkaAddress.port.getOrElse(defaultAddress.port))
+ }
+
+ override lazy val name: String = actorRef.path.name
+
+ private[akka] def init(): Unit = {
+ // Initialize the lazy vals
+ actorRef
+ address
+ name
+ }
+
+ if (initInConstructor) {
+ init()
+ }
+
+ override def send(message: Any): Unit = {
+ actorRef ! AkkaMessage(message, false)
+ }
+
+ override def sendWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] = {
+ import scala.concurrent.ExecutionContext.Implicits.global
+ actorRef.ask(AkkaMessage(message, true))(timeout).flatMap {
+ case msg @ AkkaMessage(message, reply) =>
+ if (reply) {
+ logError(s"Receive $msg but the sender cannot reply")
+ Future.failed(new SparkException(s"Receive $msg but the sender cannot reply"))
+ } else {
+ Future.successful(message)
+ }
+ case AkkaFailure(e) =>
+ Future.failed(e)
+ }.mapTo[T]
+ }
+
+ override def toString: String = s"${getClass.getSimpleName}($actorRef)"
+
+}
+
+/**
+ * A wrapper to `message` so that the receiver knows if the sender expects a reply.
+ * @param message
+ * @param needReply if the sender expects a reply message
+ */
+private[akka] case class AkkaMessage(message: Any, needReply: Boolean)
+
+/**
+ * A reply with the failure error from the receiver to the sender
+ */
+private[akka] case class AkkaFailure(e: Throwable)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala
index a3caa9f000..f748f394d1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala
@@ -19,10 +19,8 @@ package org.apache.spark.scheduler
import scala.collection.mutable
-import akka.actor.{ActorRef, Actor}
-
import org.apache.spark._
-import org.apache.spark.util.{AkkaUtils, ActorLogReceive}
+import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, RpcEndpoint}
private sealed trait OutputCommitCoordinationMessage extends Serializable
@@ -34,8 +32,8 @@ private case class AskPermissionToCommitOutput(stage: Int, task: Long, taskAttem
* policy.
*
* OutputCommitCoordinator is instantiated in both the drivers and executors. On executors, it is
- * configured with a reference to the driver's OutputCommitCoordinatorActor, so requests to commit
- * output will be forwarded to the driver's OutputCommitCoordinator.
+ * configured with a reference to the driver's OutputCommitCoordinatorEndpoint, so requests to
+ * commit output will be forwarded to the driver's OutputCommitCoordinator.
*
* This class was introduced in SPARK-4879; see that JIRA issue (and the associated pull requests)
* for an extensive design discussion.
@@ -43,10 +41,7 @@ private case class AskPermissionToCommitOutput(stage: Int, task: Long, taskAttem
private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging {
// Initialized by SparkEnv
- var coordinatorActor: Option[ActorRef] = None
- private val timeout = AkkaUtils.askTimeout(conf)
- private val maxAttempts = AkkaUtils.numRetries(conf)
- private val retryInterval = AkkaUtils.retryWaitMs(conf)
+ var coordinatorRef: Option[RpcEndpointRef] = None
private type StageId = Int
private type PartitionId = Long
@@ -81,9 +76,9 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging {
partition: PartitionId,
attempt: TaskAttemptId): Boolean = {
val msg = AskPermissionToCommitOutput(stage, partition, attempt)
- coordinatorActor match {
- case Some(actor) =>
- AkkaUtils.askWithReply[Boolean](msg, actor, maxAttempts, retryInterval, timeout)
+ coordinatorRef match {
+ case Some(endpointRef) =>
+ endpointRef.askWithReply[Boolean](msg)
case None =>
logError(
"canCommit called after coordinator was stopped (is SparkEnv shutdown in progress)?")
@@ -125,8 +120,8 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging {
}
def stop(): Unit = synchronized {
- coordinatorActor.foreach(_ ! StopCoordinator)
- coordinatorActor = None
+ coordinatorRef.foreach(_ send StopCoordinator)
+ coordinatorRef = None
authorizedCommittersByStage.clear()
}
@@ -157,16 +152,18 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging {
private[spark] object OutputCommitCoordinator {
// This actor is used only for RPC
- class OutputCommitCoordinatorActor(outputCommitCoordinator: OutputCommitCoordinator)
- extends Actor with ActorLogReceive with Logging {
+ private[spark] class OutputCommitCoordinatorEndpoint(
+ override val rpcEnv: RpcEnv, outputCommitCoordinator: OutputCommitCoordinator)
+ extends RpcEndpoint with Logging {
- override def receiveWithLogging: PartialFunction[Any, Unit] = {
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case AskPermissionToCommitOutput(stage, partition, taskAttempt) =>
- sender ! outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, taskAttempt)
+ context.reply(
+ outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, taskAttempt))
case StopCoordinator =>
logInfo("OutputCommitCoordinator stopped!")
- context.stop(self)
- sender ! true
+ context.reply(true)
+ stop()
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
index 48a6ede05e..6c2c526130 100644
--- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
@@ -179,7 +179,7 @@ private[spark] object AkkaUtils extends Logging {
message: Any,
actor: ActorRef,
maxAttempts: Int,
- retryInterval: Int,
+ retryInterval: Long,
timeout: FiniteDuration): T = {
// TODO: Consider removing multiple attempts
if (actor == null) {
diff --git a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala
new file mode 100644
index 0000000000..6665b17c3d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import org.apache.spark.{SparkEnv, SparkConf}
+import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv}
+
+object RpcUtils {
+
+ /**
+ * Retrieve a [[RpcEndpointRef]] which is located in the driver via its name.
+ */
+ def makeDriverRef(name: String, conf: SparkConf, rpcEnv: RpcEnv): RpcEndpointRef = {
+ val driverActorSystemName = SparkEnv.driverActorSystemName
+ val driverHost: String = conf.get("spark.driver.host", "localhost")
+ val driverPort: Int = conf.getInt("spark.driver.port", 7077)
+ Utils.checkHost(driverHost, "Expected hostname")
+ rpcEnv.setupEndpointRef(driverActorSystemName, RpcAddress(driverHost, driverPort), name)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala
index 5e538d6fab..6a6f29dd61 100644
--- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala
@@ -17,32 +17,38 @@
package org.apache.spark.deploy.worker
-import akka.actor.{ActorSystem, AddressFromURIString, Props}
-import akka.testkit.TestActorRef
-import akka.remote.DisassociatedEvent
+import akka.actor.AddressFromURIString
+import org.apache.spark.SparkConf
+import org.apache.spark.SecurityManager
+import org.apache.spark.rpc.{RpcAddress, RpcEnv}
import org.scalatest.FunSuite
class WorkerWatcherSuite extends FunSuite {
test("WorkerWatcher shuts down on valid disassociation") {
- val actorSystem = ActorSystem("test")
- val targetWorkerUrl = "akka://1.2.3.4/user/Worker"
+ val conf = new SparkConf()
+ val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf))
+ val targetWorkerUrl = "akka://test@1.2.3.4:1234/user/Worker"
val targetWorkerAddress = AddressFromURIString(targetWorkerUrl)
- val actorRef = TestActorRef[WorkerWatcher](Props(classOf[WorkerWatcher], targetWorkerUrl))(actorSystem)
- val workerWatcher = actorRef.underlyingActor
+ val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl)
workerWatcher.setTesting(testing = true)
- actorRef.underlyingActor.receive(new DisassociatedEvent(null, targetWorkerAddress, false))
- assert(actorRef.underlyingActor.isShutDown)
+ rpcEnv.setupEndpoint("worker-watcher", workerWatcher)
+ workerWatcher.onDisconnected(
+ RpcAddress(targetWorkerAddress.host.get, targetWorkerAddress.port.get))
+ assert(workerWatcher.isShutDown)
+ rpcEnv.shutdown()
}
test("WorkerWatcher stays alive on invalid disassociation") {
- val actorSystem = ActorSystem("test")
- val targetWorkerUrl = "akka://1.2.3.4/user/Worker"
- val otherAkkaURL = "akka://4.3.2.1/user/OtherActor"
+ val conf = new SparkConf()
+ val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf))
+ val targetWorkerUrl = "akka://test@1.2.3.4:1234/user/Worker"
+ val otherAkkaURL = "akka://test@4.3.2.1:1234/user/OtherActor"
val otherAkkaAddress = AddressFromURIString(otherAkkaURL)
- val actorRef = TestActorRef[WorkerWatcher](Props(classOf[WorkerWatcher], targetWorkerUrl))(actorSystem)
- val workerWatcher = actorRef.underlyingActor
+ val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl)
workerWatcher.setTesting(testing = true)
- actorRef.underlyingActor.receive(new DisassociatedEvent(null, otherAkkaAddress, false))
- assert(!actorRef.underlyingActor.isShutDown)
+ rpcEnv.setupEndpoint("worker-watcher", workerWatcher)
+ workerWatcher.onDisconnected(RpcAddress(otherAkkaAddress.host.get, otherAkkaAddress.port.get))
+ assert(!workerWatcher.isShutDown)
+ rpcEnv.shutdown()
}
}
diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
new file mode 100644
index 0000000000..e07bdb9637
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
@@ -0,0 +1,525 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rpc
+
+import java.util.concurrent.{TimeUnit, CountDownLatch, TimeoutException}
+
+import scala.collection.mutable
+import scala.concurrent.Await
+import scala.concurrent.duration._
+import scala.language.postfixOps
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.scalatest.concurrent.Eventually._
+
+import org.apache.spark.{SparkException, SparkConf}
+
+/**
+ * Common tests for an RpcEnv implementation.
+ */
+abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll {
+
+ var env: RpcEnv = _
+
+ override def beforeAll(): Unit = {
+ val conf = new SparkConf()
+ env = createRpcEnv(conf, "local", 12345)
+ }
+
+ override def afterAll(): Unit = {
+ if(env != null) {
+ env.shutdown()
+ }
+ }
+
+ def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv
+
+ test("send a message locally") {
+ @volatile var message: String = null
+ val rpcEndpointRef = env.setupEndpoint("send-locally", new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receive = {
+ case msg: String => message = msg
+ }
+ })
+ rpcEndpointRef.send("hello")
+ eventually(timeout(5 seconds), interval(10 millis)) {
+ assert("hello" === message)
+ }
+ }
+
+ test("send a message remotely") {
+ @volatile var message: String = null
+ // Set up a RpcEndpoint using env
+ env.setupEndpoint("send-remotely", new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receive = {
+ case msg: String => message = msg
+ }
+ })
+
+ val anotherEnv = createRpcEnv(new SparkConf(), "remote" ,13345)
+ // Use anotherEnv to find out the RpcEndpointRef
+ val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "send-remotely")
+ try {
+ rpcEndpointRef.send("hello")
+ eventually(timeout(5 seconds), interval(10 millis)) {
+ assert("hello" === message)
+ }
+ } finally {
+ anotherEnv.shutdown()
+ anotherEnv.awaitTermination()
+ }
+ }
+
+ test("send a RpcEndpointRef") {
+ val endpoint = new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receiveAndReply(context: RpcCallContext) = {
+ case "Hello" => context.reply(self)
+ case "Echo" => context.reply("Echo")
+ }
+ }
+ val rpcEndpointRef = env.setupEndpoint("send-ref", endpoint)
+
+ val newRpcEndpointRef = rpcEndpointRef.askWithReply[RpcEndpointRef]("Hello")
+ val reply = newRpcEndpointRef.askWithReply[String]("Echo")
+ assert("Echo" === reply)
+ }
+
+ test("ask a message locally") {
+ val rpcEndpointRef = env.setupEndpoint("ask-locally", new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receiveAndReply(context: RpcCallContext) = {
+ case msg: String => {
+ context.reply(msg)
+ }
+ }
+ })
+ val reply = rpcEndpointRef.askWithReply[String]("hello")
+ assert("hello" === reply)
+ }
+
+ test("ask a message remotely") {
+ env.setupEndpoint("ask-remotely", new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receiveAndReply(context: RpcCallContext) = {
+ case msg: String => {
+ context.reply(msg)
+ }
+ }
+ })
+
+ val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345)
+ // Use anotherEnv to find out the RpcEndpointRef
+ val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-remotely")
+ try {
+ val reply = rpcEndpointRef.askWithReply[String]("hello")
+ assert("hello" === reply)
+ } finally {
+ anotherEnv.shutdown()
+ anotherEnv.awaitTermination()
+ }
+ }
+
+ test("ask a message timeout") {
+ env.setupEndpoint("ask-timeout", new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receiveAndReply(context: RpcCallContext) = {
+ case msg: String => {
+ Thread.sleep(100)
+ context.reply(msg)
+ }
+ }
+ })
+
+ val conf = new SparkConf()
+ conf.set("spark.akka.retry.wait", "0")
+ conf.set("spark.akka.num.retries", "1")
+ val anotherEnv = createRpcEnv(conf, "remote", 13345)
+ // Use anotherEnv to find out the RpcEndpointRef
+ val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-timeout")
+ try {
+ val e = intercept[Exception] {
+ rpcEndpointRef.askWithReply[String]("hello", 1 millis)
+ }
+ assert(e.isInstanceOf[TimeoutException] || e.getCause.isInstanceOf[TimeoutException])
+ } finally {
+ anotherEnv.shutdown()
+ anotherEnv.awaitTermination()
+ }
+ }
+
+ test("onStart and onStop") {
+ val stopLatch = new CountDownLatch(1)
+ val calledMethods = mutable.ArrayBuffer[String]()
+
+ val endpoint = new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def onStart(): Unit = {
+ calledMethods += "start"
+ }
+
+ override def receive = {
+ case msg: String =>
+ }
+
+ override def onStop(): Unit = {
+ calledMethods += "stop"
+ stopLatch.countDown()
+ }
+ }
+ val rpcEndpointRef = env.setupEndpoint("start-stop-test", endpoint)
+ env.stop(rpcEndpointRef)
+ stopLatch.await(10, TimeUnit.SECONDS)
+ assert(List("start", "stop") === calledMethods)
+ }
+
+ test("onError: error in onStart") {
+ @volatile var e: Throwable = null
+ env.setupEndpoint("onError-onStart", new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def onStart(): Unit = {
+ throw new RuntimeException("Oops!")
+ }
+
+ override def receive = {
+ case m =>
+ }
+
+ override def onError(cause: Throwable): Unit = {
+ e = cause
+ }
+ })
+
+ eventually(timeout(5 seconds), interval(10 millis)) {
+ assert(e.getMessage === "Oops!")
+ }
+ }
+
+ test("onError: error in onStop") {
+ @volatile var e: Throwable = null
+ val endpointRef = env.setupEndpoint("onError-onStop", new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receive = {
+ case m =>
+ }
+
+ override def onError(cause: Throwable): Unit = {
+ e = cause
+ }
+
+ override def onStop(): Unit = {
+ throw new RuntimeException("Oops!")
+ }
+ })
+
+ env.stop(endpointRef)
+
+ eventually(timeout(5 seconds), interval(10 millis)) {
+ assert(e.getMessage === "Oops!")
+ }
+ }
+
+ test("onError: error in receive") {
+ @volatile var e: Throwable = null
+ val endpointRef = env.setupEndpoint("onError-receive", new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receive = {
+ case m => throw new RuntimeException("Oops!")
+ }
+
+ override def onError(cause: Throwable): Unit = {
+ e = cause
+ }
+ })
+
+ endpointRef.send("Foo")
+
+ eventually(timeout(5 seconds), interval(10 millis)) {
+ assert(e.getMessage === "Oops!")
+ }
+ }
+
+ test("self: call in onStart") {
+ @volatile var callSelfSuccessfully = false
+
+ env.setupEndpoint("self-onStart", new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def onStart(): Unit = {
+ self
+ callSelfSuccessfully = true
+ }
+
+ override def receive = {
+ case m =>
+ }
+ })
+
+ eventually(timeout(5 seconds), interval(10 millis)) {
+ // Calling `self` in `onStart` is fine
+ assert(callSelfSuccessfully === true)
+ }
+ }
+
+ test("self: call in receive") {
+ @volatile var callSelfSuccessfully = false
+
+ val endpointRef = env.setupEndpoint("self-receive", new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receive = {
+ case m => {
+ self
+ callSelfSuccessfully = true
+ }
+ }
+ })
+
+ endpointRef.send("Foo")
+
+ eventually(timeout(5 seconds), interval(10 millis)) {
+ // Calling `self` in `receive` is fine
+ assert(callSelfSuccessfully === true)
+ }
+ }
+
+ test("self: call in onStop") {
+ @volatile var e: Throwable = null
+
+ val endpointRef = env.setupEndpoint("self-onStop", new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receive = {
+ case m =>
+ }
+
+ override def onStop(): Unit = {
+ self
+ }
+
+ override def onError(cause: Throwable): Unit = {
+ e = cause
+ }
+ })
+
+ env.stop(endpointRef)
+
+ eventually(timeout(5 seconds), interval(10 millis)) {
+ // Calling `self` in `onStop` is invalid
+ assert(e != null)
+ assert(e.getMessage.contains("Cannot find RpcEndpointRef"))
+ }
+ }
+
+ test("call receive in sequence") {
+ // If a RpcEnv implementation breaks the `receive` contract, hope this test can expose it
+ for(i <- 0 until 100) {
+ @volatile var result = 0
+ val endpointRef = env.setupThreadSafeEndpoint(s"receive-in-sequence-$i", new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receive = {
+ case m => result += 1
+ }
+
+ })
+
+ (0 until 10) foreach { _ =>
+ new Thread {
+ override def run() {
+ (0 until 100) foreach { _ =>
+ endpointRef.send("Hello")
+ }
+ }
+ }.start()
+ }
+
+ eventually(timeout(5 seconds), interval(5 millis)) {
+ assert(result == 1000)
+ }
+
+ env.stop(endpointRef)
+ }
+ }
+
+ test("stop(RpcEndpointRef) reentrant") {
+ @volatile var onStopCount = 0
+ val endpointRef = env.setupEndpoint("stop-reentrant", new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receive = {
+ case m =>
+ }
+
+ override def onStop(): Unit = {
+ onStopCount += 1
+ }
+ })
+
+ env.stop(endpointRef)
+ env.stop(endpointRef)
+
+ eventually(timeout(5 seconds), interval(5 millis)) {
+ // Calling stop twice should only trigger onStop once.
+ assert(onStopCount == 1)
+ }
+ }
+
+ test("sendWithReply") {
+ val endpointRef = env.setupEndpoint("sendWithReply", new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receiveAndReply(context: RpcCallContext) = {
+ case m => context.reply("ack")
+ }
+ })
+
+ val f = endpointRef.sendWithReply[String]("Hi")
+ val ack = Await.result(f, 5 seconds)
+ assert("ack" === ack)
+
+ env.stop(endpointRef)
+ }
+
+ test("sendWithReply: remotely") {
+ env.setupEndpoint("sendWithReply-remotely", new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receiveAndReply(context: RpcCallContext) = {
+ case m => context.reply("ack")
+ }
+ })
+
+ val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345)
+ // Use anotherEnv to find out the RpcEndpointRef
+ val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "sendWithReply-remotely")
+ try {
+ val f = rpcEndpointRef.sendWithReply[String]("hello")
+ val ack = Await.result(f, 5 seconds)
+ assert("ack" === ack)
+ } finally {
+ anotherEnv.shutdown()
+ anotherEnv.awaitTermination()
+ }
+ }
+
+ test("sendWithReply: error") {
+ val endpointRef = env.setupEndpoint("sendWithReply-error", new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receiveAndReply(context: RpcCallContext) = {
+ case m => context.sendFailure(new SparkException("Oops"))
+ }
+ })
+
+ val f = endpointRef.sendWithReply[String]("Hi")
+ val e = intercept[SparkException] {
+ Await.result(f, 5 seconds)
+ }
+ assert("Oops" === e.getMessage)
+
+ env.stop(endpointRef)
+ }
+
+ test("sendWithReply: remotely error") {
+ env.setupEndpoint("sendWithReply-remotely-error", new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receiveAndReply(context: RpcCallContext) = {
+ case msg: String => context.sendFailure(new SparkException("Oops"))
+ }
+ })
+
+ val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345)
+ // Use anotherEnv to find out the RpcEndpointRef
+ val rpcEndpointRef = anotherEnv.setupEndpointRef(
+ "local", env.address, "sendWithReply-remotely-error")
+ try {
+ val f = rpcEndpointRef.sendWithReply[String]("hello")
+ val e = intercept[SparkException] {
+ Await.result(f, 5 seconds)
+ }
+ assert("Oops" === e.getMessage)
+ } finally {
+ anotherEnv.shutdown()
+ anotherEnv.awaitTermination()
+ }
+ }
+
+ test("network events") {
+ val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)]
+ env.setupThreadSafeEndpoint("network-events", new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receive = {
+ case "hello" =>
+ case m => events += "receive" -> m
+ }
+
+ override def onConnected(remoteAddress: RpcAddress): Unit = {
+ events += "onConnected" -> remoteAddress
+ }
+
+ override def onDisconnected(remoteAddress: RpcAddress): Unit = {
+ events += "onDisconnected" -> remoteAddress
+ }
+
+ override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = {
+ events += "onNetworkError" -> remoteAddress
+ }
+
+ })
+
+ val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345)
+ // Use anotherEnv to find out the RpcEndpointRef
+ val rpcEndpointRef = anotherEnv.setupEndpointRef(
+ "local", env.address, "network-events")
+ val remoteAddress = anotherEnv.address
+ rpcEndpointRef.send("hello")
+ eventually(timeout(5 seconds), interval(5 millis)) {
+ assert(events === List(("onConnected", remoteAddress)))
+ }
+
+ anotherEnv.shutdown()
+ anotherEnv.awaitTermination()
+ eventually(timeout(5 seconds), interval(5 millis)) {
+ assert(events === List(
+ ("onConnected", remoteAddress),
+ ("onNetworkError", remoteAddress),
+ ("onDisconnected", remoteAddress)))
+ }
+ }
+}
+
+case object Start
+
+case class Ping(id: Int)
+
+case class Pong(id: Int)
diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala
new file mode 100644
index 0000000000..58214c0637
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rpc.akka
+
+import org.apache.spark.rpc._
+import org.apache.spark.{SecurityManager, SparkConf}
+
+class AkkaRpcEnvSuite extends RpcEnvSuite {
+
+ override def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv = {
+ new AkkaRpcEnvFactory().create(
+ RpcEnvConfig(conf, name, "localhost", port, new SecurityManager(conf)))
+ }
+
+ test("setupEndpointRef: systemName, address, endpointName") {
+ val ref = env.setupEndpoint("test_endpoint", new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receive = {
+ case _ =>
+ }
+ })
+ val conf = new SparkConf()
+ val newRpcEnv = new AkkaRpcEnvFactory().create(
+ RpcEnvConfig(conf, "test", "localhost", 12346, new SecurityManager(conf)))
+ try {
+ val newRef = newRpcEnv.setupEndpointRef("local", ref.address, "test_endpoint")
+ assert("akka.tcp://local@localhost:12345/user/test_endpoint" ===
+ newRef.asInstanceOf[AkkaRpcEndpointRef].actorRef.path.toString)
+ } finally {
+ newRpcEnv.shutdown()
+ }
+ }
+
+}