From f15806a8f8ca34288ddb2d74b9ff1972c8374b59 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sat, 4 Apr 2015 11:52:05 -0700 Subject: [SPARK-6602][Core] Replace direct use of Akka with Spark RPC interface - part 1 This PR replaced the following `Actor`s to `RpcEndpoint`: 1. HeartbeatReceiver 1. ExecutorActor 1. BlockManagerMasterActor 1. BlockManagerSlaveActor 1. CoarseGrainedExecutorBackend and subclasses 1. CoarseGrainedSchedulerBackend.DriverActor This is the first PR. I will split the work of SPARK-6602 to several PRs for code review. Author: zsxwing Closes #5268 from zsxwing/rpc-rewrite and squashes the following commits: 287e9f8 [zsxwing] Fix the code style 26c56b7 [zsxwing] Merge branch 'master' into rpc-rewrite 9cc825a [zsxwing] Rmove setupThreadSafeEndpoint and add ThreadSafeRpcEndpoint 30a9036 [zsxwing] Make self return null after stopping RpcEndpointRef; fix docs and error messages 705245d [zsxwing] Fix some bugs after rebasing the changes on the master 003cf80 [zsxwing] Update CoarseGrainedExecutorBackend and CoarseGrainedSchedulerBackend to use RpcEndpoint 7d0e6dc [zsxwing] Update BlockManagerSlaveActor to use RpcEndpoint f5d6543 [zsxwing] Update BlockManagerMaster to use RpcEndpoint 30e3f9f [zsxwing] Update ExecutorActor to use RpcEndpoint 478b443 [zsxwing] Update HeartbeatReceiver to use RpcEndpoint --- .../scala/org/apache/spark/HeartbeatReceiver.scala | 66 ++- .../main/scala/org/apache/spark/SparkContext.scala | 23 +- .../src/main/scala/org/apache/spark/SparkEnv.scala | 13 +- .../executor/CoarseGrainedExecutorBackend.scala | 79 ++-- .../scala/org/apache/spark/executor/Executor.scala | 18 +- .../org/apache/spark/executor/ExecutorActor.scala | 41 -- .../apache/spark/executor/ExecutorEndpoint.scala | 43 ++ .../main/scala/org/apache/spark/rpc/RpcEnv.scala | 39 +- .../org/apache/spark/rpc/akka/AkkaRpcEnv.scala | 10 +- .../org/apache/spark/scheduler/DAGScheduler.scala | 11 +- .../cluster/CoarseGrainedClusterMessage.scala | 6 +- .../cluster/CoarseGrainedSchedulerBackend.scala | 148 +++--- .../spark/scheduler/cluster/ExecutorData.scala | 8 +- .../scheduler/cluster/SimrSchedulerBackend.scala | 13 +- .../cluster/SparkDeploySchedulerBackend.scala | 14 +- .../scheduler/cluster/YarnSchedulerBackend.scala | 93 ++-- .../mesos/CoarseMesosSchedulerBackend.scala | 4 +- .../spark/scheduler/local/LocalBackend.scala | 48 +- .../org/apache/spark/storage/BlockManager.scala | 22 +- .../apache/spark/storage/BlockManagerMaster.scala | 72 ++- .../spark/storage/BlockManagerMasterActor.scala | 512 --------------------- .../spark/storage/BlockManagerMasterEndpoint.scala | 509 ++++++++++++++++++++ .../spark/storage/BlockManagerMessages.scala | 7 +- .../spark/storage/BlockManagerSlaveActor.scala | 88 ---- .../spark/storage/BlockManagerSlaveEndpoint.scala | 94 ++++ .../main/scala/org/apache/spark/util/Utils.scala | 10 + .../org/apache/spark/HeartbeatReceiverSuite.scala | 81 ++++ .../scala/org/apache/spark/rpc/RpcEnvSuite.scala | 14 +- .../storage/BlockManagerReplicationSuite.scala | 28 +- .../apache/spark/storage/BlockManagerSuite.scala | 37 +- .../streaming/ReceivedBlockHandlerSuite.scala | 25 +- .../spark/deploy/yarn/ApplicationMaster.scala | 86 ++-- .../apache/spark/deploy/yarn/YarnAllocator.scala | 2 +- 33 files changed, 1169 insertions(+), 1095 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala create mode 100644 core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala delete mode 100644 core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala delete mode 100644 core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala create mode 100644 core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 9f8ad03b91..5871b8c869 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -17,15 +17,15 @@ package org.apache.spark -import scala.concurrent.duration._ -import scala.collection.mutable +import java.util.concurrent.{ScheduledFuture, TimeUnit, Executors} -import akka.actor.{Actor, Cancellable} +import scala.collection.mutable import org.apache.spark.executor.TaskMetrics +import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext} import org.apache.spark.storage.BlockManagerId import org.apache.spark.scheduler.{SlaveLost, TaskScheduler} -import org.apache.spark.util.ActorLogReceive +import org.apache.spark.util.Utils /** * A heartbeat from executors to the driver. This is a shared message used by several internal @@ -51,9 +51,11 @@ private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) * Lives in the driver to receive heartbeats from executors.. */ private[spark] class HeartbeatReceiver(sc: SparkContext) - extends Actor with ActorLogReceive with Logging { + extends ThreadSafeRpcEndpoint with Logging { + + override val rpcEnv: RpcEnv = sc.env.rpcEnv - private var scheduler: TaskScheduler = null + private[spark] var scheduler: TaskScheduler = null // executor ID -> timestamp of when the last heartbeat from this executor was received private val executorLastSeen = new mutable.HashMap[String, Long] @@ -69,34 +71,44 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) sc.conf.getOption("spark.network.timeoutInterval").map(_.toLong * 1000). getOrElse(sc.conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60000)) - private var timeoutCheckingTask: Cancellable = null - - override def preStart(): Unit = { - import context.dispatcher - timeoutCheckingTask = context.system.scheduler.schedule(0.seconds, - checkTimeoutIntervalMs.milliseconds, self, ExpireDeadHosts) - super.preStart() + private var timeoutCheckingTask: ScheduledFuture[_] = null + + private val timeoutCheckingThread = Executors.newSingleThreadScheduledExecutor( + Utils.namedThreadFactory("heartbeat-timeout-checking-thread")) + + private val killExecutorThread = Executors.newSingleThreadExecutor( + Utils.namedThreadFactory("kill-executor-thread")) + + override def onStart(): Unit = { + timeoutCheckingTask = timeoutCheckingThread.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + Option(self).foreach(_.send(ExpireDeadHosts)) + } + }, 0, checkTimeoutIntervalMs, TimeUnit.MILLISECONDS) } - - override def receiveWithLogging: PartialFunction[Any, Unit] = { + + override def receive: PartialFunction[Any, Unit] = { + case ExpireDeadHosts => + expireDeadHosts() case TaskSchedulerIsSet => scheduler = sc.taskScheduler + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case heartbeat @ Heartbeat(executorId, taskMetrics, blockManagerId) => if (scheduler != null) { val unknownExecutor = !scheduler.executorHeartbeatReceived( executorId, taskMetrics, blockManagerId) val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) executorLastSeen(executorId) = System.currentTimeMillis() - sender ! response + context.reply(response) } else { // Because Executor will sleep several seconds before sending the first "Heartbeat", this // case rarely happens. However, if it really happens, log it and ask the executor to // register itself again. logWarning(s"Dropping $heartbeat because TaskScheduler is not ready yet") - sender ! HeartbeatResponse(reregisterBlockManager = true) + context.reply(HeartbeatResponse(reregisterBlockManager = true)) } - case ExpireDeadHosts => - expireDeadHosts() } private def expireDeadHosts(): Unit = { @@ -109,17 +121,25 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) scheduler.executorLost(executorId, SlaveLost("Executor heartbeat " + s"timed out after ${now - lastSeenMs} ms")) if (sc.supportDynamicAllocation) { - sc.killExecutor(executorId) + // Asynchronously kill the executor to avoid blocking the current thread + killExecutorThread.submit(new Runnable { + override def run(): Unit = sc.killExecutor(executorId) + }) } executorLastSeen.remove(executorId) } } } - override def postStop(): Unit = { + override def onStop(): Unit = { if (timeoutCheckingTask != null) { - timeoutCheckingTask.cancel() + timeoutCheckingTask.cancel(true) } - super.postStop() + timeoutCheckingThread.shutdownNow() + killExecutorThread.shutdownNow() } } + +object HeartbeatReceiver { + val ENDPOINT_NAME = "HeartbeatReceiver" +} diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 3b73a8a8fd..942c5975ec 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -32,8 +32,6 @@ import scala.collection.generic.Growable import scala.collection.mutable.HashMap import scala.reflect.{ClassTag, classTag} -import akka.actor.Props - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, @@ -48,12 +46,13 @@ import org.apache.mesos.MesosNativeLibrary import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} -import org.apache.spark.executor.TriggerThreadDump +import org.apache.spark.executor.{ExecutorEndpoint, TriggerThreadDump} import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat} import org.apache.spark.io.CompressionCodec import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ +import org.apache.spark.rpc.RpcAddress import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SparkDeploySchedulerBackend, SimrSchedulerBackend} @@ -360,14 +359,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // We need to register "HeartbeatReceiver" before "createTaskScheduler" because Executor will // retrieve "HeartbeatReceiver" in the constructor. (SPARK-6640) - private val heartbeatReceiver = env.actorSystem.actorOf( - Props(new HeartbeatReceiver(this)), "HeartbeatReceiver") + private val heartbeatReceiver = env.rpcEnv.setupEndpoint( + HeartbeatReceiver.ENDPOINT_NAME, new HeartbeatReceiver(this)) // Create and start the scheduler private[spark] var (schedulerBackend, taskScheduler) = SparkContext.createTaskScheduler(this, master) - heartbeatReceiver ! TaskSchedulerIsSet + heartbeatReceiver.send(TaskSchedulerIsSet) @volatile private[spark] var dagScheduler: DAGScheduler = _ try { @@ -455,10 +454,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli if (executorId == SparkContext.DRIVER_IDENTIFIER) { Some(Utils.getThreadDump()) } else { - val (host, port) = env.blockManager.master.getActorSystemHostPortForExecutor(executorId).get - val actorRef = AkkaUtils.makeExecutorRef("ExecutorActor", conf, host, port, env.actorSystem) - Some(AkkaUtils.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump, actorRef, - AkkaUtils.numRetries(conf), AkkaUtils.retryWaitMs(conf), AkkaUtils.askTimeout(conf))) + val (host, port) = env.blockManager.master.getRpcHostPortForExecutor(executorId).get + val endpointRef = env.rpcEnv.setupEndpointRef( + SparkEnv.executorActorSystemName, + RpcAddress(host, port), + ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME) + Some(endpointRef.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump)) } } catch { case e: Exception => @@ -1418,7 +1419,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli dagScheduler = null listenerBus.stop() eventLogger.foreach(_.stop()) - env.actorSystem.stop(heartbeatReceiver) + env.rpcEnv.stop(heartbeatReceiver) progressBar.foreach(_.stop()) taskScheduler = null // TODO: Cache.stop()? diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 4a2ed82a40..55be0a59fe 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -295,7 +295,9 @@ object SparkEnv extends Logging { } } - def registerOrLookupEndpoint(name: String, endpointCreator: => RpcEndpoint): RpcEndpointRef = { + def registerOrLookupEndpoint( + name: String, endpointCreator: => RpcEndpoint): + RpcEndpointRef = { if (isDriver) { logInfo("Registering " + name) rpcEnv.setupEndpoint(name, endpointCreator) @@ -334,12 +336,13 @@ object SparkEnv extends Logging { new NioBlockTransferService(conf, securityManager) } - val blockManagerMaster = new BlockManagerMaster(registerOrLookup( - "BlockManagerMaster", - new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf, isDriver) + val blockManagerMaster = new BlockManagerMaster(registerOrLookupEndpoint( + BlockManagerMaster.DRIVER_ENDPOINT_NAME, + new BlockManagerMasterEndpoint(rpcEnv, isLocal, conf, listenerBus)), + conf, isDriver) // NB: blockManager is not valid until initialize() is called later. - val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, + val blockManager = new BlockManager(executorId, rpcEnv, blockManagerMaster, serializer, conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, numUsableCores) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 900e678ee0..8300f9f219 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -21,39 +21,45 @@ import java.net.URL import java.nio.ByteBuffer import scala.collection.mutable -import scala.concurrent.Await +import scala.util.{Failure, Success} -import akka.actor.{Actor, ActorSelection, Props} -import akka.pattern.Patterns -import akka.remote.{RemotingLifecycleEvent, DisassociatedEvent} - -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv} +import org.apache.spark.rpc._ +import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.worker.WorkerWatcher import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} +import org.apache.spark.util.{SignalLogger, Utils} private[spark] class CoarseGrainedExecutorBackend( + override val rpcEnv: RpcEnv, driverUrl: String, executorId: String, hostPort: String, cores: Int, userClassPath: Seq[URL], env: SparkEnv) - extends Actor with ActorLogReceive with ExecutorBackend with Logging { + extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging { Utils.checkHostPort(hostPort, "Expected hostport") var executor: Executor = null - var driver: ActorSelection = null + @volatile var driver: Option[RpcEndpointRef] = None - override def preStart() { + override def onStart() { + import scala.concurrent.ExecutionContext.Implicits.global logInfo("Connecting to driver: " + driverUrl) - driver = context.actorSelection(driverUrl) - driver ! RegisterExecutor(executorId, hostPort, cores, extractLogUrls) - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref => + driver = Some(ref) + ref.sendWithReply[RegisteredExecutor.type]( + RegisterExecutor(executorId, self, hostPort, cores, extractLogUrls)) + } onComplete { + case Success(msg) => Utils.tryLogNonFatalError { + Option(self).foreach(_.send(msg)) // msg must be RegisteredExecutor + } + case Failure(e) => logError(s"Cannot register with driver: $driverUrl", e) + } } def extractLogUrls: Map[String, String] = { @@ -62,7 +68,7 @@ private[spark] class CoarseGrainedExecutorBackend( .map(e => (e._1.substring(prefix.length).toLowerCase, e._2)) } - override def receiveWithLogging: PartialFunction[Any, Unit] = { + override def receive: PartialFunction[Any, Unit] = { case RegisteredExecutor => logInfo("Successfully registered with driver") val (hostname, _) = Utils.parseHostPort(hostPort) @@ -92,23 +98,28 @@ private[spark] class CoarseGrainedExecutorBackend( executor.killTask(taskId, interruptThread) } - case x: DisassociatedEvent => - if (x.remoteAddress == driver.anchorPath.address) { - logError(s"Driver $x disassociated! Shutting down.") - System.exit(1) - } else { - logWarning(s"Received irrelevant DisassociatedEvent $x") - } - case StopExecutor => logInfo("Driver commanded a shutdown") executor.stop() - context.stop(self) - context.system.shutdown() + stop() + rpcEnv.shutdown() + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (driver.exists(_.address == remoteAddress)) { + logError(s"Driver $remoteAddress disassociated! Shutting down.") + System.exit(1) + } else { + logWarning(s"An unknown ($remoteAddress) driver disconnected.") + } } override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) { - driver ! StatusUpdate(executorId, taskId, state, data) + val msg = StatusUpdate(executorId, taskId, state, data) + driver match { + case Some(driverRef) => driverRef.send(msg) + case None => logWarning(s"Drop $msg because has not yet connected to driver") + } } } @@ -132,16 +143,14 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { // Bootstrap to fetch the driver's Spark properties. val executorConf = new SparkConf val port = executorConf.getInt("spark.executor.port", 0) - val (fetcher, _) = AkkaUtils.createActorSystem( + val fetcher = RpcEnv.create( "driverPropsFetcher", hostname, port, executorConf, new SecurityManager(executorConf)) - val driver = fetcher.actorSelection(driverUrl) - val timeout = AkkaUtils.askTimeout(executorConf) - val fut = Patterns.ask(driver, RetrieveSparkProps, timeout) - val props = Await.result(fut, timeout).asInstanceOf[Seq[(String, String)]] ++ + val driver = fetcher.setupEndpointRefByURI(driverUrl) + val props = driver.askWithReply[Seq[(String, String)]](RetrieveSparkProps) ++ Seq[(String, String)](("spark.app.id", appId)) fetcher.shutdown() @@ -162,16 +171,14 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { val boundPort = env.conf.getInt("spark.executor.port", 0) assert(boundPort != 0) - // Start the CoarseGrainedExecutorBackend actor. + // Start the CoarseGrainedExecutorBackend endpoint. val sparkHostPort = hostname + ":" + boundPort - env.actorSystem.actorOf( - Props(classOf[CoarseGrainedExecutorBackend], - driverUrl, executorId, sparkHostPort, cores, userClassPath, env), - name = "Executor") + env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend( + env.rpcEnv, driverUrl, executorId, sparkHostPort, cores, userClassPath, env)) workerUrl.foreach { url => env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url)) } - env.actorSystem.awaitTermination() + env.rpcEnv.awaitTermination() } } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index bf3135ef08..14f99a464b 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -27,8 +27,6 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal -import akka.actor.Props - import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task} @@ -88,9 +86,9 @@ private[spark] class Executor( env.blockManager.initialize(conf.getAppId) } - // Create an actor for receiving RPCs from the driver - private val executorActor = env.actorSystem.actorOf( - Props(new ExecutorActor(executorId)), "ExecutorActor") + // Create an RpcEndpoint for receiving RPCs from the driver + private val executorEndpoint = env.rpcEnv.setupEndpoint( + ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME, new ExecutorEndpoint(env.rpcEnv, executorId)) // Whether to load classes in user jars before those in Spark jars private val userClassPathFirst: Boolean = { @@ -139,7 +137,7 @@ private[spark] class Executor( def stop(): Unit = { env.metricsSystem.report() - env.actorSystem.stop(executorActor) + env.rpcEnv.stop(executorEndpoint) isStopped = true threadPool.shutdown() if (!isLocal) { @@ -391,11 +389,8 @@ private[spark] class Executor( } } - private val timeout = AkkaUtils.lookupTimeout(conf) - private val retryAttempts = AkkaUtils.numRetries(conf) - private val retryIntervalMs = AkkaUtils.retryWaitMs(conf) private val heartbeatReceiverRef = - AkkaUtils.makeDriverRef("HeartbeatReceiver", conf, env.actorSystem) + RpcUtils.makeDriverRef(HeartbeatReceiver.ENDPOINT_NAME, conf, env.rpcEnv) /** Reports heartbeat and metrics for active tasks to the driver. */ private def reportHeartBeat(): Unit = { @@ -426,8 +421,7 @@ private[spark] class Executor( val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId) try { - val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef, - retryAttempts, retryIntervalMs, timeout) + val response = heartbeatReceiverRef.askWithReply[HeartbeatResponse](message) if (response.reregisterBlockManager) { logWarning("Told to re-register on heartbeat") env.blockManager.reregister() diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala deleted file mode 100644 index 3e47d13f75..0000000000 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.executor - -import akka.actor.Actor -import org.apache.spark.Logging - -import org.apache.spark.util.{Utils, ActorLogReceive} - -/** - * Driver -> Executor message to trigger a thread dump. - */ -private[spark] case object TriggerThreadDump - -/** - * Actor that runs inside of executors to enable driver -> executor RPC. - */ -private[spark] -class ExecutorActor(executorId: String) extends Actor with ActorLogReceive with Logging { - - override def receiveWithLogging: PartialFunction[Any, Unit] = { - case TriggerThreadDump => - sender ! Utils.getThreadDump() - } - -} diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala new file mode 100644 index 0000000000..cf362f8464 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint} +import org.apache.spark.util.Utils + +/** + * Driver -> Executor message to trigger a thread dump. + */ +private[spark] case object TriggerThreadDump + +/** + * [[RpcEndpoint]] that runs inside of executors to enable driver -> executor RPC. + */ +private[spark] +class ExecutorEndpoint(override val rpcEnv: RpcEnv, executorId: String) extends RpcEndpoint { + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case TriggerThreadDump => + context.reply(Utils.getThreadDump()) + } + +} + +object ExecutorEndpoint { + val EXECUTOR_ENDPOINT_NAME = "ExecutorEndpoint" +} diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 7985941d94..d47e41abcf 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -40,10 +40,7 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { /** * Return RpcEndpointRef of the registered [[RpcEndpoint]]. Will be used to implement - * [[RpcEndpoint.self]]. - * - * Note: This method won't return null. `IllegalArgumentException` will be thrown if calling this - * on a non-existent endpoint. + * [[RpcEndpoint.self]]. Return `null` if the corresponding [[RpcEndpointRef]] does not exist. */ private[rpc] def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef @@ -58,20 +55,6 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { */ def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef - /** - * Register a [[RpcEndpoint]] with a name and return its [[RpcEndpointRef]]. [[RpcEnv]] should - * make sure thread-safely sending messages to [[RpcEndpoint]]. - * - * Thread-safety means processing of one message happens before processing of the next message by - * the same [[RpcEndpoint]]. In the other words, changes to internal fields of a [[RpcEndpoint]] - * are visible when processing the next message, and fields in the [[RpcEndpoint]] need not be - * volatile or equivalent. - * - * However, there is no guarantee that the same thread will be executing the same [[RpcEndpoint]] - * for different messages. - */ - def setupThreadSafeEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef - /** * Retrieve the [[RpcEndpointRef]] represented by `uri` asynchronously. */ @@ -181,7 +164,7 @@ private[spark] trait RpcEnvFactory { * constructor onStart receive* onStop * * Note: `receive` can be called concurrently. If you want `receive` is thread-safe, please use - * [[RpcEnv.setupThreadSafeEndpoint]] + * [[ThreadSafeRpcEndpoint]] * * If any error is thrown from one of [[RpcEndpoint]] methods except `onError`, `onError` will be * invoked with the cause. If `onError` throws an error, [[RpcEnv]] will ignore it. @@ -195,7 +178,7 @@ private[spark] trait RpcEndpoint { /** * The [[RpcEndpointRef]] of this [[RpcEndpoint]]. `self` will become valid when `onStart` is - * called. + * called. And `self` will become `null` when `onStop` is called. * * Note: Because before `onStart`, [[RpcEndpoint]] has not yet been registered and there is not * valid [[RpcEndpointRef]] for it. So don't call `self` before `onStart` is called. @@ -278,6 +261,19 @@ private[spark] trait RpcEndpoint { } } +/** + * A trait that requires RpcEnv thread-safely sending messages to it. + * + * Thread-safety means processing of one message happens before processing of the next message by + * the same [[ThreadSafeRpcEndpoint]]. In the other words, changes to internal fields of a + * [[ThreadSafeRpcEndpoint]] are visible when processing the next message, and fields in the + * [[ThreadSafeRpcEndpoint]] need not be volatile or equivalent. + * + * However, there is no guarantee that the same thread will be executing the same + * [[ThreadSafeRpcEndpoint]] for different messages. + */ +trait ThreadSafeRpcEndpoint extends RpcEndpoint + /** * A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe. */ @@ -407,7 +403,8 @@ private[spark] object RpcAddress { } /** - * A callback that [[RpcEndpoint]] can use it to send back a message or failure. + * A callback that [[RpcEndpoint]] can use it to send back a message or failure. It's thread-safe + * and can be called in any thread. */ private[spark] trait RpcCallContext { diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 769d59b7b3..9e06147dff 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -82,17 +82,9 @@ private[spark] class AkkaRpcEnv private[akka] ( /** * Retrieve the [[RpcEndpointRef]] of `endpoint`. */ - override def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = { - val endpointRef = endpointToRef.get(endpoint) - require(endpointRef != null, s"Cannot find RpcEndpointRef of ${endpoint} in ${this}") - endpointRef - } + override def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = endpointToRef.get(endpoint) override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { - setupThreadSafeEndpoint(name, endpoint) - } - - override def setupThreadSafeEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { @volatile var endpointRef: AkkaRpcEndpointRef = null // Use lazy because the Actor needs to use `endpointRef`. // So `actorRef` should be created after assigning `endpointRef`. diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 7227fa9da4..917cce1f96 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -23,14 +23,10 @@ import java.util.concurrent.{TimeUnit, Executors} import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack} -import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.control.NonFatal -import akka.pattern.ask -import akka.util.Timeout - import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics @@ -165,11 +161,8 @@ class DAGScheduler( taskMetrics: Array[(Long, Int, Int, TaskMetrics)], // (taskId, stageId, stateAttempt, metrics) blockManagerId: BlockManagerId): Boolean = { listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics)) - implicit val timeout = Timeout(600 seconds) - - Await.result( - blockManagerMaster.driverActor ? BlockManagerHeartbeat(blockManagerId), - timeout.duration).asInstanceOf[Boolean] + blockManagerMaster.driverEndpoint.askWithReply[Boolean]( + BlockManagerHeartbeat(blockManagerId), 600 seconds) } // Called by TaskScheduler when an executor fails. diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 9bf74f4be1..70364cea62 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler.cluster import java.nio.ByteBuffer import org.apache.spark.TaskState.TaskState +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.{SerializableBuffer, Utils} private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable @@ -41,6 +42,7 @@ private[spark] object CoarseGrainedClusterMessages { // Executors to driver case class RegisterExecutor( executorId: String, + executorRef: RpcEndpointRef, hostPort: String, cores: Int, logUrls: Map[String, String]) @@ -70,6 +72,8 @@ private[spark] object CoarseGrainedClusterMessages { case class RemoveExecutor(executorId: String, reason: String) extends CoarseGrainedClusterMessage + case class SetupDriver(driver: RpcEndpointRef) extends CoarseGrainedClusterMessage + // Exchanged between the driver and the AM in Yarn client mode case class AddWebUIFilter(filterName:String, filterParams: Map[String, String], proxyBase: String) extends CoarseGrainedClusterMessage @@ -77,7 +81,7 @@ private[spark] object CoarseGrainedClusterMessages { // Messages exchanged between the driver and the cluster manager for executor allocation // In Yarn mode, these are exchanged between the driver and the AM - case object RegisterClusterManager extends CoarseGrainedClusterMessage + case class RegisterClusterManager(am: RpcEndpointRef) extends CoarseGrainedClusterMessage // Request executors by specifying the new total number of executors desired // This includes executors already pending or running diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 5d258d9da4..4c49da87af 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -17,20 +17,16 @@ package org.apache.spark.scheduler.cluster +import java.util.concurrent.{TimeUnit, Executors} import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.concurrent.Await -import scala.concurrent.duration._ - -import akka.actor._ -import akka.pattern.ask -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} +import org.apache.spark.rpc._ import org.apache.spark.{ExecutorAllocationClient, Logging, SparkEnv, SparkException, TaskState} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Utils} +import org.apache.spark.util.{SerializableBuffer, AkkaUtils, Utils} /** * A scheduler backend that waits for coarse grained executors to connect to it through Akka. @@ -41,7 +37,7 @@ import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Ut * (spark.deploy.*). */ private[spark] -class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSystem: ActorSystem) +class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: RpcEnv) extends ExecutorAllocationClient with SchedulerBackend with Logging { // Use an atomic variable to track total number of cores in the cluster for simplicity and speed @@ -49,7 +45,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste // Total number of executors that are currently registered var totalRegisteredExecutors = new AtomicInteger(0) val conf = scheduler.sc.conf - private val timeout = AkkaUtils.askTimeout(conf) private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) // Submit tasks only after (registered resources / total expected resources) // is equal to at least this value, that is double between 0 and 1. @@ -71,48 +66,26 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste // Executors we have requested the cluster manager to kill that have not died yet private val executorsPendingToRemove = new HashSet[String] - class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor with ActorLogReceive { + class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) + extends ThreadSafeRpcEndpoint with Logging { override protected def log = CoarseGrainedSchedulerBackend.this.log - private val addressToExecutorId = new HashMap[Address, String] - override def preStart() { - // Listen for remote client disconnection events, since they don't go through Akka's watch() - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + private val addressToExecutorId = new HashMap[RpcAddress, String] + + private val reviveThread = + Executors.newSingleThreadScheduledExecutor(Utils.namedThreadFactory("driver-revive-thread")) + override def onStart() { // Periodically revive offers to allow delay scheduling to work val reviveInterval = conf.getLong("spark.scheduler.revive.interval", 1000) - import context.dispatcher - context.system.scheduler.schedule(0.millis, reviveInterval.millis, self, ReviveOffers) - } - - def receiveWithLogging: PartialFunction[Any, Unit] = { - case RegisterExecutor(executorId, hostPort, cores, logUrls) => - Utils.checkHostPort(hostPort, "Host port expected " + hostPort) - if (executorDataMap.contains(executorId)) { - sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId) - } else { - logInfo("Registered executor: " + sender + " with ID " + executorId) - sender ! RegisteredExecutor - - addressToExecutorId(sender.path.address) = executorId - totalCoreCount.addAndGet(cores) - totalRegisteredExecutors.addAndGet(1) - val (host, _) = Utils.parseHostPort(hostPort) - val data = new ExecutorData(sender, sender.path.address, host, cores, cores, logUrls) - // This must be synchronized because variables mutated - // in this block are read when requesting executors - CoarseGrainedSchedulerBackend.this.synchronized { - executorDataMap.put(executorId, data) - if (numPendingExecutors > 0) { - numPendingExecutors -= 1 - logDebug(s"Decremented number of pending executors ($numPendingExecutors left)") - } - } - listenerBus.post( - SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data)) - makeOffers() + reviveThread.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + Option(self).foreach(_.send(ReviveOffers)) } + }, 0, reviveInterval, TimeUnit.MILLISECONDS) + } + override def receive: PartialFunction[Any, Unit] = { case StatusUpdate(executorId, taskId, state, data) => scheduler.statusUpdate(taskId, state, data.value) if (TaskState.isFinished(state)) { @@ -133,33 +106,58 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste case KillTask(taskId, executorId, interruptThread) => executorDataMap.get(executorId) match { case Some(executorInfo) => - executorInfo.executorActor ! KillTask(taskId, executorId, interruptThread) + executorInfo.executorEndpoint.send(KillTask(taskId, executorId, interruptThread)) case None => // Ignoring the task kill since the executor is not registered. logWarning(s"Attempted to kill task $taskId for unknown executor $executorId.") } + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RegisterExecutor(executorId, executorRef, hostPort, cores, logUrls) => + Utils.checkHostPort(hostPort, "Host port expected " + hostPort) + if (executorDataMap.contains(executorId)) { + context.reply(RegisterExecutorFailed("Duplicate executor ID: " + executorId)) + } else { + logInfo("Registered executor: " + executorRef + " with ID " + executorId) + context.reply(RegisteredExecutor) + + addressToExecutorId(executorRef.address) = executorId + totalCoreCount.addAndGet(cores) + totalRegisteredExecutors.addAndGet(1) + val (host, _) = Utils.parseHostPort(hostPort) + val data = new ExecutorData(executorRef, executorRef.address, host, cores, cores, logUrls) + // This must be synchronized because variables mutated + // in this block are read when requesting executors + CoarseGrainedSchedulerBackend.this.synchronized { + executorDataMap.put(executorId, data) + if (numPendingExecutors > 0) { + numPendingExecutors -= 1 + logDebug(s"Decremented number of pending executors ($numPendingExecutors left)") + } + } + listenerBus.post( + SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data)) + makeOffers() + } case StopDriver => - sender ! true - context.stop(self) + context.reply(true) + stop() case StopExecutors => logInfo("Asking each executor to shut down") for ((_, executorData) <- executorDataMap) { - executorData.executorActor ! StopExecutor + executorData.executorEndpoint.send(StopExecutor) } - sender ! true + context.reply(true) case RemoveExecutor(executorId, reason) => removeExecutor(executorId, reason) - sender ! true - - case DisassociatedEvent(_, address, _) => - addressToExecutorId.get(address).foreach(removeExecutor(_, - "remote Akka client disassociated")) + context.reply(true) case RetrieveSparkProps => - sender ! sparkProperties + context.reply(sparkProperties) } // Make fake resource offers on all executors @@ -169,6 +167,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste }.toSeq)) } + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + addressToExecutorId.get(remoteAddress).foreach(removeExecutor(_, + "remote Rpc client disassociated")) + } + // Make fake resource offers on just one executor def makeOffers(executorId: String) { val executorData = executorDataMap(executorId) @@ -199,7 +202,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste else { val executorData = executorDataMap(task.executorId) executorData.freeCores -= scheduler.CPUS_PER_TASK - executorData.executorActor ! LaunchTask(new SerializableBuffer(serializedTask)) + executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask))) } } } @@ -223,9 +226,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste case None => logError(s"Asked to remove non-existent executor $executorId") } } + + override def onStop() { + reviveThread.shutdownNow() + } } - var driverActor: ActorRef = null + var driverEndpoint: RpcEndpointRef = null val taskIdsOnSlave = new HashMap[String, HashSet[String]] override def start() { @@ -236,16 +243,15 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste } } // TODO (prashant) send conf instead of properties - driverActor = actorSystem.actorOf( - Props(new DriverActor(properties)), name = CoarseGrainedSchedulerBackend.ACTOR_NAME) + driverEndpoint = rpcEnv.setupEndpoint( + CoarseGrainedSchedulerBackend.ENDPOINT_NAME, new DriverEndpoint(rpcEnv, properties)) } def stopExecutors() { try { - if (driverActor != null) { + if (driverEndpoint != null) { logInfo("Shutting down all executors") - val future = driverActor.ask(StopExecutors)(timeout) - Await.ready(future, timeout) + driverEndpoint.askWithReply[Boolean](StopExecutors) } } catch { case e: Exception => @@ -256,22 +262,21 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste override def stop() { stopExecutors() try { - if (driverActor != null) { - val future = driverActor.ask(StopDriver)(timeout) - Await.ready(future, timeout) + if (driverEndpoint != null) { + driverEndpoint.askWithReply[Boolean](StopDriver) } } catch { case e: Exception => - throw new SparkException("Error stopping standalone scheduler's driver actor", e) + throw new SparkException("Error stopping standalone scheduler's driver endpoint", e) } } override def reviveOffers() { - driverActor ! ReviveOffers + driverEndpoint.send(ReviveOffers) } override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { - driverActor ! KillTask(taskId, executorId, interruptThread) + driverEndpoint.send(KillTask(taskId, executorId, interruptThread)) } override def defaultParallelism(): Int = { @@ -281,11 +286,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste // Called by subclasses when notified of a lost worker def removeExecutor(executorId: String, reason: String) { try { - val future = driverActor.ask(RemoveExecutor(executorId, reason))(timeout) - Await.ready(future, timeout) + driverEndpoint.askWithReply[Boolean](RemoveExecutor(executorId, reason)) } catch { case e: Exception => - throw new SparkException("Error notifying standalone scheduler's driver actor", e) + throw new SparkException("Error notifying standalone scheduler's driver endpoint", e) } } @@ -391,5 +395,5 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste } private[spark] object CoarseGrainedSchedulerBackend { - val ACTOR_NAME = "CoarseGrainedScheduler" + val ENDPOINT_NAME = "CoarseGrainedScheduler" } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala index 5e571efe76..26e72c0bff 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala @@ -17,20 +17,20 @@ package org.apache.spark.scheduler.cluster -import akka.actor.{Address, ActorRef} +import org.apache.spark.rpc.{RpcEndpointRef, RpcAddress} /** * Grouping of data for an executor used by CoarseGrainedSchedulerBackend. * - * @param executorActor The ActorRef representing this executor + * @param executorEndpoint The ActorRef representing this executor * @param executorAddress The network address of this executor * @param executorHost The hostname that this executor is running on * @param freeCores The current number of cores available for work on the executor * @param totalCores The total number of cores available to the executor */ private[cluster] class ExecutorData( - val executorActor: ActorRef, - val executorAddress: Address, + val executorEndpoint: RpcEndpointRef, + val executorAddress: RpcAddress, override val executorHost: String, var freeCores: Int, override val totalCores: Int, diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index 06786a5952..0324c9dab9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -19,16 +19,16 @@ package org.apache.spark.scheduler.cluster import org.apache.hadoop.fs.{Path, FileSystem} +import org.apache.spark.rpc.RpcAddress import org.apache.spark.{Logging, SparkContext, SparkEnv} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler.TaskSchedulerImpl -import org.apache.spark.util.AkkaUtils private[spark] class SimrSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext, driverFilePath: String) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with Logging { val tmpPath = new Path(driverFilePath + "_tmp") @@ -39,12 +39,9 @@ private[spark] class SimrSchedulerBackend( override def start() { super.start() - val driverUrl = AkkaUtils.address( - AkkaUtils.protocol(actorSystem), - SparkEnv.driverActorSystemName, - sc.conf.get("spark.driver.host"), - sc.conf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) + val driverUrl = rpcEnv.uriOf(SparkEnv.driverActorSystemName, + RpcAddress(sc.conf.get("spark.driver.host"), sc.conf.get("spark.driver.port").toInt), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) val conf = SparkHadoopUtil.get.newConfiguration(sc.conf) val fs = FileSystem.get(conf) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index ffd4825705..7eb3fdc19b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -19,17 +19,18 @@ package org.apache.spark.scheduler.cluster import java.util.concurrent.Semaphore +import org.apache.spark.rpc.RpcAddress import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.deploy.client.{AppClient, AppClientListener} import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskSchedulerImpl} -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.Utils private[spark] class SparkDeploySchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext, masters: Array[String]) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with AppClientListener with Logging { @@ -48,12 +49,9 @@ private[spark] class SparkDeploySchedulerBackend( super.start() // The endpoint for executors to talk to us - val driverUrl = AkkaUtils.address( - AkkaUtils.protocol(actorSystem), - SparkEnv.driverActorSystemName, - conf.get("spark.driver.host"), - conf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) + val driverUrl = rpcEnv.uriOf(SparkEnv.driverActorSystemName, + RpcAddress(sc.conf.get("spark.driver.host"), sc.conf.get("spark.driver.port").toInt), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) val args = Seq( "--driver-url", driverUrl, "--executor-id", "{{EXECUTOR_ID}}", diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 5a38ad9f2b..f72566c370 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -19,10 +19,8 @@ package org.apache.spark.scheduler.cluster import scala.concurrent.{Future, ExecutionContext} -import akka.actor.{Actor, ActorRef, Props} -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} - -import org.apache.spark.SparkContext +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.rpc._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.ui.JettyUtils @@ -37,7 +35,7 @@ import scala.util.control.NonFatal private[spark] abstract class YarnSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) { + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) { if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { minRegisteredRatio = 0.8 @@ -45,10 +43,8 @@ private[spark] abstract class YarnSchedulerBackend( protected var totalExpectedExecutors = 0 - private val yarnSchedulerActor: ActorRef = - actorSystem.actorOf( - Props(new YarnSchedulerActor), - name = YarnSchedulerBackend.ACTOR_NAME) + private val yarnSchedulerEndpoint = rpcEnv.setupEndpoint( + YarnSchedulerBackend.ENDPOINT_NAME, new YarnSchedulerEndpoint(rpcEnv)) private implicit val askTimeout = AkkaUtils.askTimeout(sc.conf) @@ -57,16 +53,14 @@ private[spark] abstract class YarnSchedulerBackend( * This includes executors already pending or running. */ override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { - AkkaUtils.askWithReply[Boolean]( - RequestExecutors(requestedTotal), yarnSchedulerActor, askTimeout) + yarnSchedulerEndpoint.askWithReply[Boolean](RequestExecutors(requestedTotal)) } /** * Request that the ApplicationMaster kill the specified executors. */ override def doKillExecutors(executorIds: Seq[String]): Boolean = { - AkkaUtils.askWithReply[Boolean]( - KillExecutors(executorIds), yarnSchedulerActor, askTimeout) + yarnSchedulerEndpoint.askWithReply[Boolean](KillExecutors(executorIds)) } override def sufficientResourcesRegistered(): Boolean = { @@ -96,64 +90,71 @@ private[spark] abstract class YarnSchedulerBackend( } /** - * An actor that communicates with the ApplicationMaster. + * An [[RpcEndpoint]] that communicates with the ApplicationMaster. */ - private class YarnSchedulerActor extends Actor { - private var amActor: Option[ActorRef] = None - - implicit val askAmActorExecutor = ExecutionContext.fromExecutor( - Utils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-executor")) + private class YarnSchedulerEndpoint(override val rpcEnv: RpcEnv) + extends ThreadSafeRpcEndpoint with Logging { + private var amEndpoint: Option[RpcEndpointRef] = None - override def preStart(): Unit = { - // Listen for disassociation events - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - } + private val askAmThreadPool = + Utils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-thread-pool") + implicit val askAmExecutor = ExecutionContext.fromExecutor(askAmThreadPool) override def receive: PartialFunction[Any, Unit] = { - case RegisterClusterManager => - logInfo(s"ApplicationMaster registered as $sender") - amActor = Some(sender) + case RegisterClusterManager(am) => + logInfo(s"ApplicationMaster registered as $am") + amEndpoint = Some(am) + + case AddWebUIFilter(filterName, filterParams, proxyBase) => + addWebUIFilter(filterName, filterParams, proxyBase) + + } + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case r: RequestExecutors => - amActor match { - case Some(actor) => - val driverActor = sender + amEndpoint match { + case Some(am) => Future { - driverActor ! AkkaUtils.askWithReply[Boolean](r, actor, askTimeout) + context.reply(am.askWithReply[Boolean](r)) } onFailure { - case NonFatal(e) => logError(s"Sending $r to AM was unsuccessful", e) + case NonFatal(e) => + logError(s"Sending $r to AM was unsuccessful", e) + context.sendFailure(e) } case None => logWarning("Attempted to request executors before the AM has registered!") - sender ! false + context.reply(false) } case k: KillExecutors => - amActor match { - case Some(actor) => - val driverActor = sender + amEndpoint match { + case Some(am) => Future { - driverActor ! AkkaUtils.askWithReply[Boolean](k, actor, askTimeout) + context.reply(am.askWithReply[Boolean](k)) } onFailure { - case NonFatal(e) => logError(s"Sending $k to AM was unsuccessful", e) + case NonFatal(e) => + logError(s"Sending $k to AM was unsuccessful", e) + context.sendFailure(e) } case None => logWarning("Attempted to kill executors before the AM has registered!") - sender ! false + context.reply(false) } - case AddWebUIFilter(filterName, filterParams, proxyBase) => - addWebUIFilter(filterName, filterParams, proxyBase) - sender ! true + } - case d: DisassociatedEvent => - if (amActor.isDefined && sender == amActor.get) { - logWarning(s"ApplicationMaster has disassociated: $d") - } + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (amEndpoint.exists(_.address == remoteAddress)) { + logWarning(s"ApplicationMaster has disassociated: $remoteAddress") + } + } + + override def onStop(): Unit ={ + askAmThreadPool.shutdownNow() } } } private[spark] object YarnSchedulerBackend { - val ACTOR_NAME = "YarnScheduler" + val ENDPOINT_NAME = "YarnScheduler" } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index e13de0f46e..b037a4966c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -47,7 +47,7 @@ private[spark] class CoarseMesosSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext, master: String) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with MScheduler with Logging { @@ -148,7 +148,7 @@ private[spark] class CoarseMesosSchedulerBackend( SparkEnv.driverActorSystemName, conf.get("spark.driver.host"), conf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) val uri = conf.get("spark.executor.uri", null) if (uri == null) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index eb3f999b5b..70a477a689 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -18,17 +18,14 @@ package org.apache.spark.scheduler.local import java.nio.ByteBuffer +import java.util.concurrent.{Executors, TimeUnit} -import scala.concurrent.duration._ -import scala.language.postfixOps - -import akka.actor.{Actor, ActorRef, Props} - +import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEndpointRef, RpcEnv} +import org.apache.spark.util.Utils import org.apache.spark.{Logging, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} -import org.apache.spark.util.ActorLogReceive private case class ReviveOffers() @@ -39,17 +36,19 @@ private case class KillTask(taskId: Long, interruptThread: Boolean) private case class StopExecutor() /** - * Calls to LocalBackend are all serialized through LocalActor. Using an actor makes the calls on - * LocalBackend asynchronous, which is necessary to prevent deadlock between LocalBackend + * Calls to LocalBackend are all serialized through LocalEndpoint. Using an RpcEndpoint makes the + * calls on LocalBackend asynchronous, which is necessary to prevent deadlock between LocalBackend * and the TaskSchedulerImpl. */ -private[spark] class LocalActor( +private[spark] class LocalEndpoint( + override val rpcEnv: RpcEnv, scheduler: TaskSchedulerImpl, executorBackend: LocalBackend, private val totalCores: Int) - extends Actor with ActorLogReceive with Logging { + extends ThreadSafeRpcEndpoint with Logging { - import context.dispatcher // to use Akka's scheduler.scheduleOnce() + private val reviveThread = Executors.newSingleThreadScheduledExecutor( + Utils.namedThreadFactory("local-revive-thread")) private var freeCores = totalCores @@ -59,7 +58,7 @@ private[spark] class LocalActor( private val executor = new Executor( localExecutorId, localExecutorHostname, SparkEnv.get, isLocal = true) - override def receiveWithLogging: PartialFunction[Any, Unit] = { + override def receive: PartialFunction[Any, Unit] = { case ReviveOffers => reviveOffers() @@ -87,9 +86,17 @@ private[spark] class LocalActor( } if (tasks.isEmpty && scheduler.activeTaskSets.nonEmpty) { // Try to reviveOffer after 1 second, because scheduler may wait for locality timeout - context.system.scheduler.scheduleOnce(1000 millis, self, ReviveOffers) + reviveThread.schedule(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + Option(self).foreach(_.send(ReviveOffers)) + } + }, 1000, TimeUnit.MILLISECONDS) } } + + override def onStop(): Unit = { + reviveThread.shutdownNow() + } } /** @@ -101,31 +108,30 @@ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: extends SchedulerBackend with ExecutorBackend { private val appId = "local-" + System.currentTimeMillis - var localActor: ActorRef = null + var localEndpoint: RpcEndpointRef = null override def start() { - localActor = SparkEnv.get.actorSystem.actorOf( - Props(new LocalActor(scheduler, this, totalCores)), - "LocalBackendActor") + localEndpoint = SparkEnv.get.rpcEnv.setupEndpoint( + "LocalBackendEndpoint", new LocalEndpoint(SparkEnv.get.rpcEnv, scheduler, this, totalCores)) } override def stop() { - localActor ! StopExecutor + localEndpoint.send(StopExecutor) } override def reviveOffers() { - localActor ! ReviveOffers + localEndpoint.send(ReviveOffers) } override def defaultParallelism(): Int = scheduler.conf.getInt("spark.default.parallelism", totalCores) override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { - localActor ! KillTask(taskId, interruptThread) + localEndpoint.send(KillTask(taskId, interruptThread)) } override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) { - localActor ! StatusUpdate(taskId, state, serializedData) + localEndpoint.send(StatusUpdate(taskId, state, serializedData)) } override def applicationId(): String = appId diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index fc31296f4d..1aa0ef18de 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -26,7 +26,6 @@ import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ import scala.util.Random -import akka.actor.{ActorSystem, Props} import sun.nio.ch.DirectBuffer import org.apache.spark._ @@ -37,6 +36,7 @@ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.ExternalShuffleClient import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo +import org.apache.spark.rpc.RpcEnv import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.shuffle.hash.HashShuffleManager @@ -64,7 +64,7 @@ private[spark] class BlockResult( */ private[spark] class BlockManager( executorId: String, - actorSystem: ActorSystem, + rpcEnv: RpcEnv, val master: BlockManagerMaster, defaultSerializer: Serializer, maxMemory: Long, @@ -136,9 +136,9 @@ private[spark] class BlockManager( // Whether to compress shuffle output temporarily spilled to disk private val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true) - private val slaveActor = actorSystem.actorOf( - Props(new BlockManagerSlaveActor(this, mapOutputTracker)), - name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next) + private val slaveEndpoint = rpcEnv.setupEndpoint( + "BlockManagerEndpoint" + BlockManager.ID_GENERATOR.next, + new BlockManagerSlaveEndpoint(rpcEnv, this, mapOutputTracker)) // Pending re-registration action being executed asynchronously or null if none is pending. // Accesses should synchronize on asyncReregisterLock. @@ -167,7 +167,7 @@ private[spark] class BlockManager( */ def this( execId: String, - actorSystem: ActorSystem, + rpcEnv: RpcEnv, master: BlockManagerMaster, serializer: Serializer, conf: SparkConf, @@ -176,7 +176,7 @@ private[spark] class BlockManager( blockTransferService: BlockTransferService, securityManager: SecurityManager, numUsableCores: Int) = { - this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), + this(execId, rpcEnv, master, serializer, BlockManager.getMaxMemory(conf), conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, numUsableCores) } @@ -186,7 +186,7 @@ private[spark] class BlockManager( * where it is only learned after registration with the TaskScheduler). * * This method initializes the BlockTransferService and ShuffleClient, registers with the - * BlockManagerMaster, starts the BlockManagerWorker actor, and registers with a local shuffle + * BlockManagerMaster, starts the BlockManagerWorker endpoint, and registers with a local shuffle * service if configured. */ def initialize(appId: String): Unit = { @@ -202,7 +202,7 @@ private[spark] class BlockManager( blockManagerId } - master.registerBlockManager(blockManagerId, maxMemory, slaveActor) + master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint) // Register Executors' configuration with the local shuffle service, if one should exist. if (externalShuffleServiceEnabled && !blockManagerId.isDriver) { @@ -265,7 +265,7 @@ private[spark] class BlockManager( def reregister(): Unit = { // TODO: We might need to rate limit re-registering. logInfo("BlockManager re-registering with master") - master.registerBlockManager(blockManagerId, maxMemory, slaveActor) + master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint) reportAllBlocks() } @@ -1215,7 +1215,7 @@ private[spark] class BlockManager( shuffleClient.close() } diskBlockManager.stop() - actorSystem.stop(slaveActor) + rpcEnv.stop(slaveEndpoint) blockInfo.clear() memoryStore.clear() diskStore.clear() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 061964826f..ceacf04302 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -20,35 +20,31 @@ package org.apache.spark.storage import scala.concurrent.{Await, Future} import scala.concurrent.ExecutionContext.Implicits.global -import akka.actor._ - +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util.AkkaUtils private[spark] class BlockManagerMaster( - var driverActor: ActorRef, + var driverEndpoint: RpcEndpointRef, conf: SparkConf, isDriver: Boolean) extends Logging { - private val AKKA_RETRY_ATTEMPTS: Int = AkkaUtils.numRetries(conf) - private val AKKA_RETRY_INTERVAL_MS: Int = AkkaUtils.retryWaitMs(conf) - - val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster" val timeout = AkkaUtils.askTimeout(conf) - /** Remove a dead executor from the driver actor. This is only called on the driver side. */ + /** Remove a dead executor from the driver endpoint. This is only called on the driver side. */ def removeExecutor(execId: String) { tell(RemoveExecutor(execId)) logInfo("Removed " + execId + " successfully in removeExecutor") } /** Register the BlockManager's id with the driver. */ - def registerBlockManager(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { + def registerBlockManager( + blockManagerId: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef): Unit = { logInfo("Trying to register BlockManager") - tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveActor)) + tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint)) logInfo("Registered BlockManager") } @@ -59,7 +55,7 @@ class BlockManagerMaster( memSize: Long, diskSize: Long, tachyonSize: Long): Boolean = { - val res = askDriverWithReply[Boolean]( + val res = driverEndpoint.askWithReply[Boolean]( UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize, tachyonSize)) logDebug(s"Updated info of block $blockId") res @@ -67,12 +63,12 @@ class BlockManagerMaster( /** Get locations of the blockId from the driver */ def getLocations(blockId: BlockId): Seq[BlockManagerId] = { - askDriverWithReply[Seq[BlockManagerId]](GetLocations(blockId)) + driverEndpoint.askWithReply[Seq[BlockManagerId]](GetLocations(blockId)) } /** Get locations of multiple blockIds from the driver */ def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { - askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) + driverEndpoint.askWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) } /** @@ -85,11 +81,11 @@ class BlockManagerMaster( /** Get ids of other nodes in the cluster from the driver */ def getPeers(blockManagerId: BlockManagerId): Seq[BlockManagerId] = { - askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId)) + driverEndpoint.askWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId)) } - def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = { - askDriverWithReply[Option[(String, Int)]](GetActorSystemHostPortForExecutor(executorId)) + def getRpcHostPortForExecutor(executorId: String): Option[(String, Int)] = { + driverEndpoint.askWithReply[Option[(String, Int)]](GetRpcHostPortForExecutor(executorId)) } /** @@ -97,12 +93,12 @@ class BlockManagerMaster( * blocks that the driver knows about. */ def removeBlock(blockId: BlockId) { - askDriverWithReply(RemoveBlock(blockId)) + driverEndpoint.askWithReply[Boolean](RemoveBlock(blockId)) } /** Remove all blocks belonging to the given RDD. */ def removeRdd(rddId: Int, blocking: Boolean) { - val future = askDriverWithReply[Future[Seq[Int]]](RemoveRdd(rddId)) + val future = driverEndpoint.askWithReply[Future[Seq[Int]]](RemoveRdd(rddId)) future.onFailure { case e: Exception => logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}") @@ -114,7 +110,7 @@ class BlockManagerMaster( /** Remove all blocks belonging to the given shuffle. */ def removeShuffle(shuffleId: Int, blocking: Boolean) { - val future = askDriverWithReply[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) + val future = driverEndpoint.askWithReply[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) future.onFailure { case e: Exception => logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}") @@ -126,7 +122,7 @@ class BlockManagerMaster( /** Remove all blocks belonging to the given broadcast. */ def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean, blocking: Boolean) { - val future = askDriverWithReply[Future[Seq[Int]]]( + val future = driverEndpoint.askWithReply[Future[Seq[Int]]]( RemoveBroadcast(broadcastId, removeFromMaster)) future.onFailure { case e: Exception => @@ -145,11 +141,11 @@ class BlockManagerMaster( * amount of remaining memory. */ def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = { - askDriverWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) + driverEndpoint.askWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) } def getStorageStatus: Array[StorageStatus] = { - askDriverWithReply[Array[StorageStatus]](GetStorageStatus) + driverEndpoint.askWithReply[Array[StorageStatus]](GetStorageStatus) } /** @@ -165,11 +161,12 @@ class BlockManagerMaster( askSlaves: Boolean = true): Map[BlockManagerId, BlockStatus] = { val msg = GetBlockStatus(blockId, askSlaves) /* - * To avoid potential deadlocks, the use of Futures is necessary, because the master actor + * To avoid potential deadlocks, the use of Futures is necessary, because the master endpoint * should not block on waiting for a block manager, which can in turn be waiting for the - * master actor for a response to a prior message. + * master endpoint for a response to a prior message. */ - val response = askDriverWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) + val response = driverEndpoint. + askWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) val (blockManagerIds, futures) = response.unzip val result = Await.result(Future.sequence(futures), timeout) if (result == null) { @@ -193,33 +190,28 @@ class BlockManagerMaster( filter: BlockId => Boolean, askSlaves: Boolean): Seq[BlockId] = { val msg = GetMatchingBlockIds(filter, askSlaves) - val future = askDriverWithReply[Future[Seq[BlockId]]](msg) + val future = driverEndpoint.askWithReply[Future[Seq[BlockId]]](msg) Await.result(future, timeout) } - /** Stop the driver actor, called only on the Spark driver node */ + /** Stop the driver endpoint, called only on the Spark driver node */ def stop() { - if (driverActor != null && isDriver) { + if (driverEndpoint != null && isDriver) { tell(StopBlockManagerMaster) - driverActor = null + driverEndpoint = null logInfo("BlockManagerMaster stopped") } } - /** Send a one-way message to the master actor, to which we expect it to reply with true. */ + /** Send a one-way message to the master endpoint, to which we expect it to reply with true. */ private def tell(message: Any) { - if (!askDriverWithReply[Boolean](message)) { - throw new SparkException("BlockManagerMasterActor returned false, expected true.") + if (!driverEndpoint.askWithReply[Boolean](message)) { + throw new SparkException("BlockManagerMasterEndpoint returned false, expected true.") } } - /** - * Send a message to the driver actor and get its result within a default timeout, or - * throw a SparkException if this fails. - */ - private def askDriverWithReply[T](message: Any): T = { - AkkaUtils.askWithReply(message, driverActor, AKKA_RETRY_ATTEMPTS, AKKA_RETRY_INTERVAL_MS, - timeout) - } +} +private[spark] object BlockManagerMaster { + val DRIVER_ENDPOINT_NAME = "BlockManagerMaster" } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala deleted file mode 100644 index 5b53280161..0000000000 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ /dev/null @@ -1,512 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import java.util.{HashMap => JHashMap} - -import scala.collection.mutable -import scala.collection.JavaConversions._ -import scala.concurrent.Future -import scala.concurrent.duration._ - -import akka.actor.{Actor, ActorRef} -import akka.pattern.ask - -import org.apache.spark.{Logging, SparkConf, SparkException} -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.scheduler._ -import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, Utils} - -/** - * BlockManagerMasterActor is an actor on the master node to track statuses of - * all slaves' block managers. - */ -private[spark] -class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus: LiveListenerBus) - extends Actor with ActorLogReceive with Logging { - - // Mapping from block manager id to the block manager's information. - private val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo] - - // Mapping from executor ID to block manager ID. - private val blockManagerIdByExecutor = new mutable.HashMap[String, BlockManagerId] - - // Mapping from block id to the set of block managers that have the block. - private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]] - - private val akkaTimeout = AkkaUtils.askTimeout(conf) - - override def receiveWithLogging: PartialFunction[Any, Unit] = { - case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) => - register(blockManagerId, maxMemSize, slaveActor) - sender ! true - - case UpdateBlockInfo( - blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize) => - sender ! updateBlockInfo( - blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize) - - case GetLocations(blockId) => - sender ! getLocations(blockId) - - case GetLocationsMultipleBlockIds(blockIds) => - sender ! getLocationsMultipleBlockIds(blockIds) - - case GetPeers(blockManagerId) => - sender ! getPeers(blockManagerId) - - case GetActorSystemHostPortForExecutor(executorId) => - sender ! getActorSystemHostPortForExecutor(executorId) - - case GetMemoryStatus => - sender ! memoryStatus - - case GetStorageStatus => - sender ! storageStatus - - case GetBlockStatus(blockId, askSlaves) => - sender ! blockStatus(blockId, askSlaves) - - case GetMatchingBlockIds(filter, askSlaves) => - sender ! getMatchingBlockIds(filter, askSlaves) - - case RemoveRdd(rddId) => - sender ! removeRdd(rddId) - - case RemoveShuffle(shuffleId) => - sender ! removeShuffle(shuffleId) - - case RemoveBroadcast(broadcastId, removeFromDriver) => - sender ! removeBroadcast(broadcastId, removeFromDriver) - - case RemoveBlock(blockId) => - removeBlockFromWorkers(blockId) - sender ! true - - case RemoveExecutor(execId) => - removeExecutor(execId) - sender ! true - - case StopBlockManagerMaster => - sender ! true - context.stop(self) - - case BlockManagerHeartbeat(blockManagerId) => - sender ! heartbeatReceived(blockManagerId) - - case other => - logWarning("Got unknown message: " + other) - } - - private def removeRdd(rddId: Int): Future[Seq[Int]] = { - // First remove the metadata for the given RDD, and then asynchronously remove the blocks - // from the slaves. - - // Find all blocks for the given RDD, remove the block from both blockLocations and - // the blockManagerInfo that is tracking the blocks. - val blocks = blockLocations.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) - blocks.foreach { blockId => - val bms: mutable.HashSet[BlockManagerId] = blockLocations.get(blockId) - bms.foreach(bm => blockManagerInfo.get(bm).foreach(_.removeBlock(blockId))) - blockLocations.remove(blockId) - } - - // Ask the slaves to remove the RDD, and put the result in a sequence of Futures. - // The dispatcher is used as an implicit argument into the Future sequence construction. - import context.dispatcher - val removeMsg = RemoveRdd(rddId) - Future.sequence( - blockManagerInfo.values.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] - }.toSeq - ) - } - - private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = { - // Nothing to do in the BlockManagerMasterActor data structures - import context.dispatcher - val removeMsg = RemoveShuffle(shuffleId) - Future.sequence( - blockManagerInfo.values.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Boolean] - }.toSeq - ) - } - - /** - * Delegate RemoveBroadcast messages to each BlockManager because the master may not notified - * of all broadcast blocks. If removeFromDriver is false, broadcast blocks are only removed - * from the executors, but not from the driver. - */ - private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean): Future[Seq[Int]] = { - import context.dispatcher - val removeMsg = RemoveBroadcast(broadcastId, removeFromDriver) - val requiredBlockManagers = blockManagerInfo.values.filter { info => - removeFromDriver || !info.blockManagerId.isDriver - } - Future.sequence( - requiredBlockManagers.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] - }.toSeq - ) - } - - private def removeBlockManager(blockManagerId: BlockManagerId) { - val info = blockManagerInfo(blockManagerId) - - // Remove the block manager from blockManagerIdByExecutor. - blockManagerIdByExecutor -= blockManagerId.executorId - - // Remove it from blockManagerInfo and remove all the blocks. - blockManagerInfo.remove(blockManagerId) - val iterator = info.blocks.keySet.iterator - while (iterator.hasNext) { - val blockId = iterator.next - val locations = blockLocations.get(blockId) - locations -= blockManagerId - if (locations.size == 0) { - blockLocations.remove(blockId) - } - } - listenerBus.post(SparkListenerBlockManagerRemoved(System.currentTimeMillis(), blockManagerId)) - logInfo(s"Removing block manager $blockManagerId") - } - - private def removeExecutor(execId: String) { - logInfo("Trying to remove executor " + execId + " from BlockManagerMaster.") - blockManagerIdByExecutor.get(execId).foreach(removeBlockManager) - } - - /** - * Return true if the driver knows about the given block manager. Otherwise, return false, - * indicating that the block manager should re-register. - */ - private def heartbeatReceived(blockManagerId: BlockManagerId): Boolean = { - if (!blockManagerInfo.contains(blockManagerId)) { - blockManagerId.isDriver && !isLocal - } else { - blockManagerInfo(blockManagerId).updateLastSeenMs() - true - } - } - - // Remove a block from the slaves that have it. This can only be used to remove - // blocks that the master knows about. - private def removeBlockFromWorkers(blockId: BlockId) { - val locations = blockLocations.get(blockId) - if (locations != null) { - locations.foreach { blockManagerId: BlockManagerId => - val blockManager = blockManagerInfo.get(blockManagerId) - if (blockManager.isDefined) { - // Remove the block from the slave's BlockManager. - // Doesn't actually wait for a confirmation and the message might get lost. - // If message loss becomes frequent, we should add retry logic here. - blockManager.get.slaveActor.ask(RemoveBlock(blockId))(akkaTimeout) - } - } - } - } - - // Return a map from the block manager id to max memory and remaining memory. - private def memoryStatus: Map[BlockManagerId, (Long, Long)] = { - blockManagerInfo.map { case(blockManagerId, info) => - (blockManagerId, (info.maxMem, info.remainingMem)) - }.toMap - } - - private def storageStatus: Array[StorageStatus] = { - blockManagerInfo.map { case (blockManagerId, info) => - new StorageStatus(blockManagerId, info.maxMem, info.blocks) - }.toArray - } - - /** - * Return the block's status for all block managers, if any. NOTE: This is a - * potentially expensive operation and should only be used for testing. - * - * If askSlaves is true, the master queries each block manager for the most updated block - * statuses. This is useful when the master is not informed of the given block by all block - * managers. - */ - private def blockStatus( - blockId: BlockId, - askSlaves: Boolean): Map[BlockManagerId, Future[Option[BlockStatus]]] = { - import context.dispatcher - val getBlockStatus = GetBlockStatus(blockId) - /* - * Rather than blocking on the block status query, master actor should simply return - * Futures to avoid potential deadlocks. This can arise if there exists a block manager - * that is also waiting for this master actor's response to a previous message. - */ - blockManagerInfo.values.map { info => - val blockStatusFuture = - if (askSlaves) { - info.slaveActor.ask(getBlockStatus)(akkaTimeout).mapTo[Option[BlockStatus]] - } else { - Future { info.getStatus(blockId) } - } - (info.blockManagerId, blockStatusFuture) - }.toMap - } - - /** - * Return the ids of blocks present in all the block managers that match the given filter. - * NOTE: This is a potentially expensive operation and should only be used for testing. - * - * If askSlaves is true, the master queries each block manager for the most updated block - * statuses. This is useful when the master is not informed of the given block by all block - * managers. - */ - private def getMatchingBlockIds( - filter: BlockId => Boolean, - askSlaves: Boolean): Future[Seq[BlockId]] = { - import context.dispatcher - val getMatchingBlockIds = GetMatchingBlockIds(filter) - Future.sequence( - blockManagerInfo.values.map { info => - val future = - if (askSlaves) { - info.slaveActor.ask(getMatchingBlockIds)(akkaTimeout).mapTo[Seq[BlockId]] - } else { - Future { info.blocks.keys.filter(filter).toSeq } - } - future - } - ).map(_.flatten.toSeq) - } - - private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { - val time = System.currentTimeMillis() - if (!blockManagerInfo.contains(id)) { - blockManagerIdByExecutor.get(id.executorId) match { - case Some(oldId) => - // A block manager of the same executor already exists, so remove it (assumed dead) - logError("Got two different block manager registrations on same executor - " - + s" will replace old one $oldId with new one $id") - removeExecutor(id.executorId) - case None => - } - logInfo("Registering block manager %s with %s RAM, %s".format( - id.hostPort, Utils.bytesToString(maxMemSize), id)) - - blockManagerIdByExecutor(id.executorId) = id - - blockManagerInfo(id) = new BlockManagerInfo( - id, System.currentTimeMillis(), maxMemSize, slaveActor) - } - listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxMemSize)) - } - - private def updateBlockInfo( - blockManagerId: BlockManagerId, - blockId: BlockId, - storageLevel: StorageLevel, - memSize: Long, - diskSize: Long, - tachyonSize: Long): Boolean = { - - if (!blockManagerInfo.contains(blockManagerId)) { - if (blockManagerId.isDriver && !isLocal) { - // We intentionally do not register the master (except in local mode), - // so we should not indicate failure. - return true - } else { - return false - } - } - - if (blockId == null) { - blockManagerInfo(blockManagerId).updateLastSeenMs() - return true - } - - blockManagerInfo(blockManagerId).updateBlockInfo( - blockId, storageLevel, memSize, diskSize, tachyonSize) - - var locations: mutable.HashSet[BlockManagerId] = null - if (blockLocations.containsKey(blockId)) { - locations = blockLocations.get(blockId) - } else { - locations = new mutable.HashSet[BlockManagerId] - blockLocations.put(blockId, locations) - } - - if (storageLevel.isValid) { - locations.add(blockManagerId) - } else { - locations.remove(blockManagerId) - } - - // Remove the block from master tracking if it has been removed on all slaves. - if (locations.size == 0) { - blockLocations.remove(blockId) - } - true - } - - private def getLocations(blockId: BlockId): Seq[BlockManagerId] = { - if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty - } - - private def getLocationsMultipleBlockIds(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { - blockIds.map(blockId => getLocations(blockId)) - } - - /** Get the list of the peers of the given block manager */ - private def getPeers(blockManagerId: BlockManagerId): Seq[BlockManagerId] = { - val blockManagerIds = blockManagerInfo.keySet - if (blockManagerIds.contains(blockManagerId)) { - blockManagerIds.filterNot { _.isDriver }.filterNot { _ == blockManagerId }.toSeq - } else { - Seq.empty - } - } - - /** - * Returns the hostname and port of an executor's actor system, based on the Akka address of its - * BlockManagerSlaveActor. - */ - private def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = { - for ( - blockManagerId <- blockManagerIdByExecutor.get(executorId); - info <- blockManagerInfo.get(blockManagerId); - host <- info.slaveActor.path.address.host; - port <- info.slaveActor.path.address.port - ) yield { - (host, port) - } - } -} - -@DeveloperApi -case class BlockStatus( - storageLevel: StorageLevel, - memSize: Long, - diskSize: Long, - tachyonSize: Long) { - def isCached: Boolean = memSize + diskSize + tachyonSize > 0 -} - -@DeveloperApi -object BlockStatus { - def empty: BlockStatus = BlockStatus(StorageLevel.NONE, 0L, 0L, 0L) -} - -private[spark] class BlockManagerInfo( - val blockManagerId: BlockManagerId, - timeMs: Long, - val maxMem: Long, - val slaveActor: ActorRef) - extends Logging { - - private var _lastSeenMs: Long = timeMs - private var _remainingMem: Long = maxMem - - // Mapping from block id to its status. - private val _blocks = new JHashMap[BlockId, BlockStatus] - - def getStatus(blockId: BlockId): Option[BlockStatus] = Option(_blocks.get(blockId)) - - def updateLastSeenMs() { - _lastSeenMs = System.currentTimeMillis() - } - - def updateBlockInfo( - blockId: BlockId, - storageLevel: StorageLevel, - memSize: Long, - diskSize: Long, - tachyonSize: Long) { - - updateLastSeenMs() - - if (_blocks.containsKey(blockId)) { - // The block exists on the slave already. - val blockStatus: BlockStatus = _blocks.get(blockId) - val originalLevel: StorageLevel = blockStatus.storageLevel - val originalMemSize: Long = blockStatus.memSize - - if (originalLevel.useMemory) { - _remainingMem += originalMemSize - } - } - - if (storageLevel.isValid) { - /* isValid means it is either stored in-memory, on-disk or on-Tachyon. - * The memSize here indicates the data size in or dropped from memory, - * tachyonSize here indicates the data size in or dropped from Tachyon, - * and the diskSize here indicates the data size in or dropped to disk. - * They can be both larger than 0, when a block is dropped from memory to disk. - * Therefore, a safe way to set BlockStatus is to set its info in accurate modes. */ - if (storageLevel.useMemory) { - _blocks.put(blockId, BlockStatus(storageLevel, memSize, 0, 0)) - _remainingMem -= memSize - logInfo("Added %s in memory on %s (size: %s, free: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(memSize), - Utils.bytesToString(_remainingMem))) - } - if (storageLevel.useDisk) { - _blocks.put(blockId, BlockStatus(storageLevel, 0, diskSize, 0)) - logInfo("Added %s on disk on %s (size: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(diskSize))) - } - if (storageLevel.useOffHeap) { - _blocks.put(blockId, BlockStatus(storageLevel, 0, 0, tachyonSize)) - logInfo("Added %s on tachyon on %s (size: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(tachyonSize))) - } - } else if (_blocks.containsKey(blockId)) { - // If isValid is not true, drop the block. - val blockStatus: BlockStatus = _blocks.get(blockId) - _blocks.remove(blockId) - if (blockStatus.storageLevel.useMemory) { - logInfo("Removed %s on %s in memory (size: %s, free: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.memSize), - Utils.bytesToString(_remainingMem))) - } - if (blockStatus.storageLevel.useDisk) { - logInfo("Removed %s on %s on disk (size: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.diskSize))) - } - if (blockStatus.storageLevel.useOffHeap) { - logInfo("Removed %s on %s on tachyon (size: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.tachyonSize))) - } - } - } - - def removeBlock(blockId: BlockId) { - if (_blocks.containsKey(blockId)) { - _remainingMem += _blocks.get(blockId).memSize - _blocks.remove(blockId) - } - } - - def remainingMem: Long = _remainingMem - - def lastSeenMs: Long = _lastSeenMs - - def blocks: JHashMap[BlockId, BlockStatus] = _blocks - - override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem - - def clear() { - _blocks.clear() - } -} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala new file mode 100644 index 0000000000..28c73a7d54 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -0,0 +1,509 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.util.{HashMap => JHashMap} + +import scala.collection.mutable +import scala.collection.JavaConversions._ +import scala.concurrent.{ExecutionContext, Future} + +import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, ThreadSafeRpcEndpoint} +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.scheduler._ +import org.apache.spark.storage.BlockManagerMessages._ +import org.apache.spark.util.Utils + +/** + * BlockManagerMasterEndpoint is an [[ThreadSafeRpcEndpoint]] on the master node to track statuses + * of all slaves' block managers. + */ +private[spark] +class BlockManagerMasterEndpoint( + override val rpcEnv: RpcEnv, + val isLocal: Boolean, + conf: SparkConf, + listenerBus: LiveListenerBus) + extends ThreadSafeRpcEndpoint with Logging { + + // Mapping from block manager id to the block manager's information. + private val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo] + + // Mapping from executor ID to block manager ID. + private val blockManagerIdByExecutor = new mutable.HashMap[String, BlockManagerId] + + // Mapping from block id to the set of block managers that have the block. + private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]] + + private val askThreadPool = Utils.newDaemonCachedThreadPool("block-manager-ask-thread-pool") + private implicit val askExecutionContext = ExecutionContext.fromExecutorService(askThreadPool) + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint) => + register(blockManagerId, maxMemSize, slaveEndpoint) + context.reply(true) + + case UpdateBlockInfo( + blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize) => + context.reply(updateBlockInfo( + blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize)) + + case GetLocations(blockId) => + context.reply(getLocations(blockId)) + + case GetLocationsMultipleBlockIds(blockIds) => + context.reply(getLocationsMultipleBlockIds(blockIds)) + + case GetPeers(blockManagerId) => + context.reply(getPeers(blockManagerId)) + + case GetRpcHostPortForExecutor(executorId) => + context.reply(getRpcHostPortForExecutor(executorId)) + + case GetMemoryStatus => + context.reply(memoryStatus) + + case GetStorageStatus => + context.reply(storageStatus) + + case GetBlockStatus(blockId, askSlaves) => + context.reply(blockStatus(blockId, askSlaves)) + + case GetMatchingBlockIds(filter, askSlaves) => + context.reply(getMatchingBlockIds(filter, askSlaves)) + + case RemoveRdd(rddId) => + context.reply(removeRdd(rddId)) + + case RemoveShuffle(shuffleId) => + context.reply(removeShuffle(shuffleId)) + + case RemoveBroadcast(broadcastId, removeFromDriver) => + context.reply(removeBroadcast(broadcastId, removeFromDriver)) + + case RemoveBlock(blockId) => + removeBlockFromWorkers(blockId) + context.reply(true) + + case RemoveExecutor(execId) => + removeExecutor(execId) + context.reply(true) + + case StopBlockManagerMaster => + context.reply(true) + stop() + + case BlockManagerHeartbeat(blockManagerId) => + context.reply(heartbeatReceived(blockManagerId)) + + } + + private def removeRdd(rddId: Int): Future[Seq[Int]] = { + // First remove the metadata for the given RDD, and then asynchronously remove the blocks + // from the slaves. + + // Find all blocks for the given RDD, remove the block from both blockLocations and + // the blockManagerInfo that is tracking the blocks. + val blocks = blockLocations.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) + blocks.foreach { blockId => + val bms: mutable.HashSet[BlockManagerId] = blockLocations.get(blockId) + bms.foreach(bm => blockManagerInfo.get(bm).foreach(_.removeBlock(blockId))) + blockLocations.remove(blockId) + } + + // Ask the slaves to remove the RDD, and put the result in a sequence of Futures. + // The dispatcher is used as an implicit argument into the Future sequence construction. + val removeMsg = RemoveRdd(rddId) + Future.sequence( + blockManagerInfo.values.map { bm => + bm.slaveEndpoint.sendWithReply[Int](removeMsg) + }.toSeq + ) + } + + private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = { + // Nothing to do in the BlockManagerMasterEndpoint data structures + val removeMsg = RemoveShuffle(shuffleId) + Future.sequence( + blockManagerInfo.values.map { bm => + bm.slaveEndpoint.sendWithReply[Boolean](removeMsg) + }.toSeq + ) + } + + /** + * Delegate RemoveBroadcast messages to each BlockManager because the master may not notified + * of all broadcast blocks. If removeFromDriver is false, broadcast blocks are only removed + * from the executors, but not from the driver. + */ + private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean): Future[Seq[Int]] = { + val removeMsg = RemoveBroadcast(broadcastId, removeFromDriver) + val requiredBlockManagers = blockManagerInfo.values.filter { info => + removeFromDriver || !info.blockManagerId.isDriver + } + Future.sequence( + requiredBlockManagers.map { bm => + bm.slaveEndpoint.sendWithReply[Int](removeMsg) + }.toSeq + ) + } + + private def removeBlockManager(blockManagerId: BlockManagerId) { + val info = blockManagerInfo(blockManagerId) + + // Remove the block manager from blockManagerIdByExecutor. + blockManagerIdByExecutor -= blockManagerId.executorId + + // Remove it from blockManagerInfo and remove all the blocks. + blockManagerInfo.remove(blockManagerId) + val iterator = info.blocks.keySet.iterator + while (iterator.hasNext) { + val blockId = iterator.next + val locations = blockLocations.get(blockId) + locations -= blockManagerId + if (locations.size == 0) { + blockLocations.remove(blockId) + } + } + listenerBus.post(SparkListenerBlockManagerRemoved(System.currentTimeMillis(), blockManagerId)) + logInfo(s"Removing block manager $blockManagerId") + } + + private def removeExecutor(execId: String) { + logInfo("Trying to remove executor " + execId + " from BlockManagerMaster.") + blockManagerIdByExecutor.get(execId).foreach(removeBlockManager) + } + + /** + * Return true if the driver knows about the given block manager. Otherwise, return false, + * indicating that the block manager should re-register. + */ + private def heartbeatReceived(blockManagerId: BlockManagerId): Boolean = { + if (!blockManagerInfo.contains(blockManagerId)) { + blockManagerId.isDriver && !isLocal + } else { + blockManagerInfo(blockManagerId).updateLastSeenMs() + true + } + } + + // Remove a block from the slaves that have it. This can only be used to remove + // blocks that the master knows about. + private def removeBlockFromWorkers(blockId: BlockId) { + val locations = blockLocations.get(blockId) + if (locations != null) { + locations.foreach { blockManagerId: BlockManagerId => + val blockManager = blockManagerInfo.get(blockManagerId) + if (blockManager.isDefined) { + // Remove the block from the slave's BlockManager. + // Doesn't actually wait for a confirmation and the message might get lost. + // If message loss becomes frequent, we should add retry logic here. + blockManager.get.slaveEndpoint.sendWithReply[Boolean](RemoveBlock(blockId)) + } + } + } + } + + // Return a map from the block manager id to max memory and remaining memory. + private def memoryStatus: Map[BlockManagerId, (Long, Long)] = { + blockManagerInfo.map { case(blockManagerId, info) => + (blockManagerId, (info.maxMem, info.remainingMem)) + }.toMap + } + + private def storageStatus: Array[StorageStatus] = { + blockManagerInfo.map { case (blockManagerId, info) => + new StorageStatus(blockManagerId, info.maxMem, info.blocks) + }.toArray + } + + /** + * Return the block's status for all block managers, if any. NOTE: This is a + * potentially expensive operation and should only be used for testing. + * + * If askSlaves is true, the master queries each block manager for the most updated block + * statuses. This is useful when the master is not informed of the given block by all block + * managers. + */ + private def blockStatus( + blockId: BlockId, + askSlaves: Boolean): Map[BlockManagerId, Future[Option[BlockStatus]]] = { + val getBlockStatus = GetBlockStatus(blockId) + /* + * Rather than blocking on the block status query, master endpoint should simply return + * Futures to avoid potential deadlocks. This can arise if there exists a block manager + * that is also waiting for this master endpoint's response to a previous message. + */ + blockManagerInfo.values.map { info => + val blockStatusFuture = + if (askSlaves) { + info.slaveEndpoint.sendWithReply[Option[BlockStatus]](getBlockStatus) + } else { + Future { info.getStatus(blockId) } + } + (info.blockManagerId, blockStatusFuture) + }.toMap + } + + /** + * Return the ids of blocks present in all the block managers that match the given filter. + * NOTE: This is a potentially expensive operation and should only be used for testing. + * + * If askSlaves is true, the master queries each block manager for the most updated block + * statuses. This is useful when the master is not informed of the given block by all block + * managers. + */ + private def getMatchingBlockIds( + filter: BlockId => Boolean, + askSlaves: Boolean): Future[Seq[BlockId]] = { + val getMatchingBlockIds = GetMatchingBlockIds(filter) + Future.sequence( + blockManagerInfo.values.map { info => + val future = + if (askSlaves) { + info.slaveEndpoint.sendWithReply[Seq[BlockId]](getMatchingBlockIds) + } else { + Future { info.blocks.keys.filter(filter).toSeq } + } + future + } + ).map(_.flatten.toSeq) + } + + private def register(id: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef) { + val time = System.currentTimeMillis() + if (!blockManagerInfo.contains(id)) { + blockManagerIdByExecutor.get(id.executorId) match { + case Some(oldId) => + // A block manager of the same executor already exists, so remove it (assumed dead) + logError("Got two different block manager registrations on same executor - " + + s" will replace old one $oldId with new one $id") + removeExecutor(id.executorId) + case None => + } + logInfo("Registering block manager %s with %s RAM, %s".format( + id.hostPort, Utils.bytesToString(maxMemSize), id)) + + blockManagerIdByExecutor(id.executorId) = id + + blockManagerInfo(id) = new BlockManagerInfo( + id, System.currentTimeMillis(), maxMemSize, slaveEndpoint) + } + listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxMemSize)) + } + + private def updateBlockInfo( + blockManagerId: BlockManagerId, + blockId: BlockId, + storageLevel: StorageLevel, + memSize: Long, + diskSize: Long, + tachyonSize: Long): Boolean = { + + if (!blockManagerInfo.contains(blockManagerId)) { + if (blockManagerId.isDriver && !isLocal) { + // We intentionally do not register the master (except in local mode), + // so we should not indicate failure. + return true + } else { + return false + } + } + + if (blockId == null) { + blockManagerInfo(blockManagerId).updateLastSeenMs() + return true + } + + blockManagerInfo(blockManagerId).updateBlockInfo( + blockId, storageLevel, memSize, diskSize, tachyonSize) + + var locations: mutable.HashSet[BlockManagerId] = null + if (blockLocations.containsKey(blockId)) { + locations = blockLocations.get(blockId) + } else { + locations = new mutable.HashSet[BlockManagerId] + blockLocations.put(blockId, locations) + } + + if (storageLevel.isValid) { + locations.add(blockManagerId) + } else { + locations.remove(blockManagerId) + } + + // Remove the block from master tracking if it has been removed on all slaves. + if (locations.size == 0) { + blockLocations.remove(blockId) + } + true + } + + private def getLocations(blockId: BlockId): Seq[BlockManagerId] = { + if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty + } + + private def getLocationsMultipleBlockIds(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { + blockIds.map(blockId => getLocations(blockId)) + } + + /** Get the list of the peers of the given block manager */ + private def getPeers(blockManagerId: BlockManagerId): Seq[BlockManagerId] = { + val blockManagerIds = blockManagerInfo.keySet + if (blockManagerIds.contains(blockManagerId)) { + blockManagerIds.filterNot { _.isDriver }.filterNot { _ == blockManagerId }.toSeq + } else { + Seq.empty + } + } + + /** + * Returns the hostname and port of an executor, based on the [[RpcEnv]] address of its + * [[BlockManagerSlaveEndpoint]]. + */ + private def getRpcHostPortForExecutor(executorId: String): Option[(String, Int)] = { + for ( + blockManagerId <- blockManagerIdByExecutor.get(executorId); + info <- blockManagerInfo.get(blockManagerId) + ) yield { + (info.slaveEndpoint.address.host, info.slaveEndpoint.address.port) + } + } + + override def onStop(): Unit = { + askThreadPool.shutdownNow() + } +} + +@DeveloperApi +case class BlockStatus( + storageLevel: StorageLevel, + memSize: Long, + diskSize: Long, + tachyonSize: Long) { + def isCached: Boolean = memSize + diskSize + tachyonSize > 0 +} + +@DeveloperApi +object BlockStatus { + def empty: BlockStatus = BlockStatus(StorageLevel.NONE, 0L, 0L, 0L) +} + +private[spark] class BlockManagerInfo( + val blockManagerId: BlockManagerId, + timeMs: Long, + val maxMem: Long, + val slaveEndpoint: RpcEndpointRef) + extends Logging { + + private var _lastSeenMs: Long = timeMs + private var _remainingMem: Long = maxMem + + // Mapping from block id to its status. + private val _blocks = new JHashMap[BlockId, BlockStatus] + + def getStatus(blockId: BlockId): Option[BlockStatus] = Option(_blocks.get(blockId)) + + def updateLastSeenMs() { + _lastSeenMs = System.currentTimeMillis() + } + + def updateBlockInfo( + blockId: BlockId, + storageLevel: StorageLevel, + memSize: Long, + diskSize: Long, + tachyonSize: Long) { + + updateLastSeenMs() + + if (_blocks.containsKey(blockId)) { + // The block exists on the slave already. + val blockStatus: BlockStatus = _blocks.get(blockId) + val originalLevel: StorageLevel = blockStatus.storageLevel + val originalMemSize: Long = blockStatus.memSize + + if (originalLevel.useMemory) { + _remainingMem += originalMemSize + } + } + + if (storageLevel.isValid) { + /* isValid means it is either stored in-memory, on-disk or on-Tachyon. + * The memSize here indicates the data size in or dropped from memory, + * tachyonSize here indicates the data size in or dropped from Tachyon, + * and the diskSize here indicates the data size in or dropped to disk. + * They can be both larger than 0, when a block is dropped from memory to disk. + * Therefore, a safe way to set BlockStatus is to set its info in accurate modes. */ + if (storageLevel.useMemory) { + _blocks.put(blockId, BlockStatus(storageLevel, memSize, 0, 0)) + _remainingMem -= memSize + logInfo("Added %s in memory on %s (size: %s, free: %s)".format( + blockId, blockManagerId.hostPort, Utils.bytesToString(memSize), + Utils.bytesToString(_remainingMem))) + } + if (storageLevel.useDisk) { + _blocks.put(blockId, BlockStatus(storageLevel, 0, diskSize, 0)) + logInfo("Added %s on disk on %s (size: %s)".format( + blockId, blockManagerId.hostPort, Utils.bytesToString(diskSize))) + } + if (storageLevel.useOffHeap) { + _blocks.put(blockId, BlockStatus(storageLevel, 0, 0, tachyonSize)) + logInfo("Added %s on tachyon on %s (size: %s)".format( + blockId, blockManagerId.hostPort, Utils.bytesToString(tachyonSize))) + } + } else if (_blocks.containsKey(blockId)) { + // If isValid is not true, drop the block. + val blockStatus: BlockStatus = _blocks.get(blockId) + _blocks.remove(blockId) + if (blockStatus.storageLevel.useMemory) { + logInfo("Removed %s on %s in memory (size: %s, free: %s)".format( + blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.memSize), + Utils.bytesToString(_remainingMem))) + } + if (blockStatus.storageLevel.useDisk) { + logInfo("Removed %s on %s on disk (size: %s)".format( + blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.diskSize))) + } + if (blockStatus.storageLevel.useOffHeap) { + logInfo("Removed %s on %s on tachyon (size: %s)".format( + blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.tachyonSize))) + } + } + } + + def removeBlock(blockId: BlockId) { + if (_blocks.containsKey(blockId)) { + _remainingMem += _blocks.get(blockId).memSize + _blocks.remove(blockId) + } + } + + def remainingMem: Long = _remainingMem + + def lastSeenMs: Long = _lastSeenMs + + def blocks: JHashMap[BlockId, BlockStatus] = _blocks + + override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem + + def clear() { + _blocks.clear() + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 48247453ed..f89d8d7493 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -19,8 +19,7 @@ package org.apache.spark.storage import java.io.{Externalizable, ObjectInput, ObjectOutput} -import akka.actor.ActorRef - +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils private[spark] object BlockManagerMessages { @@ -52,7 +51,7 @@ private[spark] object BlockManagerMessages { case class RegisterBlockManager( blockManagerId: BlockManagerId, maxMemSize: Long, - sender: ActorRef) + sender: RpcEndpointRef) extends ToBlockManagerMaster case class UpdateBlockInfo( @@ -92,7 +91,7 @@ private[spark] object BlockManagerMessages { case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster - case class GetActorSystemHostPortForExecutor(executorId: String) extends ToBlockManagerMaster + case class GetRpcHostPortForExecutor(executorId: String) extends ToBlockManagerMaster case class RemoveExecutor(execId: String) extends ToBlockManagerMaster diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala deleted file mode 100644 index 52fb896c4e..0000000000 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import scala.concurrent.Future - -import akka.actor.{ActorRef, Actor} - -import org.apache.spark.{Logging, MapOutputTracker, SparkEnv} -import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.ActorLogReceive - -/** - * An actor to take commands from the master to execute options. For example, - * this is used to remove blocks from the slave's BlockManager. - */ -private[storage] -class BlockManagerSlaveActor( - blockManager: BlockManager, - mapOutputTracker: MapOutputTracker) - extends Actor with ActorLogReceive with Logging { - - import context.dispatcher - - // Operations that involve removing blocks may be slow and should be done asynchronously - override def receiveWithLogging: PartialFunction[Any, Unit] = { - case RemoveBlock(blockId) => - doAsync[Boolean]("removing block " + blockId, sender) { - blockManager.removeBlock(blockId) - true - } - - case RemoveRdd(rddId) => - doAsync[Int]("removing RDD " + rddId, sender) { - blockManager.removeRdd(rddId) - } - - case RemoveShuffle(shuffleId) => - doAsync[Boolean]("removing shuffle " + shuffleId, sender) { - if (mapOutputTracker != null) { - mapOutputTracker.unregisterShuffle(shuffleId) - } - SparkEnv.get.shuffleManager.unregisterShuffle(shuffleId) - } - - case RemoveBroadcast(broadcastId, _) => - doAsync[Int]("removing broadcast " + broadcastId, sender) { - blockManager.removeBroadcast(broadcastId, tellMaster = true) - } - - case GetBlockStatus(blockId, _) => - sender ! blockManager.getStatus(blockId) - - case GetMatchingBlockIds(filter, _) => - sender ! blockManager.getMatchingBlockIds(filter) - } - - private def doAsync[T](actionMessage: String, responseActor: ActorRef)(body: => T) { - val future = Future { - logDebug(actionMessage) - body - } - future.onSuccess { case response => - logDebug("Done " + actionMessage + ", response is " + response) - responseActor ! response - logDebug("Sent response: " + response + " to " + responseActor) - } - future.onFailure { case t: Throwable => - logError("Error in " + actionMessage, t) - responseActor ! null.asInstanceOf[T] - } - } -} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala new file mode 100644 index 0000000000..8980fa8eb7 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import scala.concurrent.{ExecutionContext, Future} + +import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint} +import org.apache.spark.util.Utils +import org.apache.spark.{Logging, MapOutputTracker, SparkEnv} +import org.apache.spark.storage.BlockManagerMessages._ + +/** + * An RpcEndpoint to take commands from the master to execute options. For example, + * this is used to remove blocks from the slave's BlockManager. + */ +private[storage] +class BlockManagerSlaveEndpoint( + override val rpcEnv: RpcEnv, + blockManager: BlockManager, + mapOutputTracker: MapOutputTracker) + extends RpcEndpoint with Logging { + + private val asyncThreadPool = + Utils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool") + private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool) + + // Operations that involve removing blocks may be slow and should be done asynchronously + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RemoveBlock(blockId) => + doAsync[Boolean]("removing block " + blockId, context) { + blockManager.removeBlock(blockId) + true + } + + case RemoveRdd(rddId) => + doAsync[Int]("removing RDD " + rddId, context) { + blockManager.removeRdd(rddId) + } + + case RemoveShuffle(shuffleId) => + doAsync[Boolean]("removing shuffle " + shuffleId, context) { + if (mapOutputTracker != null) { + mapOutputTracker.unregisterShuffle(shuffleId) + } + SparkEnv.get.shuffleManager.unregisterShuffle(shuffleId) + } + + case RemoveBroadcast(broadcastId, _) => + doAsync[Int]("removing broadcast " + broadcastId, context) { + blockManager.removeBroadcast(broadcastId, tellMaster = true) + } + + case GetBlockStatus(blockId, _) => + context.reply(blockManager.getStatus(blockId)) + + case GetMatchingBlockIds(filter, _) => + context.reply(blockManager.getMatchingBlockIds(filter)) + } + + private def doAsync[T](actionMessage: String, context: RpcCallContext)(body: => T) { + val future = Future { + logDebug(actionMessage) + body + } + future.onSuccess { case response => + logDebug("Done " + actionMessage + ", response is " + response) + context.reply(response) + logDebug("Sent response: " + response + " to " + context.sender) + } + future.onFailure { case t: Throwable => + logError("Error in " + actionMessage, t) + context.sendFailure(t) + } + } + + override def onStop(): Unit = { + asyncThreadPool.shutdownNow() + } +} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 7c85e28679..0fdfaf300e 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1214,6 +1214,16 @@ private[spark] object Utils extends Logging { } } + /** Executes the given block. Log non-fatal errors if any, and only throw fatal errors */ + def tryLogNonFatalError(block: => Unit) { + try { + block + } catch { + case NonFatal(t) => + logError(s"Uncaught exception in thread ${Thread.currentThread().getName}", t) + } + } + /** * Execute a block of code, then a finally block, but if exceptions happen in * the finally block, do not suppress the original exception. diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala new file mode 100644 index 0000000000..0fd570e529 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.storage.BlockManagerId +import org.scalatest.FunSuite +import org.mockito.Mockito.{mock, spy, verify, when} +import org.mockito.Matchers +import org.mockito.Matchers._ + +import org.apache.spark.scheduler.TaskScheduler +import org.apache.spark.util.RpcUtils +import org.scalatest.concurrent.Eventually._ + +class HeartbeatReceiverSuite extends FunSuite with LocalSparkContext { + + test("HeartbeatReceiver") { + sc = spy(new SparkContext("local[2]", "test")) + val scheduler = mock(classOf[TaskScheduler]) + when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true) + when(sc.taskScheduler).thenReturn(scheduler) + + val heartbeatReceiver = new HeartbeatReceiver(sc) + sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(heartbeatReceiver.scheduler != null) + } + val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv) + + val metrics = new TaskMetrics + val blockManagerId = BlockManagerId("executor-1", "localhost", 12345) + val response = receiverRef.askWithReply[HeartbeatResponse]( + Heartbeat("executor-1", Array(1L -> metrics), blockManagerId)) + + verify(scheduler).executorHeartbeatReceived( + Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId)) + assert(false === response.reregisterBlockManager) + } + + test("HeartbeatReceiver re-register") { + sc = spy(new SparkContext("local[2]", "test")) + val scheduler = mock(classOf[TaskScheduler]) + when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(false) + when(sc.taskScheduler).thenReturn(scheduler) + + val heartbeatReceiver = new HeartbeatReceiver(sc) + sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(heartbeatReceiver.scheduler != null) + } + val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv) + + val metrics = new TaskMetrics + val blockManagerId = BlockManagerId("executor-1", "localhost", 12345) + val response = receiverRef.askWithReply[HeartbeatResponse]( + Heartbeat("executor-1", Array(1L -> metrics), blockManagerId)) + + verify(scheduler).executorHeartbeatReceived( + Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId)) + assert(true === response.reregisterBlockManager) + } +} diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index e07bdb9637..4f19c4f211 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -311,7 +311,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } test("self: call in onStop") { - @volatile var e: Throwable = null + @volatile var selfOption: Option[RpcEndpointRef] = null val endpointRef = env.setupEndpoint("self-onStop", new RpcEndpoint { override val rpcEnv = env @@ -321,20 +321,18 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } override def onStop(): Unit = { - self + selfOption = Option(self) } override def onError(cause: Throwable): Unit = { - e = cause } }) env.stop(endpointRef) eventually(timeout(5 seconds), interval(10 millis)) { - // Calling `self` in `onStop` is invalid - assert(e != null) - assert(e.getMessage.contains("Cannot find RpcEndpointRef")) + // Calling `self` in `onStop` will return null, so selfOption will be None + assert(selfOption == None) } } @@ -342,7 +340,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { // If a RpcEnv implementation breaks the `receive` contract, hope this test can expose it for(i <- 0 until 100) { @volatile var result = 0 - val endpointRef = env.setupThreadSafeEndpoint(s"receive-in-sequence-$i", new RpcEndpoint { + val endpointRef = env.setupEndpoint(s"receive-in-sequence-$i", new ThreadSafeRpcEndpoint { override val rpcEnv = env override def receive = { @@ -475,7 +473,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { test("network events") { val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)] - env.setupThreadSafeEndpoint("network-events", new RpcEndpoint { + env.setupEndpoint("network-events", new ThreadSafeRpcEndpoint { override val rpcEnv = env override def receive = { diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index c2903c8597..b4de90b65d 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -22,11 +22,11 @@ import scala.concurrent.duration._ import scala.language.implicitConversions import scala.language.postfixOps -import akka.actor.{ActorSystem, Props} import org.mockito.Mockito.{mock, when} -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers, PrivateMethodTester} +import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} import org.scalatest.concurrent.Eventually._ +import org.apache.spark.rpc.RpcEnv import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SecurityManager} import org.apache.spark.network.BlockTransferService import org.apache.spark.network.nio.NioBlockTransferService @@ -34,13 +34,12 @@ import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.storage.StorageLevel._ -import org.apache.spark.util.{AkkaUtils, SizeEstimator} /** Testsuite that tests block replication in BlockManager */ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAndAfter { private val conf = new SparkConf(false) - var actorSystem: ActorSystem = null + var rpcEnv: RpcEnv = null var master: BlockManagerMaster = null val securityMgr = new SecurityManager(conf) val mapOutputTracker = new MapOutputTrackerMaster(conf) @@ -61,7 +60,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) - val store = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, + val store = new BlockManager(name, rpcEnv, master, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) store.initialize("app-id") allStores += store @@ -69,12 +68,10 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd } before { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - "test", "localhost", 0, conf = conf, securityManager = securityMgr) - this.actorSystem = actorSystem + rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) conf.set("spark.authenticate", "false") - conf.set("spark.driver.port", boundPort.toString) + conf.set("spark.driver.port", rpcEnv.address.port.toString) conf.set("spark.storage.unrollFraction", "0.4") conf.set("spark.storage.unrollMemoryThreshold", "512") @@ -83,18 +80,17 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd // to make cached peers refresh frequently conf.set("spark.storage.cachedPeersTtl", "10") - master = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), - conf, true) + master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", + new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true) allStores.clear() } after { allStores.foreach { _.stop() } allStores.clear() - actorSystem.shutdown() - actorSystem.awaitTermination() - actorSystem = null + rpcEnv.shutdown() + rpcEnv.awaitTermination() + rpcEnv = null master = null } @@ -262,7 +258,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd val failableTransfer = mock(classOf[BlockTransferService]) // this wont actually work when(failableTransfer.hostName).thenReturn("some-hostname") when(failableTransfer.port).thenReturn(1000) - val failableStore = new BlockManager("failable-store", actorSystem, master, serializer, + val failableStore = new BlockManager("failable-store", rpcEnv, master, serializer, 10000, conf, mapOutputTracker, shuffleManager, failableTransfer, securityMgr, 0) failableStore.initialize("app-id") allStores += failableStore // so that this gets stopped after test diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index ecd1cba5b5..283090e3bd 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -19,24 +19,18 @@ package org.apache.spark.storage import java.nio.{ByteBuffer, MappedByteBuffer} import java.util.Arrays -import java.util.concurrent.TimeUnit import scala.collection.mutable.ArrayBuffer -import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.implicitConversions import scala.language.postfixOps -import akka.actor._ -import akka.pattern.ask -import akka.util.Timeout - import org.mockito.Mockito.{mock, when} - import org.scalatest._ import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ +import org.apache.spark.rpc.RpcEnv import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SecurityManager} import org.apache.spark.executor.DataReadMethod import org.apache.spark.network.nio.NioBlockTransferService @@ -53,7 +47,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach private val conf = new SparkConf(false) var store: BlockManager = null var store2: BlockManager = null - var actorSystem: ActorSystem = null + var rpcEnv: RpcEnv = null var master: BlockManagerMaster = null conf.set("spark.authenticate", "false") val securityMgr = new SecurityManager(conf) @@ -72,28 +66,25 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) - val manager = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, + val manager = new BlockManager(name, rpcEnv, master, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) manager.initialize("app-id") manager } override def beforeEach(): Unit = { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - "test", "localhost", 0, conf = conf, securityManager = securityMgr) - this.actorSystem = actorSystem + rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case System.setProperty("os.arch", "amd64") conf.set("os.arch", "amd64") conf.set("spark.test.useCompressedOops", "true") - conf.set("spark.driver.port", boundPort.toString) + conf.set("spark.driver.port", rpcEnv.address.port.toString) conf.set("spark.storage.unrollFraction", "0.4") conf.set("spark.storage.unrollMemoryThreshold", "512") - master = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), - conf, true) + master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", + new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true) val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() @@ -108,9 +99,9 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach store2.stop() store2 = null } - actorSystem.shutdown() - actorSystem.awaitTermination() - actorSystem = null + rpcEnv.shutdown() + rpcEnv.awaitTermination() + rpcEnv = null master = null } @@ -357,10 +348,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach master.removeExecutor(store.blockManagerId.executorId) assert(master.getLocations("a1").size == 0, "a1 was not removed from master") - implicit val timeout = Timeout(30, TimeUnit.SECONDS) - val reregister = !Await.result( - master.driverActor ? BlockManagerHeartbeat(store.blockManagerId), - timeout.duration).asInstanceOf[Boolean] + val reregister = !master.driverEndpoint.askWithReply[Boolean]( + BlockManagerHeartbeat(store.blockManagerId)) assert(reregister == true) } @@ -785,7 +774,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach test("block store put failure") { // Use Java serializer so we can create an unserializable error. val transfer = new NioBlockTransferService(conf, securityMgr) - store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, actorSystem, master, + store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master, new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 18a477f920..ef4873de2f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -24,20 +24,20 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.language.postfixOps -import akka.actor.{ActorSystem, Props} import org.apache.hadoop.conf.Configuration import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark._ import org.apache.spark.network.nio.NioBlockTransferService +import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.storage._ import org.apache.spark.streaming.receiver._ import org.apache.spark.streaming.util._ -import org.apache.spark.util.{AkkaUtils, ManualClock, Utils} +import org.apache.spark.util.{ManualClock, Utils} import WriteAheadLogBasedBlockHandler._ import WriteAheadLogSuite._ @@ -54,22 +54,19 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche val manualClock = new ManualClock val blockManagerSize = 10000000 - var actorSystem: ActorSystem = null + var rpcEnv: RpcEnv = null var blockManagerMaster: BlockManagerMaster = null var blockManager: BlockManager = null var tempDirectory: File = null before { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - "test", "localhost", 0, conf = conf, securityManager = securityMgr) - this.actorSystem = actorSystem - conf.set("spark.driver.port", boundPort.toString) + rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) + conf.set("spark.driver.port", rpcEnv.address.port.toString) - blockManagerMaster = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), - conf, true) + blockManagerMaster = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", + new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true) - blockManager = new BlockManager("bm", actorSystem, blockManagerMaster, serializer, + blockManager = new BlockManager("bm", rpcEnv, blockManagerMaster, serializer, blockManagerSize, conf, mapOutputTracker, shuffleManager, new NioBlockTransferService(conf, securityMgr), securityMgr, 0) blockManager.initialize("app-id") @@ -87,9 +84,9 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche blockManagerMaster.stop() blockManagerMaster = null } - actorSystem.shutdown() - actorSystem.awaitTermination() - actorSystem = null + rpcEnv.shutdown() + rpcEnv.awaitTermination() + rpcEnv = null Utils.deleteRecursively(tempDirectory) } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 455554eea0..24a1e02795 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -24,22 +24,20 @@ import java.lang.reflect.InvocationTargetException import java.net.{Socket, URL} import java.util.concurrent.atomic.AtomicReference -import akka.actor._ -import akka.remote._ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.util.ShutdownHookManager import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.spark.rpc._ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkEnv} import org.apache.spark.SparkException import org.apache.spark.deploy.{PythonRunner, SparkHadoopUtil} import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.scheduler.cluster.YarnSchedulerBackend import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{AkkaUtils, ChildFirstURLClassLoader, MutableURLClassLoader, - SignalLogger, Utils} +import org.apache.spark.util._ /** * Common application master functionality for Spark on Yarn. @@ -72,8 +70,8 @@ private[spark] class ApplicationMaster( @volatile private var allocator: YarnAllocator = _ // Fields used in client mode. - private var actorSystem: ActorSystem = null - private var actor: ActorRef = _ + private var rpcEnv: RpcEnv = null + private var amEndpoint: RpcEndpointRef = _ // Fields used in cluster mode. private val sparkContextRef = new AtomicReference[SparkContext](null) @@ -240,22 +238,21 @@ private[spark] class ApplicationMaster( } /** - * Create an actor that communicates with the driver. + * Create an [[RpcEndpoint]] that communicates with the driver. * * In cluster mode, the AM and the driver belong to same process - * so the AM actor need not monitor lifecycle of the driver. + * so the AMEndpoint need not monitor lifecycle of the driver. */ - private def runAMActor( + private def runAMEndpoint( host: String, port: String, isClusterMode: Boolean): Unit = { - val driverUrl = AkkaUtils.address( - AkkaUtils.protocol(actorSystem), + val driverEndpont = rpcEnv.setupEndpointRef( SparkEnv.driverActorSystemName, - host, - port, - YarnSchedulerBackend.ACTOR_NAME) - actor = actorSystem.actorOf(Props(new AMActor(driverUrl, isClusterMode)), name = "YarnAM") + RpcAddress(host, port.toInt), + YarnSchedulerBackend.ENDPOINT_NAME) + amEndpoint = + rpcEnv.setupEndpoint("YarnAM", new AMEndpoint(rpcEnv, driverEndpont, isClusterMode)) } private def runDriver(securityMgr: SecurityManager): Unit = { @@ -272,8 +269,8 @@ private[spark] class ApplicationMaster( ApplicationMaster.EXIT_SC_NOT_INITED, "Timed out waiting for SparkContext.") } else { - actorSystem = sc.env.actorSystem - runAMActor( + rpcEnv = sc.env.rpcEnv + runAMEndpoint( sc.getConf.get("spark.driver.host"), sc.getConf.get("spark.driver.port"), isClusterMode = true) @@ -283,8 +280,7 @@ private[spark] class ApplicationMaster( } private def runExecutorLauncher(securityMgr: SecurityManager): Unit = { - actorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0, - conf = sparkConf, securityManager = securityMgr)._1 + rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, 0, sparkConf, securityMgr) waitForSparkDriver() addAmIpFilter() registerAM(sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) @@ -431,7 +427,7 @@ private[spark] class ApplicationMaster( sparkConf.set("spark.driver.host", driverHost) sparkConf.set("spark.driver.port", driverPort.toString) - runAMActor(driverHost, driverPort.toString, isClusterMode = false) + runAMEndpoint(driverHost, driverPort.toString, isClusterMode = false) } /** Add the Yarn IP filter that is required for properly securing the UI. */ @@ -443,7 +439,7 @@ private[spark] class ApplicationMaster( System.setProperty("spark.ui.filters", amFilter) params.foreach { case (k, v) => System.setProperty(s"spark.$amFilter.param.$k", v) } } else { - actor ! AddWebUIFilter(amFilter, params.toMap, proxyBase) + amEndpoint.send(AddWebUIFilter(amFilter, params.toMap, proxyBase)) } } @@ -505,44 +501,29 @@ private[spark] class ApplicationMaster( } /** - * An actor that communicates with the driver's scheduler backend. + * An [[RpcEndpoint]] that communicates with the driver's scheduler backend. */ - private class AMActor(driverUrl: String, isClusterMode: Boolean) extends Actor { - var driver: ActorSelection = _ - - override def preStart(): Unit = { - logInfo("Listen to driver: " + driverUrl) - driver = context.actorSelection(driverUrl) - // Send a hello message to establish the connection, after which - // we can monitor Lifecycle Events. - driver ! "Hello" - driver ! RegisterClusterManager - // In cluster mode, the AM can directly monitor the driver status instead - // of trying to deduce it from the lifecycle of the driver's actor - if (!isClusterMode) { - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - } + private class AMEndpoint( + override val rpcEnv: RpcEnv, driver: RpcEndpointRef, isClusterMode: Boolean) + extends RpcEndpoint with Logging { + + override def onStart(): Unit = { + driver.send(RegisterClusterManager(self)) } override def receive: PartialFunction[Any, Unit] = { - case x: DisassociatedEvent => - logInfo(s"Driver terminated or disconnected! Shutting down. $x") - // In cluster mode, do not rely on the disassociated event to exit - // This avoids potentially reporting incorrect exit codes if the driver fails - if (!isClusterMode) { - finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) - } - case x: AddWebUIFilter => logInfo(s"Add WebUI Filter. $x") - driver ! x + driver.send(x) + } + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RequestExecutors(requestedTotal) => Option(allocator) match { case Some(a) => a.requestTotalExecutors(requestedTotal) case None => logWarning("Container allocator is not ready to request executors yet.") } - sender ! true + context.reply(true) case KillExecutors(executorIds) => logInfo(s"Driver requested to kill executor(s) ${executorIds.mkString(", ")}.") @@ -550,7 +531,16 @@ private[spark] class ApplicationMaster( case Some(a) => executorIds.foreach(a.killExecutor) case None => logWarning("Container allocator is not ready to kill executors yet.") } - sender ! true + context.reply(true) + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + logInfo(s"Driver terminated or disconnected! Shutting down. $remoteAddress") + // In cluster mode, do not rely on the disassociated event to exit + // This avoids potentially reporting incorrect exit codes if the driver fails + if (!isClusterMode) { + finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) + } } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index c98763e15b..b8f42dadcb 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -112,7 +112,7 @@ private[yarn] class YarnAllocator( SparkEnv.driverActorSystemName, sparkConf.get("spark.driver.host"), sparkConf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) // For testing private val launchContainers = sparkConf.getBoolean("spark.yarn.launchContainers", true) -- cgit v1.2.3