aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorzsxwing <zsxwing@gmail.com>2015-06-30 17:39:55 -0700
committerReynold Xin <rxin@databricks.com>2015-06-30 17:39:55 -0700
commit3bee0f1466ddd69f26e95297b5e0d2398b6c6268 (patch)
tree83309df1a139a49f1821e54344d551581b210cf8 /core
parentccdb05222a223187199183fd48e3a3313d536965 (diff)
downloadspark-3bee0f1466ddd69f26e95297b5e0d2398b6c6268.tar.gz
spark-3bee0f1466ddd69f26e95297b5e0d2398b6c6268.tar.bz2
spark-3bee0f1466ddd69f26e95297b5e0d2398b6c6268.zip
[SPARK-6602][Core] Update Master, Worker, Client, AppClient and related classes to use RpcEndpoint
This PR updates the rest Actors in core to RpcEndpoint. Because there is no `ActorSelection` in RpcEnv, I changes the logic of `registerWithMaster` in Worker and AppClient to avoid blocking the message loop. These changes need to be reviewed carefully. Author: zsxwing <zsxwing@gmail.com> Closes #5392 from zsxwing/rpc-rewrite-part3 and squashes the following commits: 2de7bed [zsxwing] Merge branch 'master' into rpc-rewrite-part3 f12d943 [zsxwing] Address comments 9137b82 [zsxwing] Fix the code style e734c71 [zsxwing] Merge branch 'master' into rpc-rewrite-part3 2d24fb5 [zsxwing] Fix the code style 5a82374 [zsxwing] Merge branch 'master' into rpc-rewrite-part3 fa47110 [zsxwing] Merge branch 'master' into rpc-rewrite-part3 72304f0 [zsxwing] Update the error strategy for AkkaRpcEnv e56cb16 [zsxwing] Always send failure back to the sender a7b86e6 [zsxwing] Use JFuture for java.util.concurrent.Future aa34b9b [zsxwing] Fix the code style bd541e7 [zsxwing] Merge branch 'master' into rpc-rewrite-part3 25a84d8 [zsxwing] Use ThreadUtils 060ff31 [zsxwing] Merge branch 'master' into rpc-rewrite-part3 dbfc916 [zsxwing] Improve the docs and comments 837927e [zsxwing] Merge branch 'master' into rpc-rewrite-part3 5c27f97 [zsxwing] Merge branch 'master' into rpc-rewrite-part3 fadbb9e [zsxwing] Fix the code style 6637e3c [zsxwing] Merge remote-tracking branch 'origin/master' into rpc-rewrite-part3 7fdee0e [zsxwing] Fix the return type to ExecutorService and ScheduledExecutorService e8ad0a5 [zsxwing] Fix the code style 6b2a104 [zsxwing] Log error and use SparkExitCode.UNCAUGHT_EXCEPTION exit code fbf3194 [zsxwing] Add Utils.newDaemonSingleThreadExecutor and newDaemonSingleThreadScheduledExecutor b776817 [zsxwing] Update Master, Worker, Client, AppClient and related classes to use RpcEndpoint
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/Client.scala156
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala22
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala26
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala199
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/Master.scala392
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala35
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala8
-rwxr-xr-xcore/src/main/scala/org/apache/spark/deploy/worker/Worker.scala318
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala56
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala54
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala15
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/RpcAddressSuite.scala55
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala20
27 files changed, 806 insertions, 633 deletions
diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala
index 848b62f9de..71f7e21291 100644
--- a/core/src/main/scala/org/apache/spark/deploy/Client.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala
@@ -18,17 +18,17 @@
package org.apache.spark.deploy
import scala.collection.mutable.HashSet
-import scala.concurrent._
+import scala.concurrent.ExecutionContext
+import scala.reflect.ClassTag
+import scala.util.{Failure, Success}
-import akka.actor._
-import akka.pattern.ask
-import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent}
import org.apache.log4j.{Level, Logger}
+import org.apache.spark.rpc.{RpcEndpointRef, RpcAddress, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.{Logging, SecurityManager, SparkConf}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.{DriverState, Master}
-import org.apache.spark.util.{ActorLogReceive, AkkaUtils, RpcUtils, Utils}
+import org.apache.spark.util.{ThreadUtils, SparkExitCode, Utils}
/**
* Proxy that relays messages to the driver.
@@ -36,20 +36,30 @@ import org.apache.spark.util.{ActorLogReceive, AkkaUtils, RpcUtils, Utils}
* We currently don't support retry if submission fails. In HA mode, client will submit request to
* all masters and see which one could handle it.
*/
-private class ClientActor(driverArgs: ClientArguments, conf: SparkConf)
- extends Actor with ActorLogReceive with Logging {
-
- private val masterActors = driverArgs.masters.map { m =>
- context.actorSelection(Master.toAkkaUrl(m, AkkaUtils.protocol(context.system)))
- }
- private val lostMasters = new HashSet[Address]
- private var activeMasterActor: ActorSelection = null
-
- val timeout = RpcUtils.askTimeout(conf)
-
- override def preStart(): Unit = {
- context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
-
+private class ClientEndpoint(
+ override val rpcEnv: RpcEnv,
+ driverArgs: ClientArguments,
+ masterEndpoints: Seq[RpcEndpointRef],
+ conf: SparkConf)
+ extends ThreadSafeRpcEndpoint with Logging {
+
+ // A scheduled executor used to send messages at the specified time.
+ private val forwardMessageThread =
+ ThreadUtils.newDaemonSingleThreadScheduledExecutor("client-forward-message")
+ // Used to provide the implicit parameter of `Future` methods.
+ private val forwardMessageExecutionContext =
+ ExecutionContext.fromExecutor(forwardMessageThread,
+ t => t match {
+ case ie: InterruptedException => // Exit normally
+ case e: Throwable =>
+ logError(e.getMessage, e)
+ System.exit(SparkExitCode.UNCAUGHT_EXCEPTION)
+ })
+
+ private val lostMasters = new HashSet[RpcAddress]
+ private var activeMasterEndpoint: RpcEndpointRef = null
+
+ override def onStart(): Unit = {
driverArgs.cmd match {
case "launch" =>
// TODO: We could add an env variable here and intercept it in `sc.addJar` that would
@@ -82,29 +92,37 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf)
driverArgs.cores,
driverArgs.supervise,
command)
-
- // This assumes only one Master is active at a time
- for (masterActor <- masterActors) {
- masterActor ! RequestSubmitDriver(driverDescription)
- }
+ ayncSendToMasterAndForwardReply[SubmitDriverResponse](
+ RequestSubmitDriver(driverDescription))
case "kill" =>
val driverId = driverArgs.driverId
- // This assumes only one Master is active at a time
- for (masterActor <- masterActors) {
- masterActor ! RequestKillDriver(driverId)
- }
+ ayncSendToMasterAndForwardReply[KillDriverResponse](RequestKillDriver(driverId))
+ }
+ }
+
+ /**
+ * Send the message to master and forward the reply to self asynchronously.
+ */
+ private def ayncSendToMasterAndForwardReply[T: ClassTag](message: Any): Unit = {
+ for (masterEndpoint <- masterEndpoints) {
+ masterEndpoint.ask[T](message).onComplete {
+ case Success(v) => self.send(v)
+ case Failure(e) =>
+ logWarning(s"Error sending messages to master $masterEndpoint", e)
+ }(forwardMessageExecutionContext)
}
}
/* Find out driver status then exit the JVM */
def pollAndReportStatus(driverId: String) {
+ // Since ClientEndpoint is the only RpcEndpoint in the process, blocking the event loop thread
+ // is fine.
println("... waiting before polling master for driver state")
Thread.sleep(5000)
println("... polling master for driver state")
- val statusFuture = (activeMasterActor ? RequestDriverStatus(driverId))(timeout)
- .mapTo[DriverStatusResponse]
- val statusResponse = Await.result(statusFuture, timeout)
+ val statusResponse =
+ activeMasterEndpoint.askWithRetry[DriverStatusResponse](RequestDriverStatus(driverId))
statusResponse.found match {
case false =>
println(s"ERROR: Cluster master did not recognize $driverId")
@@ -127,50 +145,62 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf)
}
}
- override def receiveWithLogging: PartialFunction[Any, Unit] = {
+ override def receive: PartialFunction[Any, Unit] = {
- case SubmitDriverResponse(success, driverId, message) =>
+ case SubmitDriverResponse(master, success, driverId, message) =>
println(message)
if (success) {
- activeMasterActor = context.actorSelection(sender.path)
+ activeMasterEndpoint = master
pollAndReportStatus(driverId.get)
} else if (!Utils.responseFromBackup(message)) {
System.exit(-1)
}
- case KillDriverResponse(driverId, success, message) =>
+ case KillDriverResponse(master, driverId, success, message) =>
println(message)
if (success) {
- activeMasterActor = context.actorSelection(sender.path)
+ activeMasterEndpoint = master
pollAndReportStatus(driverId)
} else if (!Utils.responseFromBackup(message)) {
System.exit(-1)
}
+ }
- case DisassociatedEvent(_, remoteAddress, _) =>
- if (!lostMasters.contains(remoteAddress)) {
- println(s"Error connecting to master $remoteAddress.")
- lostMasters += remoteAddress
- // Note that this heuristic does not account for the fact that a Master can recover within
- // the lifetime of this client. Thus, once a Master is lost it is lost to us forever. This
- // is not currently a concern, however, because this client does not retry submissions.
- if (lostMasters.size >= masterActors.size) {
- println("No master is available, exiting.")
- System.exit(-1)
- }
+ override def onDisconnected(remoteAddress: RpcAddress): Unit = {
+ if (!lostMasters.contains(remoteAddress)) {
+ println(s"Error connecting to master $remoteAddress.")
+ lostMasters += remoteAddress
+ // Note that this heuristic does not account for the fact that a Master can recover within
+ // the lifetime of this client. Thus, once a Master is lost it is lost to us forever. This
+ // is not currently a concern, however, because this client does not retry submissions.
+ if (lostMasters.size >= masterEndpoints.size) {
+ println("No master is available, exiting.")
+ System.exit(-1)
}
+ }
+ }
- case AssociationErrorEvent(cause, _, remoteAddress, _, _) =>
- if (!lostMasters.contains(remoteAddress)) {
- println(s"Error connecting to master ($remoteAddress).")
- println(s"Cause was: $cause")
- lostMasters += remoteAddress
- if (lostMasters.size >= masterActors.size) {
- println("No master is available, exiting.")
- System.exit(-1)
- }
+ override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = {
+ if (!lostMasters.contains(remoteAddress)) {
+ println(s"Error connecting to master ($remoteAddress).")
+ println(s"Cause was: $cause")
+ lostMasters += remoteAddress
+ if (lostMasters.size >= masterEndpoints.size) {
+ println("No master is available, exiting.")
+ System.exit(-1)
}
+ }
+ }
+
+ override def onError(cause: Throwable): Unit = {
+ println(s"Error processing messages, exiting.")
+ cause.printStackTrace()
+ System.exit(-1)
+ }
+
+ override def onStop(): Unit = {
+ forwardMessageThread.shutdownNow()
}
}
@@ -194,15 +224,13 @@ object Client {
conf.set("akka.loglevel", driverArgs.logLevel.toString.replace("WARN", "WARNING"))
Logger.getRootLogger.setLevel(driverArgs.logLevel)
- val (actorSystem, _) = AkkaUtils.createActorSystem(
- "driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf))
+ val rpcEnv =
+ RpcEnv.create("driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf))
- // Verify driverArgs.master is a valid url so that we can use it in ClientActor safely
- for (m <- driverArgs.masters) {
- Master.toAkkaUrl(m, AkkaUtils.protocol(actorSystem))
- }
- actorSystem.actorOf(Props(classOf[ClientActor], driverArgs, conf))
+ val masterEndpoints = driverArgs.masters.map(RpcAddress.fromSparkURL).
+ map(rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, _, Master.ENDPOINT_NAME))
+ rpcEnv.setupEndpoint("client", new ClientEndpoint(rpcEnv, driverArgs, masterEndpoints, conf))
- actorSystem.awaitTermination()
+ rpcEnv.awaitTermination()
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
index 9db6fd1ac4..12727de9b4 100644
--- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
@@ -24,11 +24,12 @@ import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo}
import org.apache.spark.deploy.master.DriverState.DriverState
import org.apache.spark.deploy.master.RecoveryState.MasterState
import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner}
+import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.Utils
private[deploy] sealed trait DeployMessage extends Serializable
-/** Contains messages sent between Scheduler actor nodes. */
+/** Contains messages sent between Scheduler endpoint nodes. */
private[deploy] object DeployMessages {
// Worker to Master
@@ -37,6 +38,7 @@ private[deploy] object DeployMessages {
id: String,
host: String,
port: Int,
+ worker: RpcEndpointRef,
cores: Int,
memory: Int,
webUiPort: Int,
@@ -63,11 +65,11 @@ private[deploy] object DeployMessages {
case class WorkerSchedulerStateResponse(id: String, executors: List[ExecutorDescription],
driverIds: Seq[String])
- case class Heartbeat(workerId: String) extends DeployMessage
+ case class Heartbeat(workerId: String, worker: RpcEndpointRef) extends DeployMessage
// Master to Worker
- case class RegisteredWorker(masterUrl: String, masterWebUiUrl: String) extends DeployMessage
+ case class RegisteredWorker(master: RpcEndpointRef, masterWebUiUrl: String) extends DeployMessage
case class RegisterWorkerFailed(message: String) extends DeployMessage
@@ -92,13 +94,13 @@ private[deploy] object DeployMessages {
// Worker internal
- case object WorkDirCleanup // Sent to Worker actor periodically for cleaning up app folders
+ case object WorkDirCleanup // Sent to Worker endpoint periodically for cleaning up app folders
case object ReregisterWithMaster // used when a worker attempts to reconnect to a master
// AppClient to Master
- case class RegisterApplication(appDescription: ApplicationDescription)
+ case class RegisterApplication(appDescription: ApplicationDescription, driver: RpcEndpointRef)
extends DeployMessage
case class UnregisterApplication(appId: String)
@@ -107,7 +109,7 @@ private[deploy] object DeployMessages {
// Master to AppClient
- case class RegisteredApplication(appId: String, masterUrl: String) extends DeployMessage
+ case class RegisteredApplication(appId: String, master: RpcEndpointRef) extends DeployMessage
// TODO(matei): replace hostPort with host
case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) {
@@ -123,12 +125,14 @@ private[deploy] object DeployMessages {
case class RequestSubmitDriver(driverDescription: DriverDescription) extends DeployMessage
- case class SubmitDriverResponse(success: Boolean, driverId: Option[String], message: String)
+ case class SubmitDriverResponse(
+ master: RpcEndpointRef, success: Boolean, driverId: Option[String], message: String)
extends DeployMessage
case class RequestKillDriver(driverId: String) extends DeployMessage
- case class KillDriverResponse(driverId: String, success: Boolean, message: String)
+ case class KillDriverResponse(
+ master: RpcEndpointRef, driverId: String, success: Boolean, message: String)
extends DeployMessage
case class RequestDriverStatus(driverId: String) extends DeployMessage
@@ -142,7 +146,7 @@ private[deploy] object DeployMessages {
// Master to Worker & AppClient
- case class MasterChanged(masterUrl: String, masterWebUiUrl: String)
+ case class MasterChanged(master: RpcEndpointRef, masterWebUiUrl: String)
// MasterWebUI To Master
diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
index 0550f00a17..53356addf6 100644
--- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
@@ -19,8 +19,7 @@ package org.apache.spark.deploy
import scala.collection.mutable.ArrayBuffer
-import akka.actor.ActorSystem
-
+import org.apache.spark.rpc.RpcEnv
import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.deploy.worker.Worker
import org.apache.spark.deploy.master.Master
@@ -41,8 +40,8 @@ class LocalSparkCluster(
extends Logging {
private val localHostname = Utils.localHostName()
- private val masterActorSystems = ArrayBuffer[ActorSystem]()
- private val workerActorSystems = ArrayBuffer[ActorSystem]()
+ private val masterRpcEnvs = ArrayBuffer[RpcEnv]()
+ private val workerRpcEnvs = ArrayBuffer[RpcEnv]()
// exposed for testing
var masterWebUIPort = -1
@@ -55,18 +54,17 @@ class LocalSparkCluster(
.set("spark.shuffle.service.enabled", "false")
/* Start the Master */
- val (masterSystem, masterPort, webUiPort, _) =
- Master.startSystemAndActor(localHostname, 0, 0, _conf)
+ val (rpcEnv, webUiPort, _) = Master.startRpcEnvAndEndpoint(localHostname, 0, 0, _conf)
masterWebUIPort = webUiPort
- masterActorSystems += masterSystem
- val masterUrl = "spark://" + Utils.localHostNameForURI() + ":" + masterPort
+ masterRpcEnvs += rpcEnv
+ val masterUrl = "spark://" + Utils.localHostNameForURI() + ":" + rpcEnv.address.port
val masters = Array(masterUrl)
/* Start the Workers */
for (workerNum <- 1 to numWorkers) {
- val (workerSystem, _) = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker,
+ val workerEnv = Worker.startRpcEnvAndEndpoint(localHostname, 0, 0, coresPerWorker,
memoryPerWorker, masters, null, Some(workerNum), _conf)
- workerActorSystems += workerSystem
+ workerRpcEnvs += workerEnv
}
masters
@@ -77,11 +75,11 @@ class LocalSparkCluster(
// Stop the workers before the master so they don't get upset that it disconnected
// TODO: In Akka 2.1.x, ActorSystem.awaitTermination hangs when you have remote actors!
// This is unfortunate, but for now we just comment it out.
- workerActorSystems.foreach(_.shutdown())
+ workerRpcEnvs.foreach(_.shutdown())
// workerActorSystems.foreach(_.awaitTermination())
- masterActorSystems.foreach(_.shutdown())
+ masterRpcEnvs.foreach(_.shutdown())
// masterActorSystems.foreach(_.awaitTermination())
- masterActorSystems.clear()
- workerActorSystems.clear()
+ masterRpcEnvs.clear()
+ workerRpcEnvs.clear()
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala
index 43c8a934c3..79b251e7e6 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala
@@ -17,20 +17,17 @@
package org.apache.spark.deploy.client
-import java.util.concurrent.TimeoutException
+import java.util.concurrent._
+import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture}
-import scala.concurrent.Await
-import scala.concurrent.duration._
-
-import akka.actor._
-import akka.pattern.ask
-import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent}
+import scala.util.control.NonFatal
import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.deploy.{ApplicationDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.Master
-import org.apache.spark.util.{ActorLogReceive, RpcUtils, Utils, AkkaUtils}
+import org.apache.spark.rpc._
+import org.apache.spark.util.{ThreadUtils, Utils}
/**
* Interface allowing applications to speak with a Spark deploy cluster. Takes a master URL,
@@ -40,98 +37,143 @@ import org.apache.spark.util.{ActorLogReceive, RpcUtils, Utils, AkkaUtils}
* @param masterUrls Each url should look like spark://host:port.
*/
private[spark] class AppClient(
- actorSystem: ActorSystem,
+ rpcEnv: RpcEnv,
masterUrls: Array[String],
appDescription: ApplicationDescription,
listener: AppClientListener,
conf: SparkConf)
extends Logging {
- private val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl(_, AkkaUtils.protocol(actorSystem)))
+ private val masterRpcAddresses = masterUrls.map(RpcAddress.fromSparkURL(_))
- private val REGISTRATION_TIMEOUT = 20.seconds
+ private val REGISTRATION_TIMEOUT_SECONDS = 20
private val REGISTRATION_RETRIES = 3
- private var masterAddress: Address = null
- private var actor: ActorRef = null
+ private var endpoint: RpcEndpointRef = null
private var appId: String = null
- private var registered = false
- private var activeMasterUrl: String = null
+ @volatile private var registered = false
+
+ private class ClientEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint
+ with Logging {
+
+ private var master: Option[RpcEndpointRef] = None
+ // To avoid calling listener.disconnected() multiple times
+ private var alreadyDisconnected = false
+ @volatile private var alreadyDead = false // To avoid calling listener.dead() multiple times
+ @volatile private var registerMasterFutures: Array[JFuture[_]] = null
+ @volatile private var registrationRetryTimer: JScheduledFuture[_] = null
+
+ // A thread pool for registering with masters. Because registering with a master is a blocking
+ // action, this thread pool must be able to create "masterRpcAddresses.size" threads at the same
+ // time so that we can register with all masters.
+ private val registerMasterThreadPool = new ThreadPoolExecutor(
+ 0,
+ masterRpcAddresses.size, // Make sure we can register with all masters at the same time
+ 60L, TimeUnit.SECONDS,
+ new SynchronousQueue[Runnable](),
+ ThreadUtils.namedThreadFactory("appclient-register-master-threadpool"))
- private class ClientActor extends Actor with ActorLogReceive with Logging {
- var master: ActorSelection = null
- var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times
- var alreadyDead = false // To avoid calling listener.dead() multiple times
- var registrationRetryTimer: Option[Cancellable] = None
+ // A scheduled executor for scheduling the registration actions
+ private val registrationRetryThread =
+ ThreadUtils.newDaemonSingleThreadScheduledExecutor("appclient-registration-retry-thread")
- override def preStart() {
- context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
+ override def onStart(): Unit = {
try {
- registerWithMaster()
+ registerWithMaster(1)
} catch {
case e: Exception =>
logWarning("Failed to connect to master", e)
markDisconnected()
- context.stop(self)
+ stop()
}
}
- def tryRegisterAllMasters() {
- for (masterAkkaUrl <- masterAkkaUrls) {
- logInfo("Connecting to master " + masterAkkaUrl + "...")
- val actor = context.actorSelection(masterAkkaUrl)
- actor ! RegisterApplication(appDescription)
+ /**
+ * Register with all masters asynchronously and returns an array `Future`s for cancellation.
+ */
+ private def tryRegisterAllMasters(): Array[JFuture[_]] = {
+ for (masterAddress <- masterRpcAddresses) yield {
+ registerMasterThreadPool.submit(new Runnable {
+ override def run(): Unit = try {
+ if (registered) {
+ return
+ }
+ logInfo("Connecting to master " + masterAddress.toSparkURL + "...")
+ val masterRef =
+ rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME)
+ masterRef.send(RegisterApplication(appDescription, self))
+ } catch {
+ case ie: InterruptedException => // Cancelled
+ case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e)
+ }
+ })
}
}
- def registerWithMaster() {
- tryRegisterAllMasters()
- import context.dispatcher
- var retries = 0
- registrationRetryTimer = Some {
- context.system.scheduler.schedule(REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT) {
+ /**
+ * Register with all masters asynchronously. It will call `registerWithMaster` every
+ * REGISTRATION_TIMEOUT_SECONDS seconds until exceeding REGISTRATION_RETRIES times.
+ * Once we connect to a master successfully, all scheduling work and Futures will be cancelled.
+ *
+ * nthRetry means this is the nth attempt to register with master.
+ */
+ private def registerWithMaster(nthRetry: Int) {
+ registerMasterFutures = tryRegisterAllMasters()
+ registrationRetryTimer = registrationRetryThread.scheduleAtFixedRate(new Runnable {
+ override def run(): Unit = {
Utils.tryOrExit {
- retries += 1
if (registered) {
- registrationRetryTimer.foreach(_.cancel())
- } else if (retries >= REGISTRATION_RETRIES) {
+ registerMasterFutures.foreach(_.cancel(true))
+ registerMasterThreadPool.shutdownNow()
+ } else if (nthRetry >= REGISTRATION_RETRIES) {
markDead("All masters are unresponsive! Giving up.")
} else {
- tryRegisterAllMasters()
+ registerMasterFutures.foreach(_.cancel(true))
+ registerWithMaster(nthRetry + 1)
}
}
}
- }
+ }, REGISTRATION_TIMEOUT_SECONDS, REGISTRATION_TIMEOUT_SECONDS, TimeUnit.SECONDS)
}
- def changeMaster(url: String) {
- // activeMasterUrl is a valid Spark url since we receive it from master.
- activeMasterUrl = url
- master = context.actorSelection(
- Master.toAkkaUrl(activeMasterUrl, AkkaUtils.protocol(actorSystem)))
- masterAddress = Master.toAkkaAddress(activeMasterUrl, AkkaUtils.protocol(actorSystem))
+ /**
+ * Send a message to the current master. If we have not yet registered successfully with any
+ * master, the message will be dropped.
+ */
+ private def sendToMaster(message: Any): Unit = {
+ master match {
+ case Some(masterRef) => masterRef.send(message)
+ case None => logWarning(s"Drop $message because has not yet connected to master")
+ }
}
- private def isPossibleMaster(remoteUrl: Address) = {
- masterAkkaUrls.map(AddressFromURIString(_).hostPort).contains(remoteUrl.hostPort)
+ private def isPossibleMaster(remoteAddress: RpcAddress): Boolean = {
+ masterRpcAddresses.contains(remoteAddress)
}
- override def receiveWithLogging: PartialFunction[Any, Unit] = {
- case RegisteredApplication(appId_, masterUrl) =>
+ override def receive: PartialFunction[Any, Unit] = {
+ case RegisteredApplication(appId_, masterRef) =>
+ // FIXME How to handle the following cases?
+ // 1. A master receives multiple registrations and sends back multiple
+ // RegisteredApplications due to an unstable network.
+ // 2. Receive multiple RegisteredApplication from different masters because the master is
+ // changing.
appId = appId_
registered = true
- changeMaster(masterUrl)
+ master = Some(masterRef)
listener.connected(appId)
case ApplicationRemoved(message) =>
markDead("Master removed our application: %s".format(message))
- context.stop(self)
+ stop()
case ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) =>
val fullId = appId + "/" + id
logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort,
cores))
- master ! ExecutorStateChanged(appId, id, ExecutorState.RUNNING, None, None)
+ // FIXME if changing master and `ExecutorAdded` happen at the same time (the order is not
+ // guaranteed), `ExecutorStateChanged` may be sent to a dead master.
+ sendToMaster(ExecutorStateChanged(appId, id, ExecutorState.RUNNING, None, None))
listener.executorAdded(fullId, workerId, hostPort, cores, memory)
case ExecutorUpdated(id, state, message, exitStatus) =>
@@ -142,24 +184,32 @@ private[spark] class AppClient(
listener.executorRemoved(fullId, message.getOrElse(""), exitStatus)
}
- case MasterChanged(masterUrl, masterWebUiUrl) =>
- logInfo("Master has changed, new master is at " + masterUrl)
- changeMaster(masterUrl)
+ case MasterChanged(masterRef, masterWebUiUrl) =>
+ logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL)
+ master = Some(masterRef)
alreadyDisconnected = false
- sender ! MasterChangeAcknowledged(appId)
+ masterRef.send(MasterChangeAcknowledged(appId))
+ }
- case DisassociatedEvent(_, address, _) if address == masterAddress =>
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ case StopAppClient =>
+ markDead("Application has been stopped.")
+ sendToMaster(UnregisterApplication(appId))
+ context.reply(true)
+ stop()
+ }
+
+ override def onDisconnected(address: RpcAddress): Unit = {
+ if (master.exists(_.address == address)) {
logWarning(s"Connection to $address failed; waiting for master to reconnect...")
markDisconnected()
+ }
+ }
- case AssociationErrorEvent(cause, _, address, _, _) if isPossibleMaster(address) =>
+ override def onNetworkError(cause: Throwable, address: RpcAddress): Unit = {
+ if (isPossibleMaster(address)) {
logWarning(s"Could not connect to $address: $cause")
-
- case StopAppClient =>
- markDead("Application has been stopped.")
- master ! UnregisterApplication(appId)
- sender ! true
- context.stop(self)
+ }
}
/**
@@ -179,28 +229,31 @@ private[spark] class AppClient(
}
}
- override def postStop() {
- registrationRetryTimer.foreach(_.cancel())
+ override def onStop(): Unit = {
+ if (registrationRetryTimer != null) {
+ registrationRetryTimer.cancel(true)
+ }
+ registrationRetryThread.shutdownNow()
+ registerMasterFutures.foreach(_.cancel(true))
+ registerMasterThreadPool.shutdownNow()
}
}
def start() {
// Just launch an actor; it will call back into the listener.
- actor = actorSystem.actorOf(Props(new ClientActor))
+ endpoint = rpcEnv.setupEndpoint("AppClient", new ClientEndpoint(rpcEnv))
}
def stop() {
- if (actor != null) {
+ if (endpoint != null) {
try {
- val timeout = RpcUtils.askTimeout(conf)
- val future = actor.ask(StopAppClient)(timeout)
- Await.result(future, timeout)
+ endpoint.askWithRetry[Boolean](StopAppClient)
} catch {
case e: TimeoutException =>
logInfo("Stop request to Master timed out; it may already be shut down.")
}
- actor = null
+ endpoint = null
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
index 40835b9550..1c79089303 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
@@ -17,9 +17,10 @@
package org.apache.spark.deploy.client
+import org.apache.spark.rpc.RpcEnv
import org.apache.spark.{SecurityManager, SparkConf, Logging}
import org.apache.spark.deploy.{ApplicationDescription, Command}
-import org.apache.spark.util.{AkkaUtils, Utils}
+import org.apache.spark.util.Utils
private[spark] object TestClient {
@@ -46,13 +47,12 @@ private[spark] object TestClient {
def main(args: Array[String]) {
val url = args(0)
val conf = new SparkConf
- val (actorSystem, _) = AkkaUtils.createActorSystem("spark", Utils.localHostName(), 0,
- conf = conf, securityManager = new SecurityManager(conf))
+ val rpcEnv = RpcEnv.create("spark", Utils.localHostName(), 0, conf, new SecurityManager(conf))
val desc = new ApplicationDescription("TestClient", Some(1), 512,
Command("spark.deploy.client.TestExecutor", Seq(), Map(), Seq(), Seq(), Seq()), "ignored")
val listener = new TestListener
- val client = new AppClient(actorSystem, Array(url), desc, listener, new SparkConf)
+ val client = new AppClient(rpcEnv, Array(url), desc, listener, new SparkConf)
client.start()
- actorSystem.awaitTermination()
+ rpcEnv.awaitTermination()
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
index 1620e95bea..aa54ed9360 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
@@ -22,10 +22,9 @@ import java.util.Date
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
-import akka.actor.ActorRef
-
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.deploy.ApplicationDescription
+import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.Utils
private[spark] class ApplicationInfo(
@@ -33,7 +32,7 @@ private[spark] class ApplicationInfo(
val id: String,
val desc: ApplicationDescription,
val submitDate: Date,
- val driver: ActorRef,
+ val driver: RpcEndpointRef,
defaultCores: Int)
extends Serializable {
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index fccceb3ea5..3e7c167228 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -21,20 +21,18 @@ import java.io.FileNotFoundException
import java.net.URLEncoder
import java.text.SimpleDateFormat
import java.util.Date
+import java.util.concurrent.{ScheduledFuture, TimeUnit}
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
-import scala.concurrent.Await
-import scala.concurrent.duration._
import scala.language.postfixOps
import scala.util.Random
-import akka.actor._
-import akka.pattern.ask
-import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
import akka.serialization.Serialization
import akka.serialization.SerializationExtension
import org.apache.hadoop.fs.Path
+import org.apache.spark.rpc.akka.AkkaRpcEnv
+import org.apache.spark.rpc._
import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
import org.apache.spark.deploy.{ApplicationDescription, DriverDescription,
ExecutorState, SparkHadoopUtil}
@@ -47,23 +45,27 @@ import org.apache.spark.deploy.rest.StandaloneRestServer
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus}
import org.apache.spark.ui.SparkUI
-import org.apache.spark.util.{ActorLogReceive, AkkaUtils, RpcUtils, SignalLogger, Utils}
+import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils}
private[master] class Master(
- host: String,
- port: Int,
+ override val rpcEnv: RpcEnv,
+ address: RpcAddress,
webUiPort: Int,
val securityMgr: SecurityManager,
val conf: SparkConf)
- extends Actor with ActorLogReceive with Logging with LeaderElectable {
+ extends ThreadSafeRpcEndpoint with Logging with LeaderElectable {
- import context.dispatcher // to use Akka's scheduler.schedule()
+ private val forwardMessageThread =
+ ThreadUtils.newDaemonSingleThreadScheduledExecutor("master-forward-message-thread")
+
+ // TODO Remove it once we don't use akka.serialization.Serialization
+ private val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem
private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
- private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs
+ private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs
- private val WORKER_TIMEOUT = conf.getLong("spark.worker.timeout", 60) * 1000
+ private val WORKER_TIMEOUT_MS = conf.getLong("spark.worker.timeout", 60) * 1000
private val RETAINED_APPLICATIONS = conf.getInt("spark.deploy.retainedApplications", 200)
private val RETAINED_DRIVERS = conf.getInt("spark.deploy.retainedDrivers", 200)
private val REAPER_ITERATIONS = conf.getInt("spark.dead.worker.persistence", 15)
@@ -75,10 +77,10 @@ private[master] class Master(
val apps = new HashSet[ApplicationInfo]
private val idToWorker = new HashMap[String, WorkerInfo]
- private val addressToWorker = new HashMap[Address, WorkerInfo]
+ private val addressToWorker = new HashMap[RpcAddress, WorkerInfo]
- private val actorToApp = new HashMap[ActorRef, ApplicationInfo]
- private val addressToApp = new HashMap[Address, ApplicationInfo]
+ private val endpointToApp = new HashMap[RpcEndpointRef, ApplicationInfo]
+ private val addressToApp = new HashMap[RpcAddress, ApplicationInfo]
private val completedApps = new ArrayBuffer[ApplicationInfo]
private var nextAppNumber = 0
private val appIdToUI = new HashMap[String, SparkUI]
@@ -89,21 +91,22 @@ private[master] class Master(
private val waitingDrivers = new ArrayBuffer[DriverInfo]
private var nextDriverNumber = 0
- Utils.checkHost(host, "Expected hostname")
+ Utils.checkHost(address.host, "Expected hostname")
private val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf, securityMgr)
private val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf,
securityMgr)
private val masterSource = new MasterSource(this)
- private val webUi = new MasterWebUI(this, webUiPort)
+ // After onStart, webUi will be set
+ private var webUi: MasterWebUI = null
private val masterPublicAddress = {
val envVar = conf.getenv("SPARK_PUBLIC_DNS")
- if (envVar != null) envVar else host
+ if (envVar != null) envVar else address.host
}
- private val masterUrl = "spark://" + host + ":" + port
+ private val masterUrl = address.toSparkURL
private var masterWebUiUrl: String = _
private var state = RecoveryState.STANDBY
@@ -112,7 +115,9 @@ private[master] class Master(
private var leaderElectionAgent: LeaderElectionAgent = _
- private var recoveryCompletionTask: Cancellable = _
+ private var recoveryCompletionTask: ScheduledFuture[_] = _
+
+ private var checkForWorkerTimeOutTask: ScheduledFuture[_] = _
// As a temporary workaround before better ways of configuring memory, we allow users to set
// a flag that will perform round-robin scheduling across the nodes (spreading out each app
@@ -130,20 +135,23 @@ private[master] class Master(
private val restServer =
if (restServerEnabled) {
val port = conf.getInt("spark.master.rest.port", 6066)
- Some(new StandaloneRestServer(host, port, conf, self, masterUrl))
+ Some(new StandaloneRestServer(address.host, port, conf, self, masterUrl))
} else {
None
}
private val restServerBoundPort = restServer.map(_.start())
- override def preStart() {
+ override def onStart(): Unit = {
logInfo("Starting Spark master at " + masterUrl)
logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}")
- // Listen for remote client disconnection events, since they don't go through Akka's watch()
- context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
+ webUi = new MasterWebUI(this, webUiPort)
webUi.bind()
masterWebUiUrl = "http://" + masterPublicAddress + ":" + webUi.boundPort
- context.system.scheduler.schedule(0 millis, WORKER_TIMEOUT millis, self, CheckForWorkerTimeOut)
+ checkForWorkerTimeOutTask = forwardMessageThread.scheduleAtFixedRate(new Runnable {
+ override def run(): Unit = Utils.tryLogNonFatalError {
+ self.send(CheckForWorkerTimeOut)
+ }
+ }, 0, WORKER_TIMEOUT_MS, TimeUnit.MILLISECONDS)
masterMetricsSystem.registerSource(masterSource)
masterMetricsSystem.start()
@@ -157,16 +165,16 @@ private[master] class Master(
case "ZOOKEEPER" =>
logInfo("Persisting recovery state to ZooKeeper")
val zkFactory =
- new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(context.system))
+ new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(actorSystem))
(zkFactory.createPersistenceEngine(), zkFactory.createLeaderElectionAgent(this))
case "FILESYSTEM" =>
val fsFactory =
- new FileSystemRecoveryModeFactory(conf, SerializationExtension(context.system))
+ new FileSystemRecoveryModeFactory(conf, SerializationExtension(actorSystem))
(fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this))
case "CUSTOM" =>
val clazz = Class.forName(conf.get("spark.deploy.recoveryMode.factory"))
val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serialization])
- .newInstance(conf, SerializationExtension(context.system))
+ .newInstance(conf, SerializationExtension(actorSystem))
.asInstanceOf[StandaloneRecoveryModeFactory]
(factory.createPersistenceEngine(), factory.createLeaderElectionAgent(this))
case _ =>
@@ -176,18 +184,17 @@ private[master] class Master(
leaderElectionAgent = leaderElectionAgent_
}
- override def preRestart(reason: Throwable, message: Option[Any]) {
- super.preRestart(reason, message) // calls postStop()!
- logError("Master actor restarted due to exception", reason)
- }
-
- override def postStop() {
+ override def onStop() {
masterMetricsSystem.report()
applicationMetricsSystem.report()
// prevent the CompleteRecovery message sending to restarted master
if (recoveryCompletionTask != null) {
- recoveryCompletionTask.cancel()
+ recoveryCompletionTask.cancel(true)
}
+ if (checkForWorkerTimeOutTask != null) {
+ checkForWorkerTimeOutTask.cancel(true)
+ }
+ forwardMessageThread.shutdownNow()
webUi.stop()
restServer.foreach(_.stop())
masterMetricsSystem.stop()
@@ -197,14 +204,14 @@ private[master] class Master(
}
override def electedLeader() {
- self ! ElectedLeader
+ self.send(ElectedLeader)
}
override def revokedLeadership() {
- self ! RevokedLeadership
+ self.send(RevokedLeadership)
}
- override def receiveWithLogging: PartialFunction[Any, Unit] = {
+ override def receive: PartialFunction[Any, Unit] = {
case ElectedLeader => {
val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData()
state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) {
@@ -215,8 +222,11 @@ private[master] class Master(
logInfo("I have been elected leader! New state: " + state)
if (state == RecoveryState.RECOVERING) {
beginRecovery(storedApps, storedDrivers, storedWorkers)
- recoveryCompletionTask = context.system.scheduler.scheduleOnce(WORKER_TIMEOUT millis, self,
- CompleteRecovery)
+ recoveryCompletionTask = forwardMessageThread.schedule(new Runnable {
+ override def run(): Unit = Utils.tryLogNonFatalError {
+ self.send(CompleteRecovery)
+ }
+ }, WORKER_TIMEOUT_MS, TimeUnit.MILLISECONDS)
}
}
@@ -227,111 +237,42 @@ private[master] class Master(
System.exit(0)
}
- case RegisterWorker(id, workerHost, workerPort, cores, memory, workerUiPort, publicAddress) =>
- {
+ case RegisterWorker(
+ id, workerHost, workerPort, workerRef, cores, memory, workerUiPort, publicAddress) => {
logInfo("Registering worker %s:%d with %d cores, %s RAM".format(
workerHost, workerPort, cores, Utils.megabytesToString(memory)))
if (state == RecoveryState.STANDBY) {
// ignore, don't send response
} else if (idToWorker.contains(id)) {
- sender ! RegisterWorkerFailed("Duplicate worker ID")
+ workerRef.send(RegisterWorkerFailed("Duplicate worker ID"))
} else {
val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory,
- sender, workerUiPort, publicAddress)
+ workerRef, workerUiPort, publicAddress)
if (registerWorker(worker)) {
persistenceEngine.addWorker(worker)
- sender ! RegisteredWorker(masterUrl, masterWebUiUrl)
+ workerRef.send(RegisteredWorker(self, masterWebUiUrl))
schedule()
} else {
- val workerAddress = worker.actor.path.address
+ val workerAddress = worker.endpoint.address
logWarning("Worker registration failed. Attempted to re-register worker at same " +
"address: " + workerAddress)
- sender ! RegisterWorkerFailed("Attempted to re-register worker at same address: "
- + workerAddress)
+ workerRef.send(RegisterWorkerFailed("Attempted to re-register worker at same address: "
+ + workerAddress))
}
}
}
- case RequestSubmitDriver(description) => {
- if (state != RecoveryState.ALIVE) {
- val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " +
- "Can only accept driver submissions in ALIVE state."
- sender ! SubmitDriverResponse(false, None, msg)
- } else {
- logInfo("Driver submitted " + description.command.mainClass)
- val driver = createDriver(description)
- persistenceEngine.addDriver(driver)
- waitingDrivers += driver
- drivers.add(driver)
- schedule()
-
- // TODO: It might be good to instead have the submission client poll the master to determine
- // the current status of the driver. For now it's simply "fire and forget".
-
- sender ! SubmitDriverResponse(true, Some(driver.id),
- s"Driver successfully submitted as ${driver.id}")
- }
- }
-
- case RequestKillDriver(driverId) => {
- if (state != RecoveryState.ALIVE) {
- val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " +
- s"Can only kill drivers in ALIVE state."
- sender ! KillDriverResponse(driverId, success = false, msg)
- } else {
- logInfo("Asked to kill driver " + driverId)
- val driver = drivers.find(_.id == driverId)
- driver match {
- case Some(d) =>
- if (waitingDrivers.contains(d)) {
- waitingDrivers -= d
- self ! DriverStateChanged(driverId, DriverState.KILLED, None)
- } else {
- // We just notify the worker to kill the driver here. The final bookkeeping occurs
- // on the return path when the worker submits a state change back to the master
- // to notify it that the driver was successfully killed.
- d.worker.foreach { w =>
- w.actor ! KillDriver(driverId)
- }
- }
- // TODO: It would be nice for this to be a synchronous response
- val msg = s"Kill request for $driverId submitted"
- logInfo(msg)
- sender ! KillDriverResponse(driverId, success = true, msg)
- case None =>
- val msg = s"Driver $driverId has already finished or does not exist"
- logWarning(msg)
- sender ! KillDriverResponse(driverId, success = false, msg)
- }
- }
- }
-
- case RequestDriverStatus(driverId) => {
- if (state != RecoveryState.ALIVE) {
- val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " +
- "Can only request driver status in ALIVE state."
- sender ! DriverStatusResponse(found = false, None, None, None, Some(new Exception(msg)))
- } else {
- (drivers ++ completedDrivers).find(_.id == driverId) match {
- case Some(driver) =>
- sender ! DriverStatusResponse(found = true, Some(driver.state),
- driver.worker.map(_.id), driver.worker.map(_.hostPort), driver.exception)
- case None =>
- sender ! DriverStatusResponse(found = false, None, None, None, None)
- }
- }
- }
-
- case RegisterApplication(description) => {
+ case RegisterApplication(description, driver) => {
+ // TODO Prevent repeated registrations from some driver
if (state == RecoveryState.STANDBY) {
// ignore, don't send response
} else {
logInfo("Registering app " + description.name)
- val app = createApplication(description, sender)
+ val app = createApplication(description, driver)
registerApplication(app)
logInfo("Registered app " + description.name + " with ID " + app.id)
persistenceEngine.addApplication(app)
- sender ! RegisteredApplication(app.id, masterUrl)
+ driver.send(RegisteredApplication(app.id, self))
schedule()
}
}
@@ -343,7 +284,7 @@ private[master] class Master(
val appInfo = idToApp(appId)
exec.state = state
if (state == ExecutorState.RUNNING) { appInfo.resetRetryCount() }
- exec.application.driver ! ExecutorUpdated(execId, state, message, exitStatus)
+ exec.application.driver.send(ExecutorUpdated(execId, state, message, exitStatus))
if (ExecutorState.isFinished(state)) {
// Remove this executor from the worker and app
logInfo(s"Removing executor ${exec.fullId} because it is $state")
@@ -384,7 +325,7 @@ private[master] class Master(
}
}
- case Heartbeat(workerId) => {
+ case Heartbeat(workerId, worker) => {
idToWorker.get(workerId) match {
case Some(workerInfo) =>
workerInfo.lastHeartbeat = System.currentTimeMillis()
@@ -392,7 +333,7 @@ private[master] class Master(
if (workers.map(_.id).contains(workerId)) {
logWarning(s"Got heartbeat from unregistered worker $workerId." +
" Asking it to re-register.")
- sender ! ReconnectWorker(masterUrl)
+ worker.send(ReconnectWorker(masterUrl))
} else {
logWarning(s"Got heartbeat from unregistered worker $workerId." +
" This worker was never registered, so ignoring the heartbeat.")
@@ -444,30 +385,103 @@ private[master] class Master(
logInfo(s"Received unregister request from application $applicationId")
idToApp.get(applicationId).foreach(finishApplication)
- case DisassociatedEvent(_, address, _) => {
- // The disconnected client could've been either a worker or an app; remove whichever it was
- logInfo(s"$address got disassociated, removing it.")
- addressToWorker.get(address).foreach(removeWorker)
- addressToApp.get(address).foreach(finishApplication)
- if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() }
+ case CheckForWorkerTimeOut => {
+ timeOutDeadWorkers()
}
+ }
- case RequestMasterState => {
- sender ! MasterStateResponse(
- host, port, restServerBoundPort,
- workers.toArray, apps.toArray, completedApps.toArray,
- drivers.toArray, completedDrivers.toArray, state)
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ case RequestSubmitDriver(description) => {
+ if (state != RecoveryState.ALIVE) {
+ val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " +
+ "Can only accept driver submissions in ALIVE state."
+ context.reply(SubmitDriverResponse(self, false, None, msg))
+ } else {
+ logInfo("Driver submitted " + description.command.mainClass)
+ val driver = createDriver(description)
+ persistenceEngine.addDriver(driver)
+ waitingDrivers += driver
+ drivers.add(driver)
+ schedule()
+
+ // TODO: It might be good to instead have the submission client poll the master to determine
+ // the current status of the driver. For now it's simply "fire and forget".
+
+ context.reply(SubmitDriverResponse(self, true, Some(driver.id),
+ s"Driver successfully submitted as ${driver.id}"))
+ }
}
- case CheckForWorkerTimeOut => {
- timeOutDeadWorkers()
+ case RequestKillDriver(driverId) => {
+ if (state != RecoveryState.ALIVE) {
+ val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " +
+ s"Can only kill drivers in ALIVE state."
+ context.reply(KillDriverResponse(self, driverId, success = false, msg))
+ } else {
+ logInfo("Asked to kill driver " + driverId)
+ val driver = drivers.find(_.id == driverId)
+ driver match {
+ case Some(d) =>
+ if (waitingDrivers.contains(d)) {
+ waitingDrivers -= d
+ self.send(DriverStateChanged(driverId, DriverState.KILLED, None))
+ } else {
+ // We just notify the worker to kill the driver here. The final bookkeeping occurs
+ // on the return path when the worker submits a state change back to the master
+ // to notify it that the driver was successfully killed.
+ d.worker.foreach { w =>
+ w.endpoint.send(KillDriver(driverId))
+ }
+ }
+ // TODO: It would be nice for this to be a synchronous response
+ val msg = s"Kill request for $driverId submitted"
+ logInfo(msg)
+ context.reply(KillDriverResponse(self, driverId, success = true, msg))
+ case None =>
+ val msg = s"Driver $driverId has already finished or does not exist"
+ logWarning(msg)
+ context.reply(KillDriverResponse(self, driverId, success = false, msg))
+ }
+ }
+ }
+
+ case RequestDriverStatus(driverId) => {
+ if (state != RecoveryState.ALIVE) {
+ val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " +
+ "Can only request driver status in ALIVE state."
+ context.reply(
+ DriverStatusResponse(found = false, None, None, None, Some(new Exception(msg))))
+ } else {
+ (drivers ++ completedDrivers).find(_.id == driverId) match {
+ case Some(driver) =>
+ context.reply(DriverStatusResponse(found = true, Some(driver.state),
+ driver.worker.map(_.id), driver.worker.map(_.hostPort), driver.exception))
+ case None =>
+ context.reply(DriverStatusResponse(found = false, None, None, None, None))
+ }
+ }
+ }
+
+ case RequestMasterState => {
+ context.reply(MasterStateResponse(
+ address.host, address.port, restServerBoundPort,
+ workers.toArray, apps.toArray, completedApps.toArray,
+ drivers.toArray, completedDrivers.toArray, state))
}
case BoundPortsRequest => {
- sender ! BoundPortsResponse(port, webUi.boundPort, restServerBoundPort)
+ context.reply(BoundPortsResponse(address.port, webUi.boundPort, restServerBoundPort))
}
}
+ override def onDisconnected(address: RpcAddress): Unit = {
+ // The disconnected client could've been either a worker or an app; remove whichever it was
+ logInfo(s"$address got disassociated, removing it.")
+ addressToWorker.get(address).foreach(removeWorker)
+ addressToApp.get(address).foreach(finishApplication)
+ if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() }
+ }
+
private def canCompleteRecovery =
workers.count(_.state == WorkerState.UNKNOWN) == 0 &&
apps.count(_.state == ApplicationState.UNKNOWN) == 0
@@ -479,7 +493,7 @@ private[master] class Master(
try {
registerApplication(app)
app.state = ApplicationState.UNKNOWN
- app.driver ! MasterChanged(masterUrl, masterWebUiUrl)
+ app.driver.send(MasterChanged(self, masterWebUiUrl))
} catch {
case e: Exception => logInfo("App " + app.id + " had exception on reconnect")
}
@@ -496,7 +510,7 @@ private[master] class Master(
try {
registerWorker(worker)
worker.state = WorkerState.UNKNOWN
- worker.actor ! MasterChanged(masterUrl, masterWebUiUrl)
+ worker.endpoint.send(MasterChanged(self, masterWebUiUrl))
} catch {
case e: Exception => logInfo("Worker " + worker.id + " had exception on reconnect")
}
@@ -504,6 +518,7 @@ private[master] class Master(
}
private def completeRecovery() {
+ // TODO Why synchronized
// Ensure "only-once" recovery semantics using a short synchronization period.
synchronized {
if (state != RecoveryState.RECOVERING) { return }
@@ -623,10 +638,10 @@ private[master] class Master(
private def launchExecutor(worker: WorkerInfo, exec: ExecutorDesc): Unit = {
logInfo("Launching executor " + exec.fullId + " on worker " + worker.id)
worker.addExecutor(exec)
- worker.actor ! LaunchExecutor(masterUrl,
- exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory)
- exec.application.driver ! ExecutorAdded(
- exec.id, worker.id, worker.hostPort, exec.cores, exec.memory)
+ worker.endpoint.send(LaunchExecutor(masterUrl,
+ exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory))
+ exec.application.driver.send(ExecutorAdded(
+ exec.id, worker.id, worker.hostPort, exec.cores, exec.memory))
}
private def registerWorker(worker: WorkerInfo): Boolean = {
@@ -638,7 +653,7 @@ private[master] class Master(
workers -= w
}
- val workerAddress = worker.actor.path.address
+ val workerAddress = worker.endpoint.address
if (addressToWorker.contains(workerAddress)) {
val oldWorker = addressToWorker(workerAddress)
if (oldWorker.state == WorkerState.UNKNOWN) {
@@ -661,11 +676,11 @@ private[master] class Master(
logInfo("Removing worker " + worker.id + " on " + worker.host + ":" + worker.port)
worker.setState(WorkerState.DEAD)
idToWorker -= worker.id
- addressToWorker -= worker.actor.path.address
+ addressToWorker -= worker.endpoint.address
for (exec <- worker.executors.values) {
logInfo("Telling app of lost executor: " + exec.id)
- exec.application.driver ! ExecutorUpdated(
- exec.id, ExecutorState.LOST, Some("worker lost"), None)
+ exec.application.driver.send(ExecutorUpdated(
+ exec.id, ExecutorState.LOST, Some("worker lost"), None))
exec.application.removeExecutor(exec)
}
for (driver <- worker.drivers.values) {
@@ -687,14 +702,15 @@ private[master] class Master(
schedule()
}
- private def createApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = {
+ private def createApplication(desc: ApplicationDescription, driver: RpcEndpointRef):
+ ApplicationInfo = {
val now = System.currentTimeMillis()
val date = new Date(now)
new ApplicationInfo(now, newApplicationId(date), desc, date, driver, defaultCores)
}
private def registerApplication(app: ApplicationInfo): Unit = {
- val appAddress = app.driver.path.address
+ val appAddress = app.driver.address
if (addressToApp.contains(appAddress)) {
logInfo("Attempted to re-register application at same address: " + appAddress)
return
@@ -703,7 +719,7 @@ private[master] class Master(
applicationMetricsSystem.registerSource(app.appSource)
apps += app
idToApp(app.id) = app
- actorToApp(app.driver) = app
+ endpointToApp(app.driver) = app
addressToApp(appAddress) = app
waitingApps += app
}
@@ -717,8 +733,8 @@ private[master] class Master(
logInfo("Removing app " + app.id)
apps -= app
idToApp -= app.id
- actorToApp -= app.driver
- addressToApp -= app.driver.path.address
+ endpointToApp -= app.driver
+ addressToApp -= app.driver.address
if (completedApps.size >= RETAINED_APPLICATIONS) {
val toRemove = math.max(RETAINED_APPLICATIONS / 10, 1)
completedApps.take(toRemove).foreach( a => {
@@ -735,19 +751,19 @@ private[master] class Master(
for (exec <- app.executors.values) {
exec.worker.removeExecutor(exec)
- exec.worker.actor ! KillExecutor(masterUrl, exec.application.id, exec.id)
+ exec.worker.endpoint.send(KillExecutor(masterUrl, exec.application.id, exec.id))
exec.state = ExecutorState.KILLED
}
app.markFinished(state)
if (state != ApplicationState.FINISHED) {
- app.driver ! ApplicationRemoved(state.toString)
+ app.driver.send(ApplicationRemoved(state.toString))
}
persistenceEngine.removeApplication(app)
schedule()
// Tell all workers that the application has finished, so they can clean up any app state.
workers.foreach { w =>
- w.actor ! ApplicationFinished(app.id)
+ w.endpoint.send(ApplicationFinished(app.id))
}
}
}
@@ -768,7 +784,7 @@ private[master] class Master(
}
val eventLogFilePrefix = EventLoggingListener.getLogPath(
- eventLogDir, app.id, None, app.desc.eventLogCodec)
+ eventLogDir, app.id, app.desc.eventLogCodec)
val fs = Utils.getHadoopFileSystem(eventLogDir, hadoopConf)
val inProgressExists = fs.exists(new Path(eventLogFilePrefix +
EventLoggingListener.IN_PROGRESS))
@@ -832,14 +848,14 @@ private[master] class Master(
private def timeOutDeadWorkers() {
// Copy the workers into an array so we don't modify the hashset while iterating through it
val currentTime = System.currentTimeMillis()
- val toRemove = workers.filter(_.lastHeartbeat < currentTime - WORKER_TIMEOUT).toArray
+ val toRemove = workers.filter(_.lastHeartbeat < currentTime - WORKER_TIMEOUT_MS).toArray
for (worker <- toRemove) {
if (worker.state != WorkerState.DEAD) {
logWarning("Removing %s because we got no heartbeat in %d seconds".format(
- worker.id, WORKER_TIMEOUT/1000))
+ worker.id, WORKER_TIMEOUT_MS / 1000))
removeWorker(worker)
} else {
- if (worker.lastHeartbeat < currentTime - ((REAPER_ITERATIONS + 1) * WORKER_TIMEOUT)) {
+ if (worker.lastHeartbeat < currentTime - ((REAPER_ITERATIONS + 1) * WORKER_TIMEOUT_MS)) {
workers -= worker // we've seen this DEAD worker in the UI, etc. for long enough; cull it
}
}
@@ -862,7 +878,7 @@ private[master] class Master(
logInfo("Launching driver " + driver.id + " on worker " + worker.id)
worker.addDriver(driver)
driver.worker = Some(worker)
- worker.actor ! LaunchDriver(driver.id, driver.desc)
+ worker.endpoint.send(LaunchDriver(driver.id, driver.desc))
driver.state = DriverState.RUNNING
}
@@ -891,57 +907,33 @@ private[master] class Master(
}
private[deploy] object Master extends Logging {
- val systemName = "sparkMaster"
- private val actorName = "Master"
+ val SYSTEM_NAME = "sparkMaster"
+ val ENDPOINT_NAME = "Master"
def main(argStrings: Array[String]) {
SignalLogger.register(log)
val conf = new SparkConf
val args = new MasterArguments(argStrings, conf)
- val (actorSystem, _, _, _) = startSystemAndActor(args.host, args.port, args.webUiPort, conf)
- actorSystem.awaitTermination()
- }
-
- /**
- * Returns an `akka.tcp://...` URL for the Master actor given a sparkUrl `spark://host:port`.
- *
- * @throws SparkException if the url is invalid
- */
- def toAkkaUrl(sparkUrl: String, protocol: String): String = {
- val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl)
- AkkaUtils.address(protocol, systemName, host, port, actorName)
- }
-
- /**
- * Returns an akka `Address` for the Master actor given a sparkUrl `spark://host:port`.
- *
- * @throws SparkException if the url is invalid
- */
- def toAkkaAddress(sparkUrl: String, protocol: String): Address = {
- val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl)
- Address(protocol, systemName, host, port)
+ val (rpcEnv, _, _) = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, conf)
+ rpcEnv.awaitTermination()
}
/**
- * Start the Master and return a four tuple of:
- * (1) The Master actor system
- * (2) The bound port
- * (3) The web UI bound port
- * (4) The REST server bound port, if any
+ * Start the Master and return a three tuple of:
+ * (1) The Master RpcEnv
+ * (2) The web UI bound port
+ * (3) The REST server bound port, if any
*/
- def startSystemAndActor(
+ def startRpcEnvAndEndpoint(
host: String,
port: Int,
webUiPort: Int,
- conf: SparkConf): (ActorSystem, Int, Int, Option[Int]) = {
+ conf: SparkConf): (RpcEnv, Int, Option[Int]) = {
val securityMgr = new SecurityManager(conf)
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf,
- securityManager = securityMgr)
- val actor = actorSystem.actorOf(
- Props(classOf[Master], host, boundPort, webUiPort, securityMgr, conf), actorName)
- val timeout = RpcUtils.askTimeout(conf)
- val portsRequest = actor.ask(BoundPortsRequest)(timeout)
- val portsResponse = Await.result(portsRequest, timeout).asInstanceOf[BoundPortsResponse]
- (actorSystem, boundPort, portsResponse.webUIPort, portsResponse.restPort)
+ val rpcEnv = RpcEnv.create(SYSTEM_NAME, host, port, conf, securityMgr)
+ val masterEndpoint = rpcEnv.setupEndpoint(ENDPOINT_NAME,
+ new Master(rpcEnv, rpcEnv.address, webUiPort, securityMgr, conf))
+ val portsResponse = masterEndpoint.askWithRetry[BoundPortsResponse](BoundPortsRequest)
+ (rpcEnv, portsResponse.webUIPort, portsResponse.restPort)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala
index 15c6296888..68c937188b 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala
@@ -28,7 +28,7 @@ private[master] object MasterMessages {
case object RevokedLeadership
- // Actor System to Master
+ // Master to itself
case object CheckForWorkerTimeOut
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
index 9b3d48c6ed..471811037e 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
@@ -19,9 +19,7 @@ package org.apache.spark.deploy.master
import scala.collection.mutable
-import akka.actor.ActorRef
-
-import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.Utils
private[spark] class WorkerInfo(
@@ -30,7 +28,7 @@ private[spark] class WorkerInfo(
val port: Int,
val cores: Int,
val memory: Int,
- val actor: ActorRef,
+ val endpoint: RpcEndpointRef,
val webUiPort: Int,
val publicAddress: String)
extends Serializable {
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
index 52758d6a7c..6fdff86f66 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
@@ -17,10 +17,7 @@
package org.apache.spark.deploy.master
-import akka.actor.ActorRef
-
import org.apache.spark.{Logging, SparkConf}
-import org.apache.spark.deploy.master.MasterMessages._
import org.apache.curator.framework.CuratorFramework
import org.apache.curator.framework.recipes.leader.{LeaderLatchListener, LeaderLatch}
import org.apache.spark.deploy.SparkCuratorUtil
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
index 06e265f99e..e28e7e379a 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
@@ -19,11 +19,8 @@ package org.apache.spark.deploy.master.ui
import javax.servlet.http.HttpServletRequest
-import scala.concurrent.Await
import scala.xml.Node
-import akka.pattern.ask
-
import org.apache.spark.deploy.ExecutorState
import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState}
import org.apache.spark.deploy.master.ExecutorDesc
@@ -32,14 +29,12 @@ import org.apache.spark.util.Utils
private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") {
- private val master = parent.masterActorRef
- private val timeout = parent.timeout
+ private val master = parent.masterEndpointRef
/** Executor details for a particular application */
def render(request: HttpServletRequest): Seq[Node] = {
val appId = request.getParameter("appId")
- val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse]
- val state = Await.result(stateFuture, timeout)
+ val state = master.askWithRetry[MasterStateResponse](RequestMasterState)
val app = state.activeApps.find(_.id == appId).getOrElse({
state.completedApps.find(_.id == appId).getOrElse(null)
})
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
index 6a7c74020b..c3e20ebf8d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
@@ -19,25 +19,21 @@ package org.apache.spark.deploy.master.ui
import javax.servlet.http.HttpServletRequest
-import scala.concurrent.Await
import scala.xml.Node
-import akka.pattern.ask
import org.json4s.JValue
import org.apache.spark.deploy.JsonProtocol
-import org.apache.spark.deploy.DeployMessages.{RequestKillDriver, MasterStateResponse, RequestMasterState}
+import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, RequestKillDriver, MasterStateResponse, RequestMasterState}
import org.apache.spark.deploy.master._
import org.apache.spark.ui.{WebUIPage, UIUtils}
import org.apache.spark.util.Utils
private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
- private val master = parent.masterActorRef
- private val timeout = parent.timeout
+ private val master = parent.masterEndpointRef
def getMasterState: MasterStateResponse = {
- val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse]
- Await.result(stateFuture, timeout)
+ master.askWithRetry[MasterStateResponse](RequestMasterState)
}
override def renderJson(request: HttpServletRequest): JValue = {
@@ -53,7 +49,9 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
}
def handleDriverKillRequest(request: HttpServletRequest): Unit = {
- handleKillRequest(request, id => { master ! RequestKillDriver(id) })
+ handleKillRequest(request, id => {
+ master.ask[KillDriverResponse](RequestKillDriver(id))
+ })
}
private def handleKillRequest(request: HttpServletRequest, action: String => Unit): Unit = {
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
index 2111a8581f..6174fc11f8 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
@@ -23,7 +23,6 @@ import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationsListResource
UIRoot}
import org.apache.spark.ui.{SparkUI, WebUI}
import org.apache.spark.ui.JettyUtils._
-import org.apache.spark.util.RpcUtils
/**
* Web UI server for the standalone master.
@@ -33,8 +32,7 @@ class MasterWebUI(val master: Master, requestedPort: Int)
extends WebUI(master.securityMgr, requestedPort, master.conf, name = "MasterUI") with Logging
with UIRoot {
- val masterActorRef = master.self
- val timeout = RpcUtils.askTimeout(master.conf)
+ val masterEndpointRef = master.self
val killEnabled = master.conf.getBoolean("spark.ui.killEnabled", true)
val masterPage = new MasterPage(this)
diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala
index 502b9bb701..d5b9bcab14 100644
--- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala
@@ -20,10 +20,10 @@ package org.apache.spark.deploy.rest
import java.io.File
import javax.servlet.http.HttpServletResponse
-import akka.actor.ActorRef
import org.apache.spark.deploy.ClientArguments._
import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription}
-import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils}
+import org.apache.spark.rpc.RpcEndpointRef
+import org.apache.spark.util.Utils
import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf}
/**
@@ -45,35 +45,34 @@ import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf}
* @param host the address this server should bind to
* @param requestedPort the port this server will attempt to bind to
* @param masterConf the conf used by the Master
- * @param masterActor reference to the Master actor to which requests can be sent
+ * @param masterEndpoint reference to the Master endpoint to which requests can be sent
* @param masterUrl the URL of the Master new drivers will attempt to connect to
*/
private[deploy] class StandaloneRestServer(
host: String,
requestedPort: Int,
masterConf: SparkConf,
- masterActor: ActorRef,
+ masterEndpoint: RpcEndpointRef,
masterUrl: String)
extends RestSubmissionServer(host, requestedPort, masterConf) {
protected override val submitRequestServlet =
- new StandaloneSubmitRequestServlet(masterActor, masterUrl, masterConf)
+ new StandaloneSubmitRequestServlet(masterEndpoint, masterUrl, masterConf)
protected override val killRequestServlet =
- new StandaloneKillRequestServlet(masterActor, masterConf)
+ new StandaloneKillRequestServlet(masterEndpoint, masterConf)
protected override val statusRequestServlet =
- new StandaloneStatusRequestServlet(masterActor, masterConf)
+ new StandaloneStatusRequestServlet(masterEndpoint, masterConf)
}
/**
* A servlet for handling kill requests passed to the [[StandaloneRestServer]].
*/
-private[rest] class StandaloneKillRequestServlet(masterActor: ActorRef, conf: SparkConf)
+private[rest] class StandaloneKillRequestServlet(masterEndpoint: RpcEndpointRef, conf: SparkConf)
extends KillRequestServlet {
protected def handleKill(submissionId: String): KillSubmissionResponse = {
- val askTimeout = RpcUtils.askTimeout(conf)
- val response = AkkaUtils.askWithReply[DeployMessages.KillDriverResponse](
- DeployMessages.RequestKillDriver(submissionId), masterActor, askTimeout)
+ val response = masterEndpoint.askWithRetry[DeployMessages.KillDriverResponse](
+ DeployMessages.RequestKillDriver(submissionId))
val k = new KillSubmissionResponse
k.serverSparkVersion = sparkVersion
k.message = response.message
@@ -86,13 +85,12 @@ private[rest] class StandaloneKillRequestServlet(masterActor: ActorRef, conf: Sp
/**
* A servlet for handling status requests passed to the [[StandaloneRestServer]].
*/
-private[rest] class StandaloneStatusRequestServlet(masterActor: ActorRef, conf: SparkConf)
+private[rest] class StandaloneStatusRequestServlet(masterEndpoint: RpcEndpointRef, conf: SparkConf)
extends StatusRequestServlet {
protected def handleStatus(submissionId: String): SubmissionStatusResponse = {
- val askTimeout = RpcUtils.askTimeout(conf)
- val response = AkkaUtils.askWithReply[DeployMessages.DriverStatusResponse](
- DeployMessages.RequestDriverStatus(submissionId), masterActor, askTimeout)
+ val response = masterEndpoint.askWithRetry[DeployMessages.DriverStatusResponse](
+ DeployMessages.RequestDriverStatus(submissionId))
val message = response.exception.map { s"Exception from the cluster:\n" + formatException(_) }
val d = new SubmissionStatusResponse
d.serverSparkVersion = sparkVersion
@@ -110,7 +108,7 @@ private[rest] class StandaloneStatusRequestServlet(masterActor: ActorRef, conf:
* A servlet for handling submit requests passed to the [[StandaloneRestServer]].
*/
private[rest] class StandaloneSubmitRequestServlet(
- masterActor: ActorRef,
+ masterEndpoint: RpcEndpointRef,
masterUrl: String,
conf: SparkConf)
extends SubmitRequestServlet {
@@ -175,10 +173,9 @@ private[rest] class StandaloneSubmitRequestServlet(
responseServlet: HttpServletResponse): SubmitRestProtocolResponse = {
requestMessage match {
case submitRequest: CreateSubmissionRequest =>
- val askTimeout = RpcUtils.askTimeout(conf)
val driverDescription = buildDriverDescription(submitRequest)
- val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse](
- DeployMessages.RequestSubmitDriver(driverDescription), masterActor, askTimeout)
+ val response = masterEndpoint.askWithRetry[DeployMessages.SubmitDriverResponse](
+ DeployMessages.RequestSubmitDriver(driverDescription))
val submitResponse = new CreateSubmissionResponse
submitResponse.serverSparkVersion = sparkVersion
submitResponse.message = response.message
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
index 1386055eb8..ec51c3d935 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
@@ -21,7 +21,6 @@ import java.io._
import scala.collection.JavaConversions._
-import akka.actor.ActorRef
import com.google.common.base.Charsets.UTF_8
import com.google.common.io.Files
import org.apache.hadoop.fs.Path
@@ -31,6 +30,7 @@ import org.apache.spark.deploy.{DriverDescription, SparkHadoopUtil}
import org.apache.spark.deploy.DeployMessages.DriverStateChanged
import org.apache.spark.deploy.master.DriverState
import org.apache.spark.deploy.master.DriverState.DriverState
+import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.{Utils, Clock, SystemClock}
/**
@@ -43,7 +43,7 @@ private[deploy] class DriverRunner(
val workDir: File,
val sparkHome: File,
val driverDesc: DriverDescription,
- val worker: ActorRef,
+ val worker: RpcEndpointRef,
val workerUrl: String,
val securityManager: SecurityManager)
extends Logging {
@@ -107,7 +107,7 @@ private[deploy] class DriverRunner(
finalState = Some(state)
- worker ! DriverStateChanged(driverId, state, finalException)
+ worker.send(DriverStateChanged(driverId, state, finalException))
}
}.start()
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
index fff17e1095..29a5042285 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
@@ -21,10 +21,10 @@ import java.io._
import scala.collection.JavaConversions._
-import akka.actor.ActorRef
import com.google.common.base.Charsets.UTF_8
import com.google.common.io.Files
+import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.{SecurityManager, SparkConf, Logging}
import org.apache.spark.deploy.{ApplicationDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged
@@ -41,7 +41,7 @@ private[deploy] class ExecutorRunner(
val appDesc: ApplicationDescription,
val cores: Int,
val memory: Int,
- val worker: ActorRef,
+ val worker: RpcEndpointRef,
val workerId: String,
val host: String,
val webUiPort: Int,
@@ -91,7 +91,7 @@ private[deploy] class ExecutorRunner(
process.destroy()
exitCode = Some(process.waitFor())
}
- worker ! ExecutorStateChanged(appId, execId, state, message, exitCode)
+ worker.send(ExecutorStateChanged(appId, execId, state, message, exitCode))
}
/** Stop this executor runner, including killing the process it launched */
@@ -159,7 +159,7 @@ private[deploy] class ExecutorRunner(
val exitCode = process.waitFor()
state = ExecutorState.EXITED
val message = "Command exited with code " + exitCode
- worker ! ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode))
+ worker.send(ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode)))
} catch {
case interrupted: InterruptedException => {
logInfo("Runner thread for executor " + fullId + " interrupted")
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index ebc6cd76c6..82e9578bbc 100755
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -21,15 +21,14 @@ import java.io.File
import java.io.IOException
import java.text.SimpleDateFormat
import java.util.{UUID, Date}
+import java.util.concurrent._
+import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture}
import scala.collection.JavaConversions._
import scala.collection.mutable.{HashMap, HashSet}
-import scala.concurrent.duration._
-import scala.language.postfixOps
+import scala.concurrent.ExecutionContext
import scala.util.Random
-
-import akka.actor._
-import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
+import scala.util.control.NonFatal
import org.apache.spark.{Logging, SecurityManager, SparkConf}
import org.apache.spark.deploy.{Command, ExecutorDescription, ExecutorState}
@@ -38,32 +37,39 @@ import org.apache.spark.deploy.ExternalShuffleService
import org.apache.spark.deploy.master.{DriverState, Master}
import org.apache.spark.deploy.worker.ui.WorkerWebUI
import org.apache.spark.metrics.MetricsSystem
-import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils}
+import org.apache.spark.rpc._
+import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils}
-/**
- * @param masterAkkaUrls Each url should be a valid akka url.
- */
private[worker] class Worker(
- host: String,
- port: Int,
+ override val rpcEnv: RpcEnv,
webUiPort: Int,
cores: Int,
memory: Int,
- masterAkkaUrls: Array[String],
- actorSystemName: String,
- actorName: String,
+ masterRpcAddresses: Array[RpcAddress],
+ systemName: String,
+ endpointName: String,
workDirPath: String = null,
val conf: SparkConf,
val securityMgr: SecurityManager)
- extends Actor with ActorLogReceive with Logging {
- import context.dispatcher
+ extends ThreadSafeRpcEndpoint with Logging {
+
+ private val host = rpcEnv.address.host
+ private val port = rpcEnv.address.port
Utils.checkHost(host, "Expected hostname")
assert (port > 0)
+ // A scheduled executor used to send messages at the specified time.
+ private val forwordMessageScheduler =
+ ThreadUtils.newDaemonSingleThreadScheduledExecutor("worker-forward-message-scheduler")
+
+ // A separated thread to clean up the workDir. Used to provide the implicit parameter of `Future`
+ // methods.
+ private val cleanupThreadExecutor = ExecutionContext.fromExecutorService(
+ ThreadUtils.newDaemonSingleThreadExecutor("worker-cleanup-thread"))
+
// For worker and executor IDs
private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
-
// Send a heartbeat every (heartbeat timeout) / 4 milliseconds
private val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4
@@ -79,32 +85,26 @@ private[worker] class Worker(
val randomNumberGenerator = new Random(UUID.randomUUID.getMostSignificantBits)
randomNumberGenerator.nextDouble + FUZZ_MULTIPLIER_INTERVAL_LOWER_BOUND
}
- private val INITIAL_REGISTRATION_RETRY_INTERVAL = (math.round(10 *
- REGISTRATION_RETRY_FUZZ_MULTIPLIER)).seconds
- private val PROLONGED_REGISTRATION_RETRY_INTERVAL = (math.round(60
- * REGISTRATION_RETRY_FUZZ_MULTIPLIER)).seconds
+ private val INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS = (math.round(10 *
+ REGISTRATION_RETRY_FUZZ_MULTIPLIER))
+ private val PROLONGED_REGISTRATION_RETRY_INTERVAL_SECONDS = (math.round(60
+ * REGISTRATION_RETRY_FUZZ_MULTIPLIER))
private val CLEANUP_ENABLED = conf.getBoolean("spark.worker.cleanup.enabled", false)
// How often worker will clean up old app folders
private val CLEANUP_INTERVAL_MILLIS =
conf.getLong("spark.worker.cleanup.interval", 60 * 30) * 1000
// TTL for app folders/data; after TTL expires it will be cleaned up
- private val APP_DATA_RETENTION_SECS =
+ private val APP_DATA_RETENTION_SECONDS =
conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600)
private val testing: Boolean = sys.props.contains("spark.testing")
- private var master: ActorSelection = null
- private var masterAddress: Address = null
+ private var master: Option[RpcEndpointRef] = None
private var activeMasterUrl: String = ""
private[worker] var activeMasterWebUiUrl : String = ""
- private val akkaUrl = AkkaUtils.address(
- AkkaUtils.protocol(context.system),
- actorSystemName,
- host,
- port,
- actorName)
- @volatile private var registered = false
- @volatile private var connected = false
+ private val workerUri = rpcEnv.uriOf(systemName, rpcEnv.address, endpointName)
+ private var registered = false
+ private var connected = false
private val workerId = generateWorkerId()
private val sparkHome =
if (testing) {
@@ -136,7 +136,18 @@ private[worker] class Worker(
private val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr)
private val workerSource = new WorkerSource(this)
- private var registrationRetryTimer: Option[Cancellable] = None
+ private var registerMasterFutures: Array[JFuture[_]] = null
+ private var registrationRetryTimer: Option[JScheduledFuture[_]] = None
+
+ // A thread pool for registering with masters. Because registering with a master is a blocking
+ // action, this thread pool must be able to create "masterRpcAddresses.size" threads at the same
+ // time so that we can register with all masters.
+ private val registerMasterThreadPool = new ThreadPoolExecutor(
+ 0,
+ masterRpcAddresses.size, // Make sure we can register with all masters at the same time
+ 60L, TimeUnit.SECONDS,
+ new SynchronousQueue[Runnable](),
+ ThreadUtils.namedThreadFactory("worker-register-master-threadpool"))
var coresUsed = 0
var memoryUsed = 0
@@ -162,14 +173,13 @@ private[worker] class Worker(
}
}
- override def preStart() {
+ override def onStart() {
assert(!registered)
logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format(
host, port, cores, Utils.megabytesToString(memory)))
logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}")
logInfo("Spark home: " + sparkHome)
createWorkDir()
- context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
shuffleService.startIfEnabled()
webUi = new WorkerWebUI(this, workDir, webUiPort)
webUi.bind()
@@ -181,24 +191,32 @@ private[worker] class Worker(
metricsSystem.getServletHandlers.foreach(webUi.attachHandler)
}
- private def changeMaster(url: String, uiUrl: String) {
+ private def changeMaster(masterRef: RpcEndpointRef, uiUrl: String) {
// activeMasterUrl it's a valid Spark url since we receive it from master.
- activeMasterUrl = url
+ activeMasterUrl = masterRef.address.toSparkURL
activeMasterWebUiUrl = uiUrl
- master = context.actorSelection(
- Master.toAkkaUrl(activeMasterUrl, AkkaUtils.protocol(context.system)))
- masterAddress = Master.toAkkaAddress(activeMasterUrl, AkkaUtils.protocol(context.system))
+ master = Some(masterRef)
connected = true
// Cancel any outstanding re-registration attempts because we found a new master
- registrationRetryTimer.foreach(_.cancel())
- registrationRetryTimer = None
+ cancelLastRegistrationRetry()
}
- private def tryRegisterAllMasters() {
- for (masterAkkaUrl <- masterAkkaUrls) {
- logInfo("Connecting to master " + masterAkkaUrl + "...")
- val actor = context.actorSelection(masterAkkaUrl)
- actor ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort, publicAddress)
+ private def tryRegisterAllMasters(): Array[JFuture[_]] = {
+ masterRpcAddresses.map { masterAddress =>
+ registerMasterThreadPool.submit(new Runnable {
+ override def run(): Unit = {
+ try {
+ logInfo("Connecting to master " + masterAddress + "...")
+ val masterEndpoint =
+ rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME)
+ masterEndpoint.send(RegisterWorker(
+ workerId, host, port, self, cores, memory, webUi.boundPort, publicAddress))
+ } catch {
+ case ie: InterruptedException => // Cancelled
+ case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e)
+ }
+ }
+ })
}
}
@@ -211,8 +229,7 @@ private[worker] class Worker(
Utils.tryOrExit {
connectionAttemptCount += 1
if (registered) {
- registrationRetryTimer.foreach(_.cancel())
- registrationRetryTimer = None
+ cancelLastRegistrationRetry()
} else if (connectionAttemptCount <= TOTAL_REGISTRATION_RETRIES) {
logInfo(s"Retrying connection to master (attempt # $connectionAttemptCount)")
/**
@@ -235,21 +252,48 @@ private[worker] class Worker(
* still not safe if the old master recovers within this interval, but this is a much
* less likely scenario.
*/
- if (master != null) {
- master ! RegisterWorker(
- workerId, host, port, cores, memory, webUi.boundPort, publicAddress)
- } else {
- // We are retrying the initial registration
- tryRegisterAllMasters()
+ master match {
+ case Some(masterRef) =>
+ // registered == false && master != None means we lost the connection to master, so
+ // masterRef cannot be used and we need to recreate it again. Note: we must not set
+ // master to None due to the above comments.
+ if (registerMasterFutures != null) {
+ registerMasterFutures.foreach(_.cancel(true))
+ }
+ val masterAddress = masterRef.address
+ registerMasterFutures = Array(registerMasterThreadPool.submit(new Runnable {
+ override def run(): Unit = {
+ try {
+ logInfo("Connecting to master " + masterAddress + "...")
+ val masterEndpoint =
+ rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME)
+ masterEndpoint.send(RegisterWorker(
+ workerId, host, port, self, cores, memory, webUi.boundPort, publicAddress))
+ } catch {
+ case ie: InterruptedException => // Cancelled
+ case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e)
+ }
+ }
+ }))
+ case None =>
+ if (registerMasterFutures != null) {
+ registerMasterFutures.foreach(_.cancel(true))
+ }
+ // We are retrying the initial registration
+ registerMasterFutures = tryRegisterAllMasters()
}
// We have exceeded the initial registration retry threshold
// All retries from now on should use a higher interval
if (connectionAttemptCount == INITIAL_REGISTRATION_RETRIES) {
- registrationRetryTimer.foreach(_.cancel())
- registrationRetryTimer = Some {
- context.system.scheduler.schedule(PROLONGED_REGISTRATION_RETRY_INTERVAL,
- PROLONGED_REGISTRATION_RETRY_INTERVAL, self, ReregisterWithMaster)
- }
+ registrationRetryTimer.foreach(_.cancel(true))
+ registrationRetryTimer = Some(
+ forwordMessageScheduler.scheduleAtFixedRate(new Runnable {
+ override def run(): Unit = Utils.tryLogNonFatalError {
+ self.send(ReregisterWithMaster)
+ }
+ }, PROLONGED_REGISTRATION_RETRY_INTERVAL_SECONDS,
+ PROLONGED_REGISTRATION_RETRY_INTERVAL_SECONDS,
+ TimeUnit.SECONDS))
}
} else {
logError("All masters are unresponsive! Giving up.")
@@ -258,41 +302,67 @@ private[worker] class Worker(
}
}
+ /**
+ * Cancel last registeration retry, or do nothing if no retry
+ */
+ private def cancelLastRegistrationRetry(): Unit = {
+ if (registerMasterFutures != null) {
+ registerMasterFutures.foreach(_.cancel(true))
+ registerMasterFutures = null
+ }
+ registrationRetryTimer.foreach(_.cancel(true))
+ registrationRetryTimer = None
+ }
+
private def registerWithMaster() {
- // DisassociatedEvent may be triggered multiple times, so don't attempt registration
+ // onDisconnected may be triggered multiple times, so don't attempt registration
// if there are outstanding registration attempts scheduled.
registrationRetryTimer match {
case None =>
registered = false
- tryRegisterAllMasters()
+ registerMasterFutures = tryRegisterAllMasters()
connectionAttemptCount = 0
- registrationRetryTimer = Some {
- context.system.scheduler.schedule(INITIAL_REGISTRATION_RETRY_INTERVAL,
- INITIAL_REGISTRATION_RETRY_INTERVAL, self, ReregisterWithMaster)
- }
+ registrationRetryTimer = Some(forwordMessageScheduler.scheduleAtFixedRate(
+ new Runnable {
+ override def run(): Unit = Utils.tryLogNonFatalError {
+ self.send(ReregisterWithMaster)
+ }
+ },
+ INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS,
+ INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS,
+ TimeUnit.SECONDS))
case Some(_) =>
logInfo("Not spawning another attempt to register with the master, since there is an" +
" attempt scheduled already.")
}
}
- override def receiveWithLogging: PartialFunction[Any, Unit] = {
- case RegisteredWorker(masterUrl, masterWebUiUrl) =>
- logInfo("Successfully registered with master " + masterUrl)
+ override def receive: PartialFunction[Any, Unit] = {
+ case RegisteredWorker(masterRef, masterWebUiUrl) =>
+ logInfo("Successfully registered with master " + masterRef.address.toSparkURL)
registered = true
- changeMaster(masterUrl, masterWebUiUrl)
- context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis, self, SendHeartbeat)
+ changeMaster(masterRef, masterWebUiUrl)
+ forwordMessageScheduler.scheduleAtFixedRate(new Runnable {
+ override def run(): Unit = Utils.tryLogNonFatalError {
+ self.send(SendHeartbeat)
+ }
+ }, 0, HEARTBEAT_MILLIS, TimeUnit.MILLISECONDS)
if (CLEANUP_ENABLED) {
logInfo(s"Worker cleanup enabled; old application directories will be deleted in: $workDir")
- context.system.scheduler.schedule(CLEANUP_INTERVAL_MILLIS millis,
- CLEANUP_INTERVAL_MILLIS millis, self, WorkDirCleanup)
+ forwordMessageScheduler.scheduleAtFixedRate(new Runnable {
+ override def run(): Unit = Utils.tryLogNonFatalError {
+ self.send(WorkDirCleanup)
+ }
+ }, CLEANUP_INTERVAL_MILLIS, CLEANUP_INTERVAL_MILLIS, TimeUnit.MILLISECONDS)
}
case SendHeartbeat =>
- if (connected) { master ! Heartbeat(workerId) }
+ if (connected) { sendToMaster(Heartbeat(workerId, self)) }
case WorkDirCleanup =>
// Spin up a separate thread (in a future) to do the dir cleanup; don't tie up worker actor
+ // Copy ids so that it can be used in the cleanup thread.
+ val appIds = executors.values.map(_.appId).toSet
val cleanupFuture = concurrent.future {
val appDirs = workDir.listFiles()
if (appDirs == null) {
@@ -302,27 +372,27 @@ private[worker] class Worker(
// the directory is used by an application - check that the application is not running
// when cleaning up
val appIdFromDir = dir.getName
- val isAppStillRunning = executors.values.map(_.appId).contains(appIdFromDir)
+ val isAppStillRunning = appIds.contains(appIdFromDir)
dir.isDirectory && !isAppStillRunning &&
- !Utils.doesDirectoryContainAnyNewFiles(dir, APP_DATA_RETENTION_SECS)
+ !Utils.doesDirectoryContainAnyNewFiles(dir, APP_DATA_RETENTION_SECONDS)
}.foreach { dir =>
logInfo(s"Removing directory: ${dir.getPath}")
Utils.deleteRecursively(dir)
}
- }
+ }(cleanupThreadExecutor)
- cleanupFuture onFailure {
+ cleanupFuture.onFailure {
case e: Throwable =>
logError("App dir cleanup failed: " + e.getMessage, e)
- }
+ }(cleanupThreadExecutor)
- case MasterChanged(masterUrl, masterWebUiUrl) =>
- logInfo("Master has changed, new master is at " + masterUrl)
- changeMaster(masterUrl, masterWebUiUrl)
+ case MasterChanged(masterRef, masterWebUiUrl) =>
+ logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL)
+ changeMaster(masterRef, masterWebUiUrl)
val execs = executors.values.
map(e => new ExecutorDescription(e.appId, e.execId, e.cores, e.state))
- sender ! WorkerSchedulerStateResponse(workerId, execs.toList, drivers.keys.toSeq)
+ masterRef.send(WorkerSchedulerStateResponse(workerId, execs.toList, drivers.keys.toSeq))
case RegisterWorkerFailed(message) =>
if (!registered) {
@@ -369,14 +439,14 @@ private[worker] class Worker(
publicAddress,
sparkHome,
executorDir,
- akkaUrl,
+ workerUri,
conf,
appLocalDirs, ExecutorState.LOADING)
executors(appId + "/" + execId) = manager
manager.start()
coresUsed += cores_
memoryUsed += memory_
- master ! ExecutorStateChanged(appId, execId, manager.state, None, None)
+ sendToMaster(ExecutorStateChanged(appId, execId, manager.state, None, None))
} catch {
case e: Exception => {
logError(s"Failed to launch executor $appId/$execId for ${appDesc.name}.", e)
@@ -384,14 +454,14 @@ private[worker] class Worker(
executors(appId + "/" + execId).kill()
executors -= appId + "/" + execId
}
- master ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED,
- Some(e.toString), None)
+ sendToMaster(ExecutorStateChanged(appId, execId, ExecutorState.FAILED,
+ Some(e.toString), None))
}
}
}
- case ExecutorStateChanged(appId, execId, state, message, exitStatus) =>
- master ! ExecutorStateChanged(appId, execId, state, message, exitStatus)
+ case executorStateChanged @ ExecutorStateChanged(appId, execId, state, message, exitStatus) =>
+ sendToMaster(executorStateChanged)
val fullId = appId + "/" + execId
if (ExecutorState.isFinished(state)) {
executors.get(fullId) match {
@@ -434,7 +504,7 @@ private[worker] class Worker(
sparkHome,
driverDesc.copy(command = Worker.maybeUpdateSSLSettings(driverDesc.command, conf)),
self,
- akkaUrl,
+ workerUri,
securityMgr)
drivers(driverId) = driver
driver.start()
@@ -453,7 +523,7 @@ private[worker] class Worker(
}
}
- case DriverStateChanged(driverId, state, exception) => {
+ case driverStageChanged @ DriverStateChanged(driverId, state, exception) => {
state match {
case DriverState.ERROR =>
logWarning(s"Driver $driverId failed with unrecoverable exception: ${exception.get}")
@@ -466,23 +536,13 @@ private[worker] class Worker(
case _ =>
logDebug(s"Driver $driverId changed state to $state")
}
- master ! DriverStateChanged(driverId, state, exception)
+ sendToMaster(driverStageChanged)
val driver = drivers.remove(driverId).get
finishedDrivers(driverId) = driver
memoryUsed -= driver.driverDesc.mem
coresUsed -= driver.driverDesc.cores
}
- case x: DisassociatedEvent if x.remoteAddress == masterAddress =>
- logInfo(s"$x Disassociated !")
- masterDisconnected()
-
- case RequestWorkerState =>
- sender ! WorkerStateResponse(host, port, workerId, executors.values.toList,
- finishedExecutors.values.toList, drivers.values.toList,
- finishedDrivers.values.toList, activeMasterUrl, cores, memory,
- coresUsed, memoryUsed, activeMasterWebUiUrl)
-
case ReregisterWithMaster =>
reregisterWithMaster()
@@ -491,6 +551,21 @@ private[worker] class Worker(
maybeCleanupApplication(id)
}
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ case RequestWorkerState =>
+ context.reply(WorkerStateResponse(host, port, workerId, executors.values.toList,
+ finishedExecutors.values.toList, drivers.values.toList,
+ finishedDrivers.values.toList, activeMasterUrl, cores, memory,
+ coresUsed, memoryUsed, activeMasterWebUiUrl))
+ }
+
+ override def onDisconnected(remoteAddress: RpcAddress): Unit = {
+ if (master.exists(_.address == remoteAddress)) {
+ logInfo(s"$remoteAddress Disassociated !")
+ masterDisconnected()
+ }
+ }
+
private def masterDisconnected() {
logError("Connection to master failed! Waiting for master to reconnect...")
connected = false
@@ -510,13 +585,29 @@ private[worker] class Worker(
}
}
+ /**
+ * Send a message to the current master. If we have not yet registered successfully with any
+ * master, the message will be dropped.
+ */
+ private def sendToMaster(message: Any): Unit = {
+ master match {
+ case Some(masterRef) => masterRef.send(message)
+ case None =>
+ logWarning(
+ s"Dropping $message because the connection to master has not yet been established")
+ }
+ }
+
private def generateWorkerId(): String = {
"worker-%s-%s-%d".format(createDateFormat.format(new Date), host, port)
}
- override def postStop() {
+ override def onStop() {
+ cleanupThreadExecutor.shutdownNow()
metricsSystem.report()
- registrationRetryTimer.foreach(_.cancel())
+ cancelLastRegistrationRetry()
+ forwordMessageScheduler.shutdownNow()
+ registerMasterThreadPool.shutdownNow()
executors.values.foreach(_.kill())
drivers.values.foreach(_.kill())
shuffleService.stop()
@@ -530,12 +621,12 @@ private[deploy] object Worker extends Logging {
SignalLogger.register(log)
val conf = new SparkConf
val args = new WorkerArguments(argStrings, conf)
- val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores,
+ val rpcEnv = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, args.cores,
args.memory, args.masters, args.workDir)
- actorSystem.awaitTermination()
+ rpcEnv.awaitTermination()
}
- def startSystemAndActor(
+ def startRpcEnvAndEndpoint(
host: String,
port: Int,
webUiPort: Int,
@@ -544,18 +635,17 @@ private[deploy] object Worker extends Logging {
masterUrls: Array[String],
workDir: String,
workerNumber: Option[Int] = None,
- conf: SparkConf = new SparkConf): (ActorSystem, Int) = {
+ conf: SparkConf = new SparkConf): RpcEnv = {
// The LocalSparkCluster runs multiple local sparkWorkerX actor systems
val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("")
val actorName = "Worker"
val securityMgr = new SecurityManager(conf)
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port,
- conf = conf, securityManager = securityMgr)
- val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl(_, AkkaUtils.protocol(actorSystem)))
- actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory,
- masterAkkaUrls, systemName, actorName, workDir, conf, securityMgr), name = actorName)
- (actorSystem, boundPort)
+ val rpcEnv = RpcEnv.create(systemName, host, port, conf, securityMgr)
+ val masterAddresses = masterUrls.map(RpcAddress.fromSparkURL(_))
+ rpcEnv.setupEndpoint(actorName, new Worker(rpcEnv, webUiPort, cores, memory, masterAddresses,
+ systemName, actorName, workDir, conf, securityMgr))
+ rpcEnv
}
def isUseLocalNodeSSLConfig(cmd: Command): Boolean = {
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
index 83fb991891..fae5640b9a 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
@@ -18,7 +18,6 @@
package org.apache.spark.deploy.worker
import org.apache.spark.Logging
-import org.apache.spark.deploy.DeployMessages.SendHeartbeat
import org.apache.spark.rpc._
/**
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala
index 9f9f27d71e..fd905feb97 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala
@@ -17,10 +17,8 @@
package org.apache.spark.deploy.worker.ui
-import scala.concurrent.Await
import scala.xml.Node
-import akka.pattern.ask
import javax.servlet.http.HttpServletRequest
import org.json4s.JValue
@@ -32,18 +30,15 @@ import org.apache.spark.ui.{WebUIPage, UIUtils}
import org.apache.spark.util.Utils
private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") {
- private val workerActor = parent.worker.self
- private val timeout = parent.timeout
+ private val workerEndpoint = parent.worker.self
override def renderJson(request: HttpServletRequest): JValue = {
- val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse]
- val workerState = Await.result(stateFuture, timeout)
+ val workerState = workerEndpoint.askWithRetry[WorkerStateResponse](RequestWorkerState)
JsonProtocol.writeWorkerState(workerState)
}
def render(request: HttpServletRequest): Seq[Node] = {
- val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse]
- val workerState = Await.result(stateFuture, timeout)
+ val workerState = workerEndpoint.askWithRetry[WorkerStateResponse](RequestWorkerState)
val executorHeaders = Seq("ExecutorID", "Cores", "State", "Memory", "Job Details", "Logs")
val runningExecutors = workerState.executors
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
index 12b6b28d4d..3b6938ec63 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
@@ -158,6 +158,8 @@ private[spark] case class RpcAddress(host: String, port: Int) {
val hostPort: String = host + ":" + port
override val toString: String = hostPort
+
+ def toSparkURL: String = "spark://" + hostPort
}
diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
index 0161962cde..31ebe5ac5b 100644
--- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
@@ -180,10 +180,10 @@ private[spark] class AkkaRpcEnv private[akka] (
})
} catch {
case NonFatal(e) =>
- if (needReply) {
- // If the sender asks a reply, we should send the error back to the sender
- _sender ! AkkaFailure(e)
- } else {
+ _sender ! AkkaFailure(e)
+ if (!needReply) {
+ // If the sender does not require a reply, it may not handle the exception. So we rethrow
+ // "e" to make sure it will be processed.
throw e
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index ccf1dc5af6..687ae96204 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -85,7 +85,7 @@ private[spark] class SparkDeploySchedulerBackend(
val coresPerExecutor = conf.getOption("spark.executor.cores").map(_.toInt)
val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory,
command, appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor)
- client = new AppClient(sc.env.actorSystem, masters, appDesc, this, conf)
+ client = new AppClient(sc.env.rpcEnv, masters, appDesc, this, conf)
client.start()
waitForRegistration()
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala
index 014e87bb40..9cb6dd43ba 100644
--- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala
@@ -19,63 +19,21 @@ package org.apache.spark.deploy.master
import java.util.Date
-import scala.concurrent.Await
import scala.concurrent.duration._
import scala.io.Source
import scala.language.postfixOps
-import akka.actor.Address
import org.json4s._
import org.json4s.jackson.JsonMethods._
import org.scalatest.Matchers
import org.scalatest.concurrent.Eventually
import other.supplier.{CustomPersistenceEngine, CustomRecoveryModeFactory}
-import org.apache.spark.{SparkConf, SparkException, SparkFunSuite}
+import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.deploy._
class MasterSuite extends SparkFunSuite with Matchers with Eventually {
- test("toAkkaUrl") {
- val conf = new SparkConf(loadDefaults = false)
- val akkaUrl = Master.toAkkaUrl("spark://1.2.3.4:1234", "akka.tcp")
- assert("akka.tcp://sparkMaster@1.2.3.4:1234/user/Master" === akkaUrl)
- }
-
- test("toAkkaUrl with SSL") {
- val conf = new SparkConf(loadDefaults = false)
- val akkaUrl = Master.toAkkaUrl("spark://1.2.3.4:1234", "akka.ssl.tcp")
- assert("akka.ssl.tcp://sparkMaster@1.2.3.4:1234/user/Master" === akkaUrl)
- }
-
- test("toAkkaUrl: a typo url") {
- val conf = new SparkConf(loadDefaults = false)
- val e = intercept[SparkException] {
- Master.toAkkaUrl("spark://1.2. 3.4:1234", "akka.tcp")
- }
- assert("Invalid master URL: spark://1.2. 3.4:1234" === e.getMessage)
- }
-
- test("toAkkaAddress") {
- val conf = new SparkConf(loadDefaults = false)
- val address = Master.toAkkaAddress("spark://1.2.3.4:1234", "akka.tcp")
- assert(Address("akka.tcp", "sparkMaster", "1.2.3.4", 1234) === address)
- }
-
- test("toAkkaAddress with SSL") {
- val conf = new SparkConf(loadDefaults = false)
- val address = Master.toAkkaAddress("spark://1.2.3.4:1234", "akka.ssl.tcp")
- assert(Address("akka.ssl.tcp", "sparkMaster", "1.2.3.4", 1234) === address)
- }
-
- test("toAkkaAddress: a typo url") {
- val conf = new SparkConf(loadDefaults = false)
- val e = intercept[SparkException] {
- Master.toAkkaAddress("spark://1.2. 3.4:1234", "akka.tcp")
- }
- assert("Invalid master URL: spark://1.2. 3.4:1234" === e.getMessage)
- }
-
test("can use a custom recovery mode factory") {
val conf = new SparkConf(loadDefaults = false)
conf.set("spark.deploy.recoveryMode", "CUSTOM")
@@ -129,16 +87,16 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually {
port = 10000,
cores = 0,
memory = 0,
- actor = null,
+ endpoint = null,
webUiPort = 0,
publicAddress = ""
)
- val (actorSystem, port, uiPort, restPort) =
- Master.startSystemAndActor("127.0.0.1", 7077, 8080, conf)
+ val (rpcEnv, uiPort, restPort) =
+ Master.startRpcEnvAndEndpoint("127.0.0.1", 7077, 8080, conf)
try {
- Await.result(actorSystem.actorSelection("/user/Master").resolveOne(10 seconds), 10 seconds)
+ rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, rpcEnv.address, Master.ENDPOINT_NAME)
CustomPersistenceEngine.lastInstance.isDefined shouldBe true
val persistenceEngine = CustomPersistenceEngine.lastInstance.get
@@ -154,8 +112,8 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually {
workers.map(_.id) should contain(workerToPersist.id)
} finally {
- actorSystem.shutdown()
- actorSystem.awaitTermination()
+ rpcEnv.shutdown()
+ rpcEnv.awaitTermination()
}
CustomRecoveryModeFactory.instantiationAttempts should be > instantiationAttempts
diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala
index 197f68e7ec..96e456d889 100644
--- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala
@@ -23,14 +23,14 @@ import javax.servlet.http.HttpServletResponse
import scala.collection.mutable
-import akka.actor.{Actor, ActorRef, ActorSystem, Props}
import com.google.common.base.Charsets
import org.scalatest.BeforeAndAfterEach
import org.json4s.JsonAST._
import org.json4s.jackson.JsonMethods._
import org.apache.spark._
-import org.apache.spark.util.{AkkaUtils, Utils}
+import org.apache.spark.rpc._
+import org.apache.spark.util.Utils
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.{SparkSubmit, SparkSubmitArguments}
import org.apache.spark.deploy.master.DriverState._
@@ -39,11 +39,11 @@ import org.apache.spark.deploy.master.DriverState._
* Tests for the REST application submission protocol used in standalone cluster mode.
*/
class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach {
- private var actorSystem: Option[ActorSystem] = None
+ private var rpcEnv: Option[RpcEnv] = None
private var server: Option[RestSubmissionServer] = None
override def afterEach() {
- actorSystem.foreach(_.shutdown())
+ rpcEnv.foreach(_.shutdown())
server.foreach(_.stop())
}
@@ -377,31 +377,32 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach {
killMessage: String = "driver is killed",
state: DriverState = FINISHED,
exception: Option[Exception] = None): String = {
- startServer(new DummyMaster(submitId, submitMessage, killMessage, state, exception))
+ startServer(new DummyMaster(_, submitId, submitMessage, killMessage, state, exception))
}
/** Start a smarter dummy server that keeps track of submitted driver states. */
private def startSmartServer(): String = {
- startServer(new SmarterMaster)
+ startServer(new SmarterMaster(_))
}
/** Start a dummy server that is faulty in many ways... */
private def startFaultyServer(): String = {
- startServer(new DummyMaster, faulty = true)
+ startServer(new DummyMaster(_), faulty = true)
}
/**
- * Start a [[StandaloneRestServer]] that communicates with the given actor.
+ * Start a [[StandaloneRestServer]] that communicates with the given endpoint.
* If `faulty` is true, start an [[FaultyStandaloneRestServer]] instead.
* Return the master URL that corresponds to the address of this server.
*/
- private def startServer(makeFakeMaster: => Actor, faulty: Boolean = false): String = {
+ private def startServer(
+ makeFakeMaster: RpcEnv => RpcEndpoint, faulty: Boolean = false): String = {
val name = "test-standalone-rest-protocol"
val conf = new SparkConf
val localhost = Utils.localHostName()
val securityManager = new SecurityManager(conf)
- val (_actorSystem, _) = AkkaUtils.createActorSystem(name, localhost, 0, conf, securityManager)
- val fakeMasterRef = _actorSystem.actorOf(Props(makeFakeMaster))
+ val _rpcEnv = RpcEnv.create(name, localhost, 0, conf, securityManager)
+ val fakeMasterRef = _rpcEnv.setupEndpoint("fake-master", makeFakeMaster(_rpcEnv))
val _server =
if (faulty) {
new FaultyStandaloneRestServer(localhost, 0, conf, fakeMasterRef, "spark://fake:7077")
@@ -410,7 +411,7 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach {
}
val port = _server.start()
// set these to clean them up after every test
- actorSystem = Some(_actorSystem)
+ rpcEnv = Some(_rpcEnv)
server = Some(_server)
s"spark://$localhost:$port"
}
@@ -505,20 +506,21 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach {
* In all responses, the success parameter is always true.
*/
private class DummyMaster(
+ override val rpcEnv: RpcEnv,
submitId: String = "fake-driver-id",
submitMessage: String = "submitted",
killMessage: String = "killed",
state: DriverState = FINISHED,
exception: Option[Exception] = None)
- extends Actor {
+ extends RpcEndpoint {
- override def receive: PartialFunction[Any, Unit] = {
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case RequestSubmitDriver(driverDesc) =>
- sender ! SubmitDriverResponse(success = true, Some(submitId), submitMessage)
+ context.reply(SubmitDriverResponse(self, success = true, Some(submitId), submitMessage))
case RequestKillDriver(driverId) =>
- sender ! KillDriverResponse(driverId, success = true, killMessage)
+ context.reply(KillDriverResponse(self, driverId, success = true, killMessage))
case RequestDriverStatus(driverId) =>
- sender ! DriverStatusResponse(found = true, Some(state), None, None, exception)
+ context.reply(DriverStatusResponse(found = true, Some(state), None, None, exception))
}
}
@@ -531,28 +533,28 @@ private class DummyMaster(
* Submits are always successful while kills and status requests are successful only
* if the driver was submitted in the past.
*/
-private class SmarterMaster extends Actor {
+private class SmarterMaster(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint {
private var counter: Int = 0
private val submittedDrivers = new mutable.HashMap[String, DriverState]
- override def receive: PartialFunction[Any, Unit] = {
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case RequestSubmitDriver(driverDesc) =>
val driverId = s"driver-$counter"
submittedDrivers(driverId) = RUNNING
counter += 1
- sender ! SubmitDriverResponse(success = true, Some(driverId), "submitted")
+ context.reply(SubmitDriverResponse(self, success = true, Some(driverId), "submitted"))
case RequestKillDriver(driverId) =>
val success = submittedDrivers.contains(driverId)
if (success) {
submittedDrivers(driverId) = KILLED
}
- sender ! KillDriverResponse(driverId, success, "killed")
+ context.reply(KillDriverResponse(self, driverId, success, "killed"))
case RequestDriverStatus(driverId) =>
val found = submittedDrivers.contains(driverId)
val state = submittedDrivers.get(driverId)
- sender ! DriverStatusResponse(found, state, None, None, None)
+ context.reply(DriverStatusResponse(found, state, None, None, None))
}
}
@@ -568,7 +570,7 @@ private class FaultyStandaloneRestServer(
host: String,
requestedPort: Int,
masterConf: SparkConf,
- masterActor: ActorRef,
+ masterEndpoint: RpcEndpointRef,
masterUrl: String)
extends RestSubmissionServer(host, requestedPort, masterConf) {
@@ -578,7 +580,7 @@ private class FaultyStandaloneRestServer(
/** A faulty servlet that produces malformed responses. */
class MalformedSubmitServlet
- extends StandaloneSubmitRequestServlet(masterActor, masterUrl, masterConf) {
+ extends StandaloneSubmitRequestServlet(masterEndpoint, masterUrl, masterConf) {
protected override def sendResponse(
responseMessage: SubmitRestProtocolResponse,
responseServlet: HttpServletResponse): Unit = {
@@ -588,7 +590,7 @@ private class FaultyStandaloneRestServer(
}
/** A faulty servlet that produces invalid responses. */
- class InvalidKillServlet extends StandaloneKillRequestServlet(masterActor, masterConf) {
+ class InvalidKillServlet extends StandaloneKillRequestServlet(masterEndpoint, masterConf) {
protected override def handleKill(submissionId: String): KillSubmissionResponse = {
val k = super.handleKill(submissionId)
k.submissionId = null
@@ -597,7 +599,7 @@ private class FaultyStandaloneRestServer(
}
/** A faulty status servlet that explodes. */
- class ExplodingStatusServlet extends StandaloneStatusRequestServlet(masterActor, masterConf) {
+ class ExplodingStatusServlet extends StandaloneStatusRequestServlet(masterEndpoint, masterConf) {
private def explode: Int = 1 / 0
protected override def handleStatus(submissionId: String): SubmissionStatusResponse = {
val s = super.handleStatus(submissionId)
diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala
index ac18f04a11..cd24d79423 100644
--- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala
@@ -17,7 +17,6 @@
package org.apache.spark.deploy.worker
-import akka.actor.AddressFromURIString
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.SecurityManager
import org.apache.spark.rpc.{RpcAddress, RpcEnv}
@@ -26,13 +25,11 @@ class WorkerWatcherSuite extends SparkFunSuite {
test("WorkerWatcher shuts down on valid disassociation") {
val conf = new SparkConf()
val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf))
- val targetWorkerUrl = "akka://test@1.2.3.4:1234/user/Worker"
- val targetWorkerAddress = AddressFromURIString(targetWorkerUrl)
+ val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker")
val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl)
workerWatcher.setTesting(testing = true)
rpcEnv.setupEndpoint("worker-watcher", workerWatcher)
- workerWatcher.onDisconnected(
- RpcAddress(targetWorkerAddress.host.get, targetWorkerAddress.port.get))
+ workerWatcher.onDisconnected(RpcAddress("1.2.3.4", 1234))
assert(workerWatcher.isShutDown)
rpcEnv.shutdown()
}
@@ -40,13 +37,13 @@ class WorkerWatcherSuite extends SparkFunSuite {
test("WorkerWatcher stays alive on invalid disassociation") {
val conf = new SparkConf()
val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf))
- val targetWorkerUrl = "akka://test@1.2.3.4:1234/user/Worker"
- val otherAkkaURL = "akka://test@4.3.2.1:1234/user/OtherActor"
- val otherAkkaAddress = AddressFromURIString(otherAkkaURL)
+ val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker")
+ val otherAddress = "akka://test@4.3.2.1:1234/user/OtherActor"
+ val otherAkkaAddress = RpcAddress("4.3.2.1", 1234)
val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl)
workerWatcher.setTesting(testing = true)
rpcEnv.setupEndpoint("worker-watcher", workerWatcher)
- workerWatcher.onDisconnected(RpcAddress(otherAkkaAddress.host.get, otherAkkaAddress.port.get))
+ workerWatcher.onDisconnected(otherAkkaAddress)
assert(!workerWatcher.isShutDown)
rpcEnv.shutdown()
}
diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcAddressSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcAddressSuite.scala
new file mode 100644
index 0000000000..b3223ec61b
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rpc/RpcAddressSuite.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rpc
+
+import org.apache.spark.{SparkException, SparkFunSuite}
+
+class RpcAddressSuite extends SparkFunSuite {
+
+ test("hostPort") {
+ val address = RpcAddress("1.2.3.4", 1234)
+ assert(address.host == "1.2.3.4")
+ assert(address.port == 1234)
+ assert(address.hostPort == "1.2.3.4:1234")
+ }
+
+ test("fromSparkURL") {
+ val address = RpcAddress.fromSparkURL("spark://1.2.3.4:1234")
+ assert(address.host == "1.2.3.4")
+ assert(address.port == 1234)
+ }
+
+ test("fromSparkURL: a typo url") {
+ val e = intercept[SparkException] {
+ RpcAddress.fromSparkURL("spark://1.2. 3.4:1234")
+ }
+ assert("Invalid master URL: spark://1.2. 3.4:1234" === e.getMessage)
+ }
+
+ test("fromSparkURL: invalid scheme") {
+ val e = intercept[SparkException] {
+ RpcAddress.fromSparkURL("invalid://1.2.3.4:1234")
+ }
+ assert("Invalid master URL: invalid://1.2.3.4:1234" === e.getMessage)
+ }
+
+ test("toSparkURL") {
+ val address = RpcAddress("1.2.3.4", 1234)
+ assert(address.toSparkURL == "spark://1.2.3.4:1234")
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala
index a33a83db7b..4aa75c9230 100644
--- a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.rpc.akka
import org.apache.spark.rpc._
-import org.apache.spark.{SecurityManager, SparkConf}
+import org.apache.spark.{SSLSampleConfigs, SecurityManager, SparkConf}
class AkkaRpcEnvSuite extends RpcEnvSuite {
@@ -47,4 +47,22 @@ class AkkaRpcEnvSuite extends RpcEnvSuite {
}
}
+ test("uriOf") {
+ val uri = env.uriOf("local", RpcAddress("1.2.3.4", 12345), "test_endpoint")
+ assert("akka.tcp://local@1.2.3.4:12345/user/test_endpoint" === uri)
+ }
+
+ test("uriOf: ssl") {
+ val conf = SSLSampleConfigs.sparkSSLConfig()
+ val securityManager = new SecurityManager(conf)
+ val rpcEnv = new AkkaRpcEnvFactory().create(
+ RpcEnvConfig(conf, "test", "localhost", 12346, securityManager))
+ try {
+ val uri = rpcEnv.uriOf("local", RpcAddress("1.2.3.4", 12345), "test_endpoint")
+ assert("akka.ssl.tcp://local@1.2.3.4:12345/user/test_endpoint" === uri)
+ } finally {
+ rpcEnv.shutdown()
+ }
+ }
+
}