aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorzsxwing <zsxwing@gmail.com>2015-04-04 11:52:05 -0700
committerReynold Xin <rxin@databricks.com>2015-04-04 11:52:05 -0700
commitf15806a8f8ca34288ddb2d74b9ff1972c8374b59 (patch)
tree88abe5de9fadf078e57951450cb3368d0fb7cb64 /core
parent7bca62f79056e592cf07b49d8b8d04c59dea25fc (diff)
downloadspark-f15806a8f8ca34288ddb2d74b9ff1972c8374b59.tar.gz
spark-f15806a8f8ca34288ddb2d74b9ff1972c8374b59.tar.bz2
spark-f15806a8f8ca34288ddb2d74b9ff1972c8374b59.zip
[SPARK-6602][Core] Replace direct use of Akka with Spark RPC interface - part 1
This PR replaced the following `Actor`s to `RpcEndpoint`: 1. HeartbeatReceiver 1. ExecutorActor 1. BlockManagerMasterActor 1. BlockManagerSlaveActor 1. CoarseGrainedExecutorBackend and subclasses 1. CoarseGrainedSchedulerBackend.DriverActor This is the first PR. I will split the work of SPARK-6602 to several PRs for code review. Author: zsxwing <zsxwing@gmail.com> Closes #5268 from zsxwing/rpc-rewrite and squashes the following commits: 287e9f8 [zsxwing] Fix the code style 26c56b7 [zsxwing] Merge branch 'master' into rpc-rewrite 9cc825a [zsxwing] Rmove setupThreadSafeEndpoint and add ThreadSafeRpcEndpoint 30a9036 [zsxwing] Make self return null after stopping RpcEndpointRef; fix docs and error messages 705245d [zsxwing] Fix some bugs after rebasing the changes on the master 003cf80 [zsxwing] Update CoarseGrainedExecutorBackend and CoarseGrainedSchedulerBackend to use RpcEndpoint 7d0e6dc [zsxwing] Update BlockManagerSlaveActor to use RpcEndpoint f5d6543 [zsxwing] Update BlockManagerMaster to use RpcEndpoint 30e3f9f [zsxwing] Update ExecutorActor to use RpcEndpoint 478b443 [zsxwing] Update HeartbeatReceiver to use RpcEndpoint
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala66
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala23
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala79
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala18
-rw-r--r--core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala (renamed from core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala)18
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala39
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala148
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala93
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala48
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala22
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala72
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala (renamed from core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala)119
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala (renamed from core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala)44
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala10
-rw-r--r--core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala81
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala14
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala28
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala37
27 files changed, 566 insertions, 479 deletions
diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
index 9f8ad03b91..5871b8c869 100644
--- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
+++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
@@ -17,15 +17,15 @@
package org.apache.spark
-import scala.concurrent.duration._
-import scala.collection.mutable
+import java.util.concurrent.{ScheduledFuture, TimeUnit, Executors}
-import akka.actor.{Actor, Cancellable}
+import scala.collection.mutable
import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext}
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.scheduler.{SlaveLost, TaskScheduler}
-import org.apache.spark.util.ActorLogReceive
+import org.apache.spark.util.Utils
/**
* A heartbeat from executors to the driver. This is a shared message used by several internal
@@ -51,9 +51,11 @@ private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean)
* Lives in the driver to receive heartbeats from executors..
*/
private[spark] class HeartbeatReceiver(sc: SparkContext)
- extends Actor with ActorLogReceive with Logging {
+ extends ThreadSafeRpcEndpoint with Logging {
+
+ override val rpcEnv: RpcEnv = sc.env.rpcEnv
- private var scheduler: TaskScheduler = null
+ private[spark] var scheduler: TaskScheduler = null
// executor ID -> timestamp of when the last heartbeat from this executor was received
private val executorLastSeen = new mutable.HashMap[String, Long]
@@ -69,34 +71,44 @@ private[spark] class HeartbeatReceiver(sc: SparkContext)
sc.conf.getOption("spark.network.timeoutInterval").map(_.toLong * 1000).
getOrElse(sc.conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60000))
- private var timeoutCheckingTask: Cancellable = null
-
- override def preStart(): Unit = {
- import context.dispatcher
- timeoutCheckingTask = context.system.scheduler.schedule(0.seconds,
- checkTimeoutIntervalMs.milliseconds, self, ExpireDeadHosts)
- super.preStart()
+ private var timeoutCheckingTask: ScheduledFuture[_] = null
+
+ private val timeoutCheckingThread = Executors.newSingleThreadScheduledExecutor(
+ Utils.namedThreadFactory("heartbeat-timeout-checking-thread"))
+
+ private val killExecutorThread = Executors.newSingleThreadExecutor(
+ Utils.namedThreadFactory("kill-executor-thread"))
+
+ override def onStart(): Unit = {
+ timeoutCheckingTask = timeoutCheckingThread.scheduleAtFixedRate(new Runnable {
+ override def run(): Unit = Utils.tryLogNonFatalError {
+ Option(self).foreach(_.send(ExpireDeadHosts))
+ }
+ }, 0, checkTimeoutIntervalMs, TimeUnit.MILLISECONDS)
}
-
- override def receiveWithLogging: PartialFunction[Any, Unit] = {
+
+ override def receive: PartialFunction[Any, Unit] = {
+ case ExpireDeadHosts =>
+ expireDeadHosts()
case TaskSchedulerIsSet =>
scheduler = sc.taskScheduler
+ }
+
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case heartbeat @ Heartbeat(executorId, taskMetrics, blockManagerId) =>
if (scheduler != null) {
val unknownExecutor = !scheduler.executorHeartbeatReceived(
executorId, taskMetrics, blockManagerId)
val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor)
executorLastSeen(executorId) = System.currentTimeMillis()
- sender ! response
+ context.reply(response)
} else {
// Because Executor will sleep several seconds before sending the first "Heartbeat", this
// case rarely happens. However, if it really happens, log it and ask the executor to
// register itself again.
logWarning(s"Dropping $heartbeat because TaskScheduler is not ready yet")
- sender ! HeartbeatResponse(reregisterBlockManager = true)
+ context.reply(HeartbeatResponse(reregisterBlockManager = true))
}
- case ExpireDeadHosts =>
- expireDeadHosts()
}
private def expireDeadHosts(): Unit = {
@@ -109,17 +121,25 @@ private[spark] class HeartbeatReceiver(sc: SparkContext)
scheduler.executorLost(executorId, SlaveLost("Executor heartbeat " +
s"timed out after ${now - lastSeenMs} ms"))
if (sc.supportDynamicAllocation) {
- sc.killExecutor(executorId)
+ // Asynchronously kill the executor to avoid blocking the current thread
+ killExecutorThread.submit(new Runnable {
+ override def run(): Unit = sc.killExecutor(executorId)
+ })
}
executorLastSeen.remove(executorId)
}
}
}
- override def postStop(): Unit = {
+ override def onStop(): Unit = {
if (timeoutCheckingTask != null) {
- timeoutCheckingTask.cancel()
+ timeoutCheckingTask.cancel(true)
}
- super.postStop()
+ timeoutCheckingThread.shutdownNow()
+ killExecutorThread.shutdownNow()
}
}
+
+object HeartbeatReceiver {
+ val ENDPOINT_NAME = "HeartbeatReceiver"
+}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 3b73a8a8fd..942c5975ec 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -32,8 +32,6 @@ import scala.collection.generic.Growable
import scala.collection.mutable.HashMap
import scala.reflect.{ClassTag, classTag}
-import akka.actor.Props
-
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable,
@@ -48,12 +46,13 @@ import org.apache.mesos.MesosNativeLibrary
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
-import org.apache.spark.executor.TriggerThreadDump
+import org.apache.spark.executor.{ExecutorEndpoint, TriggerThreadDump}
import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat,
FixedLengthBinaryInputFormat}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd._
+import org.apache.spark.rpc.RpcAddress
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend,
SparkDeploySchedulerBackend, SimrSchedulerBackend}
@@ -360,14 +359,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
// We need to register "HeartbeatReceiver" before "createTaskScheduler" because Executor will
// retrieve "HeartbeatReceiver" in the constructor. (SPARK-6640)
- private val heartbeatReceiver = env.actorSystem.actorOf(
- Props(new HeartbeatReceiver(this)), "HeartbeatReceiver")
+ private val heartbeatReceiver = env.rpcEnv.setupEndpoint(
+ HeartbeatReceiver.ENDPOINT_NAME, new HeartbeatReceiver(this))
// Create and start the scheduler
private[spark] var (schedulerBackend, taskScheduler) =
SparkContext.createTaskScheduler(this, master)
- heartbeatReceiver ! TaskSchedulerIsSet
+ heartbeatReceiver.send(TaskSchedulerIsSet)
@volatile private[spark] var dagScheduler: DAGScheduler = _
try {
@@ -455,10 +454,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
if (executorId == SparkContext.DRIVER_IDENTIFIER) {
Some(Utils.getThreadDump())
} else {
- val (host, port) = env.blockManager.master.getActorSystemHostPortForExecutor(executorId).get
- val actorRef = AkkaUtils.makeExecutorRef("ExecutorActor", conf, host, port, env.actorSystem)
- Some(AkkaUtils.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump, actorRef,
- AkkaUtils.numRetries(conf), AkkaUtils.retryWaitMs(conf), AkkaUtils.askTimeout(conf)))
+ val (host, port) = env.blockManager.master.getRpcHostPortForExecutor(executorId).get
+ val endpointRef = env.rpcEnv.setupEndpointRef(
+ SparkEnv.executorActorSystemName,
+ RpcAddress(host, port),
+ ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME)
+ Some(endpointRef.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump))
}
} catch {
case e: Exception =>
@@ -1418,7 +1419,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
dagScheduler = null
listenerBus.stop()
eventLogger.foreach(_.stop())
- env.actorSystem.stop(heartbeatReceiver)
+ env.rpcEnv.stop(heartbeatReceiver)
progressBar.foreach(_.stop())
taskScheduler = null
// TODO: Cache.stop()?
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 4a2ed82a40..55be0a59fe 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -295,7 +295,9 @@ object SparkEnv extends Logging {
}
}
- def registerOrLookupEndpoint(name: String, endpointCreator: => RpcEndpoint): RpcEndpointRef = {
+ def registerOrLookupEndpoint(
+ name: String, endpointCreator: => RpcEndpoint):
+ RpcEndpointRef = {
if (isDriver) {
logInfo("Registering " + name)
rpcEnv.setupEndpoint(name, endpointCreator)
@@ -334,12 +336,13 @@ object SparkEnv extends Logging {
new NioBlockTransferService(conf, securityManager)
}
- val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
- "BlockManagerMaster",
- new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf, isDriver)
+ val blockManagerMaster = new BlockManagerMaster(registerOrLookupEndpoint(
+ BlockManagerMaster.DRIVER_ENDPOINT_NAME,
+ new BlockManagerMasterEndpoint(rpcEnv, isLocal, conf, listenerBus)),
+ conf, isDriver)
// NB: blockManager is not valid until initialize() is called later.
- val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster,
+ val blockManager = new BlockManager(executorId, rpcEnv, blockManagerMaster,
serializer, conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager,
numUsableCores)
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 900e678ee0..8300f9f219 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -21,39 +21,45 @@ import java.net.URL
import java.nio.ByteBuffer
import scala.collection.mutable
-import scala.concurrent.Await
+import scala.util.{Failure, Success}
-import akka.actor.{Actor, ActorSelection, Props}
-import akka.pattern.Patterns
-import akka.remote.{RemotingLifecycleEvent, DisassociatedEvent}
-
-import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv}
+import org.apache.spark.rpc._
+import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.deploy.worker.WorkerWatcher
import org.apache.spark.scheduler.TaskDescription
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
-import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils}
+import org.apache.spark.util.{SignalLogger, Utils}
private[spark] class CoarseGrainedExecutorBackend(
+ override val rpcEnv: RpcEnv,
driverUrl: String,
executorId: String,
hostPort: String,
cores: Int,
userClassPath: Seq[URL],
env: SparkEnv)
- extends Actor with ActorLogReceive with ExecutorBackend with Logging {
+ extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging {
Utils.checkHostPort(hostPort, "Expected hostport")
var executor: Executor = null
- var driver: ActorSelection = null
+ @volatile var driver: Option[RpcEndpointRef] = None
- override def preStart() {
+ override def onStart() {
+ import scala.concurrent.ExecutionContext.Implicits.global
logInfo("Connecting to driver: " + driverUrl)
- driver = context.actorSelection(driverUrl)
- driver ! RegisterExecutor(executorId, hostPort, cores, extractLogUrls)
- context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
+ rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref =>
+ driver = Some(ref)
+ ref.sendWithReply[RegisteredExecutor.type](
+ RegisterExecutor(executorId, self, hostPort, cores, extractLogUrls))
+ } onComplete {
+ case Success(msg) => Utils.tryLogNonFatalError {
+ Option(self).foreach(_.send(msg)) // msg must be RegisteredExecutor
+ }
+ case Failure(e) => logError(s"Cannot register with driver: $driverUrl", e)
+ }
}
def extractLogUrls: Map[String, String] = {
@@ -62,7 +68,7 @@ private[spark] class CoarseGrainedExecutorBackend(
.map(e => (e._1.substring(prefix.length).toLowerCase, e._2))
}
- override def receiveWithLogging: PartialFunction[Any, Unit] = {
+ override def receive: PartialFunction[Any, Unit] = {
case RegisteredExecutor =>
logInfo("Successfully registered with driver")
val (hostname, _) = Utils.parseHostPort(hostPort)
@@ -92,23 +98,28 @@ private[spark] class CoarseGrainedExecutorBackend(
executor.killTask(taskId, interruptThread)
}
- case x: DisassociatedEvent =>
- if (x.remoteAddress == driver.anchorPath.address) {
- logError(s"Driver $x disassociated! Shutting down.")
- System.exit(1)
- } else {
- logWarning(s"Received irrelevant DisassociatedEvent $x")
- }
-
case StopExecutor =>
logInfo("Driver commanded a shutdown")
executor.stop()
- context.stop(self)
- context.system.shutdown()
+ stop()
+ rpcEnv.shutdown()
+ }
+
+ override def onDisconnected(remoteAddress: RpcAddress): Unit = {
+ if (driver.exists(_.address == remoteAddress)) {
+ logError(s"Driver $remoteAddress disassociated! Shutting down.")
+ System.exit(1)
+ } else {
+ logWarning(s"An unknown ($remoteAddress) driver disconnected.")
+ }
}
override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {
- driver ! StatusUpdate(executorId, taskId, state, data)
+ val msg = StatusUpdate(executorId, taskId, state, data)
+ driver match {
+ case Some(driverRef) => driverRef.send(msg)
+ case None => logWarning(s"Drop $msg because has not yet connected to driver")
+ }
}
}
@@ -132,16 +143,14 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
// Bootstrap to fetch the driver's Spark properties.
val executorConf = new SparkConf
val port = executorConf.getInt("spark.executor.port", 0)
- val (fetcher, _) = AkkaUtils.createActorSystem(
+ val fetcher = RpcEnv.create(
"driverPropsFetcher",
hostname,
port,
executorConf,
new SecurityManager(executorConf))
- val driver = fetcher.actorSelection(driverUrl)
- val timeout = AkkaUtils.askTimeout(executorConf)
- val fut = Patterns.ask(driver, RetrieveSparkProps, timeout)
- val props = Await.result(fut, timeout).asInstanceOf[Seq[(String, String)]] ++
+ val driver = fetcher.setupEndpointRefByURI(driverUrl)
+ val props = driver.askWithReply[Seq[(String, String)]](RetrieveSparkProps) ++
Seq[(String, String)](("spark.app.id", appId))
fetcher.shutdown()
@@ -162,16 +171,14 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
val boundPort = env.conf.getInt("spark.executor.port", 0)
assert(boundPort != 0)
- // Start the CoarseGrainedExecutorBackend actor.
+ // Start the CoarseGrainedExecutorBackend endpoint.
val sparkHostPort = hostname + ":" + boundPort
- env.actorSystem.actorOf(
- Props(classOf[CoarseGrainedExecutorBackend],
- driverUrl, executorId, sparkHostPort, cores, userClassPath, env),
- name = "Executor")
+ env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend(
+ env.rpcEnv, driverUrl, executorId, sparkHostPort, cores, userClassPath, env))
workerUrl.foreach { url =>
env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url))
}
- env.actorSystem.awaitTermination()
+ env.rpcEnv.awaitTermination()
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index bf3135ef08..14f99a464b 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -27,8 +27,6 @@ import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.util.control.NonFatal
-import akka.actor.Props
-
import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task}
@@ -88,9 +86,9 @@ private[spark] class Executor(
env.blockManager.initialize(conf.getAppId)
}
- // Create an actor for receiving RPCs from the driver
- private val executorActor = env.actorSystem.actorOf(
- Props(new ExecutorActor(executorId)), "ExecutorActor")
+ // Create an RpcEndpoint for receiving RPCs from the driver
+ private val executorEndpoint = env.rpcEnv.setupEndpoint(
+ ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME, new ExecutorEndpoint(env.rpcEnv, executorId))
// Whether to load classes in user jars before those in Spark jars
private val userClassPathFirst: Boolean = {
@@ -139,7 +137,7 @@ private[spark] class Executor(
def stop(): Unit = {
env.metricsSystem.report()
- env.actorSystem.stop(executorActor)
+ env.rpcEnv.stop(executorEndpoint)
isStopped = true
threadPool.shutdown()
if (!isLocal) {
@@ -391,11 +389,8 @@ private[spark] class Executor(
}
}
- private val timeout = AkkaUtils.lookupTimeout(conf)
- private val retryAttempts = AkkaUtils.numRetries(conf)
- private val retryIntervalMs = AkkaUtils.retryWaitMs(conf)
private val heartbeatReceiverRef =
- AkkaUtils.makeDriverRef("HeartbeatReceiver", conf, env.actorSystem)
+ RpcUtils.makeDriverRef(HeartbeatReceiver.ENDPOINT_NAME, conf, env.rpcEnv)
/** Reports heartbeat and metrics for active tasks to the driver. */
private def reportHeartBeat(): Unit = {
@@ -426,8 +421,7 @@ private[spark] class Executor(
val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId)
try {
- val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef,
- retryAttempts, retryIntervalMs, timeout)
+ val response = heartbeatReceiverRef.askWithReply[HeartbeatResponse](message)
if (response.reregisterBlockManager) {
logWarning("Told to re-register on heartbeat")
env.blockManager.reregister()
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala
index 3e47d13f75..cf362f8464 100644
--- a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala
@@ -17,10 +17,8 @@
package org.apache.spark.executor
-import akka.actor.Actor
-import org.apache.spark.Logging
-
-import org.apache.spark.util.{Utils, ActorLogReceive}
+import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint}
+import org.apache.spark.util.Utils
/**
* Driver -> Executor message to trigger a thread dump.
@@ -28,14 +26,18 @@ import org.apache.spark.util.{Utils, ActorLogReceive}
private[spark] case object TriggerThreadDump
/**
- * Actor that runs inside of executors to enable driver -> executor RPC.
+ * [[RpcEndpoint]] that runs inside of executors to enable driver -> executor RPC.
*/
private[spark]
-class ExecutorActor(executorId: String) extends Actor with ActorLogReceive with Logging {
+class ExecutorEndpoint(override val rpcEnv: RpcEnv, executorId: String) extends RpcEndpoint {
- override def receiveWithLogging: PartialFunction[Any, Unit] = {
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case TriggerThreadDump =>
- sender ! Utils.getThreadDump()
+ context.reply(Utils.getThreadDump())
}
}
+
+object ExecutorEndpoint {
+ val EXECUTOR_ENDPOINT_NAME = "ExecutorEndpoint"
+}
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
index 7985941d94..d47e41abcf 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
@@ -40,10 +40,7 @@ private[spark] abstract class RpcEnv(conf: SparkConf) {
/**
* Return RpcEndpointRef of the registered [[RpcEndpoint]]. Will be used to implement
- * [[RpcEndpoint.self]].
- *
- * Note: This method won't return null. `IllegalArgumentException` will be thrown if calling this
- * on a non-existent endpoint.
+ * [[RpcEndpoint.self]]. Return `null` if the corresponding [[RpcEndpointRef]] does not exist.
*/
private[rpc] def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef
@@ -59,20 +56,6 @@ private[spark] abstract class RpcEnv(conf: SparkConf) {
def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef
/**
- * Register a [[RpcEndpoint]] with a name and return its [[RpcEndpointRef]]. [[RpcEnv]] should
- * make sure thread-safely sending messages to [[RpcEndpoint]].
- *
- * Thread-safety means processing of one message happens before processing of the next message by
- * the same [[RpcEndpoint]]. In the other words, changes to internal fields of a [[RpcEndpoint]]
- * are visible when processing the next message, and fields in the [[RpcEndpoint]] need not be
- * volatile or equivalent.
- *
- * However, there is no guarantee that the same thread will be executing the same [[RpcEndpoint]]
- * for different messages.
- */
- def setupThreadSafeEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef
-
- /**
* Retrieve the [[RpcEndpointRef]] represented by `uri` asynchronously.
*/
def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef]
@@ -181,7 +164,7 @@ private[spark] trait RpcEnvFactory {
* constructor onStart receive* onStop
*
* Note: `receive` can be called concurrently. If you want `receive` is thread-safe, please use
- * [[RpcEnv.setupThreadSafeEndpoint]]
+ * [[ThreadSafeRpcEndpoint]]
*
* If any error is thrown from one of [[RpcEndpoint]] methods except `onError`, `onError` will be
* invoked with the cause. If `onError` throws an error, [[RpcEnv]] will ignore it.
@@ -195,7 +178,7 @@ private[spark] trait RpcEndpoint {
/**
* The [[RpcEndpointRef]] of this [[RpcEndpoint]]. `self` will become valid when `onStart` is
- * called.
+ * called. And `self` will become `null` when `onStop` is called.
*
* Note: Because before `onStart`, [[RpcEndpoint]] has not yet been registered and there is not
* valid [[RpcEndpointRef]] for it. So don't call `self` before `onStart` is called.
@@ -279,6 +262,19 @@ private[spark] trait RpcEndpoint {
}
/**
+ * A trait that requires RpcEnv thread-safely sending messages to it.
+ *
+ * Thread-safety means processing of one message happens before processing of the next message by
+ * the same [[ThreadSafeRpcEndpoint]]. In the other words, changes to internal fields of a
+ * [[ThreadSafeRpcEndpoint]] are visible when processing the next message, and fields in the
+ * [[ThreadSafeRpcEndpoint]] need not be volatile or equivalent.
+ *
+ * However, there is no guarantee that the same thread will be executing the same
+ * [[ThreadSafeRpcEndpoint]] for different messages.
+ */
+trait ThreadSafeRpcEndpoint extends RpcEndpoint
+
+/**
* A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe.
*/
private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf)
@@ -407,7 +403,8 @@ private[spark] object RpcAddress {
}
/**
- * A callback that [[RpcEndpoint]] can use it to send back a message or failure.
+ * A callback that [[RpcEndpoint]] can use it to send back a message or failure. It's thread-safe
+ * and can be called in any thread.
*/
private[spark] trait RpcCallContext {
diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
index 769d59b7b3..9e06147dff 100644
--- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
@@ -82,17 +82,9 @@ private[spark] class AkkaRpcEnv private[akka] (
/**
* Retrieve the [[RpcEndpointRef]] of `endpoint`.
*/
- override def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = {
- val endpointRef = endpointToRef.get(endpoint)
- require(endpointRef != null, s"Cannot find RpcEndpointRef of ${endpoint} in ${this}")
- endpointRef
- }
+ override def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = endpointToRef.get(endpoint)
override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = {
- setupThreadSafeEndpoint(name, endpoint)
- }
-
- override def setupThreadSafeEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = {
@volatile var endpointRef: AkkaRpcEndpointRef = null
// Use lazy because the Actor needs to use `endpointRef`.
// So `actorRef` should be created after assigning `endpointRef`.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 7227fa9da4..917cce1f96 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -23,14 +23,10 @@ import java.util.concurrent.{TimeUnit, Executors}
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack}
-import scala.concurrent.Await
import scala.concurrent.duration._
import scala.language.postfixOps
import scala.util.control.NonFatal
-import akka.pattern.ask
-import akka.util.Timeout
-
import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.executor.TaskMetrics
@@ -165,11 +161,8 @@ class DAGScheduler(
taskMetrics: Array[(Long, Int, Int, TaskMetrics)], // (taskId, stageId, stateAttempt, metrics)
blockManagerId: BlockManagerId): Boolean = {
listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics))
- implicit val timeout = Timeout(600 seconds)
-
- Await.result(
- blockManagerMaster.driverActor ? BlockManagerHeartbeat(blockManagerId),
- timeout.duration).asInstanceOf[Boolean]
+ blockManagerMaster.driverEndpoint.askWithReply[Boolean](
+ BlockManagerHeartbeat(blockManagerId), 600 seconds)
}
// Called by TaskScheduler when an executor fails.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
index 9bf74f4be1..70364cea62 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
@@ -20,6 +20,7 @@ package org.apache.spark.scheduler.cluster
import java.nio.ByteBuffer
import org.apache.spark.TaskState.TaskState
+import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.{SerializableBuffer, Utils}
private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable
@@ -41,6 +42,7 @@ private[spark] object CoarseGrainedClusterMessages {
// Executors to driver
case class RegisterExecutor(
executorId: String,
+ executorRef: RpcEndpointRef,
hostPort: String,
cores: Int,
logUrls: Map[String, String])
@@ -70,6 +72,8 @@ private[spark] object CoarseGrainedClusterMessages {
case class RemoveExecutor(executorId: String, reason: String) extends CoarseGrainedClusterMessage
+ case class SetupDriver(driver: RpcEndpointRef) extends CoarseGrainedClusterMessage
+
// Exchanged between the driver and the AM in Yarn client mode
case class AddWebUIFilter(filterName:String, filterParams: Map[String, String], proxyBase: String)
extends CoarseGrainedClusterMessage
@@ -77,7 +81,7 @@ private[spark] object CoarseGrainedClusterMessages {
// Messages exchanged between the driver and the cluster manager for executor allocation
// In Yarn mode, these are exchanged between the driver and the AM
- case object RegisterClusterManager extends CoarseGrainedClusterMessage
+ case class RegisterClusterManager(am: RpcEndpointRef) extends CoarseGrainedClusterMessage
// Request executors by specifying the new total number of executors desired
// This includes executors already pending or running
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 5d258d9da4..4c49da87af 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -17,20 +17,16 @@
package org.apache.spark.scheduler.cluster
+import java.util.concurrent.{TimeUnit, Executors}
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
-import scala.concurrent.Await
-import scala.concurrent.duration._
-
-import akka.actor._
-import akka.pattern.ask
-import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
+import org.apache.spark.rpc._
import org.apache.spark.{ExecutorAllocationClient, Logging, SparkEnv, SparkException, TaskState}
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
-import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Utils}
+import org.apache.spark.util.{SerializableBuffer, AkkaUtils, Utils}
/**
* A scheduler backend that waits for coarse grained executors to connect to it through Akka.
@@ -41,7 +37,7 @@ import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Ut
* (spark.deploy.*).
*/
private[spark]
-class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSystem: ActorSystem)
+class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: RpcEnv)
extends ExecutorAllocationClient with SchedulerBackend with Logging
{
// Use an atomic variable to track total number of cores in the cluster for simplicity and speed
@@ -49,7 +45,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
// Total number of executors that are currently registered
var totalRegisteredExecutors = new AtomicInteger(0)
val conf = scheduler.sc.conf
- private val timeout = AkkaUtils.askTimeout(conf)
private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)
// Submit tasks only after (registered resources / total expected resources)
// is equal to at least this value, that is double between 0 and 1.
@@ -71,48 +66,26 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
// Executors we have requested the cluster manager to kill that have not died yet
private val executorsPendingToRemove = new HashSet[String]
- class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor with ActorLogReceive {
+ class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)])
+ extends ThreadSafeRpcEndpoint with Logging {
override protected def log = CoarseGrainedSchedulerBackend.this.log
- private val addressToExecutorId = new HashMap[Address, String]
- override def preStart() {
- // Listen for remote client disconnection events, since they don't go through Akka's watch()
- context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
+ private val addressToExecutorId = new HashMap[RpcAddress, String]
+
+ private val reviveThread =
+ Executors.newSingleThreadScheduledExecutor(Utils.namedThreadFactory("driver-revive-thread"))
+ override def onStart() {
// Periodically revive offers to allow delay scheduling to work
val reviveInterval = conf.getLong("spark.scheduler.revive.interval", 1000)
- import context.dispatcher
- context.system.scheduler.schedule(0.millis, reviveInterval.millis, self, ReviveOffers)
- }
-
- def receiveWithLogging: PartialFunction[Any, Unit] = {
- case RegisterExecutor(executorId, hostPort, cores, logUrls) =>
- Utils.checkHostPort(hostPort, "Host port expected " + hostPort)
- if (executorDataMap.contains(executorId)) {
- sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId)
- } else {
- logInfo("Registered executor: " + sender + " with ID " + executorId)
- sender ! RegisteredExecutor
-
- addressToExecutorId(sender.path.address) = executorId
- totalCoreCount.addAndGet(cores)
- totalRegisteredExecutors.addAndGet(1)
- val (host, _) = Utils.parseHostPort(hostPort)
- val data = new ExecutorData(sender, sender.path.address, host, cores, cores, logUrls)
- // This must be synchronized because variables mutated
- // in this block are read when requesting executors
- CoarseGrainedSchedulerBackend.this.synchronized {
- executorDataMap.put(executorId, data)
- if (numPendingExecutors > 0) {
- numPendingExecutors -= 1
- logDebug(s"Decremented number of pending executors ($numPendingExecutors left)")
- }
- }
- listenerBus.post(
- SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data))
- makeOffers()
+ reviveThread.scheduleAtFixedRate(new Runnable {
+ override def run(): Unit = Utils.tryLogNonFatalError {
+ Option(self).foreach(_.send(ReviveOffers))
}
+ }, 0, reviveInterval, TimeUnit.MILLISECONDS)
+ }
+ override def receive: PartialFunction[Any, Unit] = {
case StatusUpdate(executorId, taskId, state, data) =>
scheduler.statusUpdate(taskId, state, data.value)
if (TaskState.isFinished(state)) {
@@ -133,33 +106,58 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
case KillTask(taskId, executorId, interruptThread) =>
executorDataMap.get(executorId) match {
case Some(executorInfo) =>
- executorInfo.executorActor ! KillTask(taskId, executorId, interruptThread)
+ executorInfo.executorEndpoint.send(KillTask(taskId, executorId, interruptThread))
case None =>
// Ignoring the task kill since the executor is not registered.
logWarning(s"Attempted to kill task $taskId for unknown executor $executorId.")
}
+ }
+
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ case RegisterExecutor(executorId, executorRef, hostPort, cores, logUrls) =>
+ Utils.checkHostPort(hostPort, "Host port expected " + hostPort)
+ if (executorDataMap.contains(executorId)) {
+ context.reply(RegisterExecutorFailed("Duplicate executor ID: " + executorId))
+ } else {
+ logInfo("Registered executor: " + executorRef + " with ID " + executorId)
+ context.reply(RegisteredExecutor)
+
+ addressToExecutorId(executorRef.address) = executorId
+ totalCoreCount.addAndGet(cores)
+ totalRegisteredExecutors.addAndGet(1)
+ val (host, _) = Utils.parseHostPort(hostPort)
+ val data = new ExecutorData(executorRef, executorRef.address, host, cores, cores, logUrls)
+ // This must be synchronized because variables mutated
+ // in this block are read when requesting executors
+ CoarseGrainedSchedulerBackend.this.synchronized {
+ executorDataMap.put(executorId, data)
+ if (numPendingExecutors > 0) {
+ numPendingExecutors -= 1
+ logDebug(s"Decremented number of pending executors ($numPendingExecutors left)")
+ }
+ }
+ listenerBus.post(
+ SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data))
+ makeOffers()
+ }
case StopDriver =>
- sender ! true
- context.stop(self)
+ context.reply(true)
+ stop()
case StopExecutors =>
logInfo("Asking each executor to shut down")
for ((_, executorData) <- executorDataMap) {
- executorData.executorActor ! StopExecutor
+ executorData.executorEndpoint.send(StopExecutor)
}
- sender ! true
+ context.reply(true)
case RemoveExecutor(executorId, reason) =>
removeExecutor(executorId, reason)
- sender ! true
-
- case DisassociatedEvent(_, address, _) =>
- addressToExecutorId.get(address).foreach(removeExecutor(_,
- "remote Akka client disassociated"))
+ context.reply(true)
case RetrieveSparkProps =>
- sender ! sparkProperties
+ context.reply(sparkProperties)
}
// Make fake resource offers on all executors
@@ -169,6 +167,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
}.toSeq))
}
+ override def onDisconnected(remoteAddress: RpcAddress): Unit = {
+ addressToExecutorId.get(remoteAddress).foreach(removeExecutor(_,
+ "remote Rpc client disassociated"))
+ }
+
// Make fake resource offers on just one executor
def makeOffers(executorId: String) {
val executorData = executorDataMap(executorId)
@@ -199,7 +202,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
else {
val executorData = executorDataMap(task.executorId)
executorData.freeCores -= scheduler.CPUS_PER_TASK
- executorData.executorActor ! LaunchTask(new SerializableBuffer(serializedTask))
+ executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask)))
}
}
}
@@ -223,9 +226,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
case None => logError(s"Asked to remove non-existent executor $executorId")
}
}
+
+ override def onStop() {
+ reviveThread.shutdownNow()
+ }
}
- var driverActor: ActorRef = null
+ var driverEndpoint: RpcEndpointRef = null
val taskIdsOnSlave = new HashMap[String, HashSet[String]]
override def start() {
@@ -236,16 +243,15 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
}
}
// TODO (prashant) send conf instead of properties
- driverActor = actorSystem.actorOf(
- Props(new DriverActor(properties)), name = CoarseGrainedSchedulerBackend.ACTOR_NAME)
+ driverEndpoint = rpcEnv.setupEndpoint(
+ CoarseGrainedSchedulerBackend.ENDPOINT_NAME, new DriverEndpoint(rpcEnv, properties))
}
def stopExecutors() {
try {
- if (driverActor != null) {
+ if (driverEndpoint != null) {
logInfo("Shutting down all executors")
- val future = driverActor.ask(StopExecutors)(timeout)
- Await.ready(future, timeout)
+ driverEndpoint.askWithReply[Boolean](StopExecutors)
}
} catch {
case e: Exception =>
@@ -256,22 +262,21 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
override def stop() {
stopExecutors()
try {
- if (driverActor != null) {
- val future = driverActor.ask(StopDriver)(timeout)
- Await.ready(future, timeout)
+ if (driverEndpoint != null) {
+ driverEndpoint.askWithReply[Boolean](StopDriver)
}
} catch {
case e: Exception =>
- throw new SparkException("Error stopping standalone scheduler's driver actor", e)
+ throw new SparkException("Error stopping standalone scheduler's driver endpoint", e)
}
}
override def reviveOffers() {
- driverActor ! ReviveOffers
+ driverEndpoint.send(ReviveOffers)
}
override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) {
- driverActor ! KillTask(taskId, executorId, interruptThread)
+ driverEndpoint.send(KillTask(taskId, executorId, interruptThread))
}
override def defaultParallelism(): Int = {
@@ -281,11 +286,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
// Called by subclasses when notified of a lost worker
def removeExecutor(executorId: String, reason: String) {
try {
- val future = driverActor.ask(RemoveExecutor(executorId, reason))(timeout)
- Await.ready(future, timeout)
+ driverEndpoint.askWithReply[Boolean](RemoveExecutor(executorId, reason))
} catch {
case e: Exception =>
- throw new SparkException("Error notifying standalone scheduler's driver actor", e)
+ throw new SparkException("Error notifying standalone scheduler's driver endpoint", e)
}
}
@@ -391,5 +395,5 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
}
private[spark] object CoarseGrainedSchedulerBackend {
- val ACTOR_NAME = "CoarseGrainedScheduler"
+ val ENDPOINT_NAME = "CoarseGrainedScheduler"
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala
index 5e571efe76..26e72c0bff 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala
@@ -17,20 +17,20 @@
package org.apache.spark.scheduler.cluster
-import akka.actor.{Address, ActorRef}
+import org.apache.spark.rpc.{RpcEndpointRef, RpcAddress}
/**
* Grouping of data for an executor used by CoarseGrainedSchedulerBackend.
*
- * @param executorActor The ActorRef representing this executor
+ * @param executorEndpoint The ActorRef representing this executor
* @param executorAddress The network address of this executor
* @param executorHost The hostname that this executor is running on
* @param freeCores The current number of cores available for work on the executor
* @param totalCores The total number of cores available to the executor
*/
private[cluster] class ExecutorData(
- val executorActor: ActorRef,
- val executorAddress: Address,
+ val executorEndpoint: RpcEndpointRef,
+ val executorAddress: RpcAddress,
override val executorHost: String,
var freeCores: Int,
override val totalCores: Int,
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
index 06786a5952..0324c9dab9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
@@ -19,16 +19,16 @@ package org.apache.spark.scheduler.cluster
import org.apache.hadoop.fs.{Path, FileSystem}
+import org.apache.spark.rpc.RpcAddress
import org.apache.spark.{Logging, SparkContext, SparkEnv}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.scheduler.TaskSchedulerImpl
-import org.apache.spark.util.AkkaUtils
private[spark] class SimrSchedulerBackend(
scheduler: TaskSchedulerImpl,
sc: SparkContext,
driverFilePath: String)
- extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
+ extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv)
with Logging {
val tmpPath = new Path(driverFilePath + "_tmp")
@@ -39,12 +39,9 @@ private[spark] class SimrSchedulerBackend(
override def start() {
super.start()
- val driverUrl = AkkaUtils.address(
- AkkaUtils.protocol(actorSystem),
- SparkEnv.driverActorSystemName,
- sc.conf.get("spark.driver.host"),
- sc.conf.get("spark.driver.port"),
- CoarseGrainedSchedulerBackend.ACTOR_NAME)
+ val driverUrl = rpcEnv.uriOf(SparkEnv.driverActorSystemName,
+ RpcAddress(sc.conf.get("spark.driver.host"), sc.conf.get("spark.driver.port").toInt),
+ CoarseGrainedSchedulerBackend.ENDPOINT_NAME)
val conf = SparkHadoopUtil.get.newConfiguration(sc.conf)
val fs = FileSystem.get(conf)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index ffd4825705..7eb3fdc19b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -19,17 +19,18 @@ package org.apache.spark.scheduler.cluster
import java.util.concurrent.Semaphore
+import org.apache.spark.rpc.RpcAddress
import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv}
import org.apache.spark.deploy.{ApplicationDescription, Command}
import org.apache.spark.deploy.client.{AppClient, AppClientListener}
import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskSchedulerImpl}
-import org.apache.spark.util.{AkkaUtils, Utils}
+import org.apache.spark.util.Utils
private[spark] class SparkDeploySchedulerBackend(
scheduler: TaskSchedulerImpl,
sc: SparkContext,
masters: Array[String])
- extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
+ extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv)
with AppClientListener
with Logging {
@@ -48,12 +49,9 @@ private[spark] class SparkDeploySchedulerBackend(
super.start()
// The endpoint for executors to talk to us
- val driverUrl = AkkaUtils.address(
- AkkaUtils.protocol(actorSystem),
- SparkEnv.driverActorSystemName,
- conf.get("spark.driver.host"),
- conf.get("spark.driver.port"),
- CoarseGrainedSchedulerBackend.ACTOR_NAME)
+ val driverUrl = rpcEnv.uriOf(SparkEnv.driverActorSystemName,
+ RpcAddress(sc.conf.get("spark.driver.host"), sc.conf.get("spark.driver.port").toInt),
+ CoarseGrainedSchedulerBackend.ENDPOINT_NAME)
val args = Seq(
"--driver-url", driverUrl,
"--executor-id", "{{EXECUTOR_ID}}",
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
index 5a38ad9f2b..f72566c370 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
@@ -19,10 +19,8 @@ package org.apache.spark.scheduler.cluster
import scala.concurrent.{Future, ExecutionContext}
-import akka.actor.{Actor, ActorRef, Props}
-import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
-
-import org.apache.spark.SparkContext
+import org.apache.spark.{Logging, SparkContext}
+import org.apache.spark.rpc._
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.scheduler.TaskSchedulerImpl
import org.apache.spark.ui.JettyUtils
@@ -37,7 +35,7 @@ import scala.util.control.NonFatal
private[spark] abstract class YarnSchedulerBackend(
scheduler: TaskSchedulerImpl,
sc: SparkContext)
- extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) {
+ extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) {
if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) {
minRegisteredRatio = 0.8
@@ -45,10 +43,8 @@ private[spark] abstract class YarnSchedulerBackend(
protected var totalExpectedExecutors = 0
- private val yarnSchedulerActor: ActorRef =
- actorSystem.actorOf(
- Props(new YarnSchedulerActor),
- name = YarnSchedulerBackend.ACTOR_NAME)
+ private val yarnSchedulerEndpoint = rpcEnv.setupEndpoint(
+ YarnSchedulerBackend.ENDPOINT_NAME, new YarnSchedulerEndpoint(rpcEnv))
private implicit val askTimeout = AkkaUtils.askTimeout(sc.conf)
@@ -57,16 +53,14 @@ private[spark] abstract class YarnSchedulerBackend(
* This includes executors already pending or running.
*/
override def doRequestTotalExecutors(requestedTotal: Int): Boolean = {
- AkkaUtils.askWithReply[Boolean](
- RequestExecutors(requestedTotal), yarnSchedulerActor, askTimeout)
+ yarnSchedulerEndpoint.askWithReply[Boolean](RequestExecutors(requestedTotal))
}
/**
* Request that the ApplicationMaster kill the specified executors.
*/
override def doKillExecutors(executorIds: Seq[String]): Boolean = {
- AkkaUtils.askWithReply[Boolean](
- KillExecutors(executorIds), yarnSchedulerActor, askTimeout)
+ yarnSchedulerEndpoint.askWithReply[Boolean](KillExecutors(executorIds))
}
override def sufficientResourcesRegistered(): Boolean = {
@@ -96,64 +90,71 @@ private[spark] abstract class YarnSchedulerBackend(
}
/**
- * An actor that communicates with the ApplicationMaster.
+ * An [[RpcEndpoint]] that communicates with the ApplicationMaster.
*/
- private class YarnSchedulerActor extends Actor {
- private var amActor: Option[ActorRef] = None
-
- implicit val askAmActorExecutor = ExecutionContext.fromExecutor(
- Utils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-executor"))
+ private class YarnSchedulerEndpoint(override val rpcEnv: RpcEnv)
+ extends ThreadSafeRpcEndpoint with Logging {
+ private var amEndpoint: Option[RpcEndpointRef] = None
- override def preStart(): Unit = {
- // Listen for disassociation events
- context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
- }
+ private val askAmThreadPool =
+ Utils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-thread-pool")
+ implicit val askAmExecutor = ExecutionContext.fromExecutor(askAmThreadPool)
override def receive: PartialFunction[Any, Unit] = {
- case RegisterClusterManager =>
- logInfo(s"ApplicationMaster registered as $sender")
- amActor = Some(sender)
+ case RegisterClusterManager(am) =>
+ logInfo(s"ApplicationMaster registered as $am")
+ amEndpoint = Some(am)
+
+ case AddWebUIFilter(filterName, filterParams, proxyBase) =>
+ addWebUIFilter(filterName, filterParams, proxyBase)
+
+ }
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case r: RequestExecutors =>
- amActor match {
- case Some(actor) =>
- val driverActor = sender
+ amEndpoint match {
+ case Some(am) =>
Future {
- driverActor ! AkkaUtils.askWithReply[Boolean](r, actor, askTimeout)
+ context.reply(am.askWithReply[Boolean](r))
} onFailure {
- case NonFatal(e) => logError(s"Sending $r to AM was unsuccessful", e)
+ case NonFatal(e) =>
+ logError(s"Sending $r to AM was unsuccessful", e)
+ context.sendFailure(e)
}
case None =>
logWarning("Attempted to request executors before the AM has registered!")
- sender ! false
+ context.reply(false)
}
case k: KillExecutors =>
- amActor match {
- case Some(actor) =>
- val driverActor = sender
+ amEndpoint match {
+ case Some(am) =>
Future {
- driverActor ! AkkaUtils.askWithReply[Boolean](k, actor, askTimeout)
+ context.reply(am.askWithReply[Boolean](k))
} onFailure {
- case NonFatal(e) => logError(s"Sending $k to AM was unsuccessful", e)
+ case NonFatal(e) =>
+ logError(s"Sending $k to AM was unsuccessful", e)
+ context.sendFailure(e)
}
case None =>
logWarning("Attempted to kill executors before the AM has registered!")
- sender ! false
+ context.reply(false)
}
- case AddWebUIFilter(filterName, filterParams, proxyBase) =>
- addWebUIFilter(filterName, filterParams, proxyBase)
- sender ! true
+ }
- case d: DisassociatedEvent =>
- if (amActor.isDefined && sender == amActor.get) {
- logWarning(s"ApplicationMaster has disassociated: $d")
- }
+ override def onDisconnected(remoteAddress: RpcAddress): Unit = {
+ if (amEndpoint.exists(_.address == remoteAddress)) {
+ logWarning(s"ApplicationMaster has disassociated: $remoteAddress")
+ }
+ }
+
+ override def onStop(): Unit ={
+ askAmThreadPool.shutdownNow()
}
}
}
private[spark] object YarnSchedulerBackend {
- val ACTOR_NAME = "YarnScheduler"
+ val ENDPOINT_NAME = "YarnScheduler"
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index e13de0f46e..b037a4966c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -47,7 +47,7 @@ private[spark] class CoarseMesosSchedulerBackend(
scheduler: TaskSchedulerImpl,
sc: SparkContext,
master: String)
- extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
+ extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv)
with MScheduler
with Logging {
@@ -148,7 +148,7 @@ private[spark] class CoarseMesosSchedulerBackend(
SparkEnv.driverActorSystemName,
conf.get("spark.driver.host"),
conf.get("spark.driver.port"),
- CoarseGrainedSchedulerBackend.ACTOR_NAME)
+ CoarseGrainedSchedulerBackend.ENDPOINT_NAME)
val uri = conf.get("spark.executor.uri", null)
if (uri == null) {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
index eb3f999b5b..70a477a689 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
@@ -18,17 +18,14 @@
package org.apache.spark.scheduler.local
import java.nio.ByteBuffer
+import java.util.concurrent.{Executors, TimeUnit}
-import scala.concurrent.duration._
-import scala.language.postfixOps
-
-import akka.actor.{Actor, ActorRef, Props}
-
+import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEndpointRef, RpcEnv}
+import org.apache.spark.util.Utils
import org.apache.spark.{Logging, SparkContext, SparkEnv, TaskState}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.executor.{Executor, ExecutorBackend}
import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer}
-import org.apache.spark.util.ActorLogReceive
private case class ReviveOffers()
@@ -39,17 +36,19 @@ private case class KillTask(taskId: Long, interruptThread: Boolean)
private case class StopExecutor()
/**
- * Calls to LocalBackend are all serialized through LocalActor. Using an actor makes the calls on
- * LocalBackend asynchronous, which is necessary to prevent deadlock between LocalBackend
+ * Calls to LocalBackend are all serialized through LocalEndpoint. Using an RpcEndpoint makes the
+ * calls on LocalBackend asynchronous, which is necessary to prevent deadlock between LocalBackend
* and the TaskSchedulerImpl.
*/
-private[spark] class LocalActor(
+private[spark] class LocalEndpoint(
+ override val rpcEnv: RpcEnv,
scheduler: TaskSchedulerImpl,
executorBackend: LocalBackend,
private val totalCores: Int)
- extends Actor with ActorLogReceive with Logging {
+ extends ThreadSafeRpcEndpoint with Logging {
- import context.dispatcher // to use Akka's scheduler.scheduleOnce()
+ private val reviveThread = Executors.newSingleThreadScheduledExecutor(
+ Utils.namedThreadFactory("local-revive-thread"))
private var freeCores = totalCores
@@ -59,7 +58,7 @@ private[spark] class LocalActor(
private val executor = new Executor(
localExecutorId, localExecutorHostname, SparkEnv.get, isLocal = true)
- override def receiveWithLogging: PartialFunction[Any, Unit] = {
+ override def receive: PartialFunction[Any, Unit] = {
case ReviveOffers =>
reviveOffers()
@@ -87,9 +86,17 @@ private[spark] class LocalActor(
}
if (tasks.isEmpty && scheduler.activeTaskSets.nonEmpty) {
// Try to reviveOffer after 1 second, because scheduler may wait for locality timeout
- context.system.scheduler.scheduleOnce(1000 millis, self, ReviveOffers)
+ reviveThread.schedule(new Runnable {
+ override def run(): Unit = Utils.tryLogNonFatalError {
+ Option(self).foreach(_.send(ReviveOffers))
+ }
+ }, 1000, TimeUnit.MILLISECONDS)
}
}
+
+ override def onStop(): Unit = {
+ reviveThread.shutdownNow()
+ }
}
/**
@@ -101,31 +108,30 @@ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores:
extends SchedulerBackend with ExecutorBackend {
private val appId = "local-" + System.currentTimeMillis
- var localActor: ActorRef = null
+ var localEndpoint: RpcEndpointRef = null
override def start() {
- localActor = SparkEnv.get.actorSystem.actorOf(
- Props(new LocalActor(scheduler, this, totalCores)),
- "LocalBackendActor")
+ localEndpoint = SparkEnv.get.rpcEnv.setupEndpoint(
+ "LocalBackendEndpoint", new LocalEndpoint(SparkEnv.get.rpcEnv, scheduler, this, totalCores))
}
override def stop() {
- localActor ! StopExecutor
+ localEndpoint.send(StopExecutor)
}
override def reviveOffers() {
- localActor ! ReviveOffers
+ localEndpoint.send(ReviveOffers)
}
override def defaultParallelism(): Int =
scheduler.conf.getInt("spark.default.parallelism", totalCores)
override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) {
- localActor ! KillTask(taskId, interruptThread)
+ localEndpoint.send(KillTask(taskId, interruptThread))
}
override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) {
- localActor ! StatusUpdate(taskId, state, serializedData)
+ localEndpoint.send(StatusUpdate(taskId, state, serializedData))
}
override def applicationId(): String = appId
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index fc31296f4d..1aa0ef18de 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -26,7 +26,6 @@ import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._
import scala.util.Random
-import akka.actor.{ActorSystem, Props}
import sun.nio.ch.DirectBuffer
import org.apache.spark._
@@ -37,6 +36,7 @@ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.shuffle.ExternalShuffleClient
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
+import org.apache.spark.rpc.RpcEnv
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.shuffle.hash.HashShuffleManager
@@ -64,7 +64,7 @@ private[spark] class BlockResult(
*/
private[spark] class BlockManager(
executorId: String,
- actorSystem: ActorSystem,
+ rpcEnv: RpcEnv,
val master: BlockManagerMaster,
defaultSerializer: Serializer,
maxMemory: Long,
@@ -136,9 +136,9 @@ private[spark] class BlockManager(
// Whether to compress shuffle output temporarily spilled to disk
private val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true)
- private val slaveActor = actorSystem.actorOf(
- Props(new BlockManagerSlaveActor(this, mapOutputTracker)),
- name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next)
+ private val slaveEndpoint = rpcEnv.setupEndpoint(
+ "BlockManagerEndpoint" + BlockManager.ID_GENERATOR.next,
+ new BlockManagerSlaveEndpoint(rpcEnv, this, mapOutputTracker))
// Pending re-registration action being executed asynchronously or null if none is pending.
// Accesses should synchronize on asyncReregisterLock.
@@ -167,7 +167,7 @@ private[spark] class BlockManager(
*/
def this(
execId: String,
- actorSystem: ActorSystem,
+ rpcEnv: RpcEnv,
master: BlockManagerMaster,
serializer: Serializer,
conf: SparkConf,
@@ -176,7 +176,7 @@ private[spark] class BlockManager(
blockTransferService: BlockTransferService,
securityManager: SecurityManager,
numUsableCores: Int) = {
- this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf),
+ this(execId, rpcEnv, master, serializer, BlockManager.getMaxMemory(conf),
conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, numUsableCores)
}
@@ -186,7 +186,7 @@ private[spark] class BlockManager(
* where it is only learned after registration with the TaskScheduler).
*
* This method initializes the BlockTransferService and ShuffleClient, registers with the
- * BlockManagerMaster, starts the BlockManagerWorker actor, and registers with a local shuffle
+ * BlockManagerMaster, starts the BlockManagerWorker endpoint, and registers with a local shuffle
* service if configured.
*/
def initialize(appId: String): Unit = {
@@ -202,7 +202,7 @@ private[spark] class BlockManager(
blockManagerId
}
- master.registerBlockManager(blockManagerId, maxMemory, slaveActor)
+ master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint)
// Register Executors' configuration with the local shuffle service, if one should exist.
if (externalShuffleServiceEnabled && !blockManagerId.isDriver) {
@@ -265,7 +265,7 @@ private[spark] class BlockManager(
def reregister(): Unit = {
// TODO: We might need to rate limit re-registering.
logInfo("BlockManager re-registering with master")
- master.registerBlockManager(blockManagerId, maxMemory, slaveActor)
+ master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint)
reportAllBlocks()
}
@@ -1215,7 +1215,7 @@ private[spark] class BlockManager(
shuffleClient.close()
}
diskBlockManager.stop()
- actorSystem.stop(slaveActor)
+ rpcEnv.stop(slaveEndpoint)
blockInfo.clear()
memoryStore.clear()
diskStore.clear()
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index 061964826f..ceacf04302 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -20,35 +20,31 @@ package org.apache.spark.storage
import scala.concurrent.{Await, Future}
import scala.concurrent.ExecutionContext.Implicits.global
-import akka.actor._
-
+import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.{Logging, SparkConf, SparkException}
import org.apache.spark.storage.BlockManagerMessages._
import org.apache.spark.util.AkkaUtils
private[spark]
class BlockManagerMaster(
- var driverActor: ActorRef,
+ var driverEndpoint: RpcEndpointRef,
conf: SparkConf,
isDriver: Boolean)
extends Logging {
- private val AKKA_RETRY_ATTEMPTS: Int = AkkaUtils.numRetries(conf)
- private val AKKA_RETRY_INTERVAL_MS: Int = AkkaUtils.retryWaitMs(conf)
-
- val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster"
val timeout = AkkaUtils.askTimeout(conf)
- /** Remove a dead executor from the driver actor. This is only called on the driver side. */
+ /** Remove a dead executor from the driver endpoint. This is only called on the driver side. */
def removeExecutor(execId: String) {
tell(RemoveExecutor(execId))
logInfo("Removed " + execId + " successfully in removeExecutor")
}
/** Register the BlockManager's id with the driver. */
- def registerBlockManager(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
+ def registerBlockManager(
+ blockManagerId: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef): Unit = {
logInfo("Trying to register BlockManager")
- tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveActor))
+ tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint))
logInfo("Registered BlockManager")
}
@@ -59,7 +55,7 @@ class BlockManagerMaster(
memSize: Long,
diskSize: Long,
tachyonSize: Long): Boolean = {
- val res = askDriverWithReply[Boolean](
+ val res = driverEndpoint.askWithReply[Boolean](
UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize, tachyonSize))
logDebug(s"Updated info of block $blockId")
res
@@ -67,12 +63,12 @@ class BlockManagerMaster(
/** Get locations of the blockId from the driver */
def getLocations(blockId: BlockId): Seq[BlockManagerId] = {
- askDriverWithReply[Seq[BlockManagerId]](GetLocations(blockId))
+ driverEndpoint.askWithReply[Seq[BlockManagerId]](GetLocations(blockId))
}
/** Get locations of multiple blockIds from the driver */
def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = {
- askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds))
+ driverEndpoint.askWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds))
}
/**
@@ -85,11 +81,11 @@ class BlockManagerMaster(
/** Get ids of other nodes in the cluster from the driver */
def getPeers(blockManagerId: BlockManagerId): Seq[BlockManagerId] = {
- askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId))
+ driverEndpoint.askWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId))
}
- def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = {
- askDriverWithReply[Option[(String, Int)]](GetActorSystemHostPortForExecutor(executorId))
+ def getRpcHostPortForExecutor(executorId: String): Option[(String, Int)] = {
+ driverEndpoint.askWithReply[Option[(String, Int)]](GetRpcHostPortForExecutor(executorId))
}
/**
@@ -97,12 +93,12 @@ class BlockManagerMaster(
* blocks that the driver knows about.
*/
def removeBlock(blockId: BlockId) {
- askDriverWithReply(RemoveBlock(blockId))
+ driverEndpoint.askWithReply[Boolean](RemoveBlock(blockId))
}
/** Remove all blocks belonging to the given RDD. */
def removeRdd(rddId: Int, blocking: Boolean) {
- val future = askDriverWithReply[Future[Seq[Int]]](RemoveRdd(rddId))
+ val future = driverEndpoint.askWithReply[Future[Seq[Int]]](RemoveRdd(rddId))
future.onFailure {
case e: Exception =>
logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}")
@@ -114,7 +110,7 @@ class BlockManagerMaster(
/** Remove all blocks belonging to the given shuffle. */
def removeShuffle(shuffleId: Int, blocking: Boolean) {
- val future = askDriverWithReply[Future[Seq[Boolean]]](RemoveShuffle(shuffleId))
+ val future = driverEndpoint.askWithReply[Future[Seq[Boolean]]](RemoveShuffle(shuffleId))
future.onFailure {
case e: Exception =>
logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}")
@@ -126,7 +122,7 @@ class BlockManagerMaster(
/** Remove all blocks belonging to the given broadcast. */
def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean, blocking: Boolean) {
- val future = askDriverWithReply[Future[Seq[Int]]](
+ val future = driverEndpoint.askWithReply[Future[Seq[Int]]](
RemoveBroadcast(broadcastId, removeFromMaster))
future.onFailure {
case e: Exception =>
@@ -145,11 +141,11 @@ class BlockManagerMaster(
* amount of remaining memory.
*/
def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = {
- askDriverWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus)
+ driverEndpoint.askWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus)
}
def getStorageStatus: Array[StorageStatus] = {
- askDriverWithReply[Array[StorageStatus]](GetStorageStatus)
+ driverEndpoint.askWithReply[Array[StorageStatus]](GetStorageStatus)
}
/**
@@ -165,11 +161,12 @@ class BlockManagerMaster(
askSlaves: Boolean = true): Map[BlockManagerId, BlockStatus] = {
val msg = GetBlockStatus(blockId, askSlaves)
/*
- * To avoid potential deadlocks, the use of Futures is necessary, because the master actor
+ * To avoid potential deadlocks, the use of Futures is necessary, because the master endpoint
* should not block on waiting for a block manager, which can in turn be waiting for the
- * master actor for a response to a prior message.
+ * master endpoint for a response to a prior message.
*/
- val response = askDriverWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg)
+ val response = driverEndpoint.
+ askWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg)
val (blockManagerIds, futures) = response.unzip
val result = Await.result(Future.sequence(futures), timeout)
if (result == null) {
@@ -193,33 +190,28 @@ class BlockManagerMaster(
filter: BlockId => Boolean,
askSlaves: Boolean): Seq[BlockId] = {
val msg = GetMatchingBlockIds(filter, askSlaves)
- val future = askDriverWithReply[Future[Seq[BlockId]]](msg)
+ val future = driverEndpoint.askWithReply[Future[Seq[BlockId]]](msg)
Await.result(future, timeout)
}
- /** Stop the driver actor, called only on the Spark driver node */
+ /** Stop the driver endpoint, called only on the Spark driver node */
def stop() {
- if (driverActor != null && isDriver) {
+ if (driverEndpoint != null && isDriver) {
tell(StopBlockManagerMaster)
- driverActor = null
+ driverEndpoint = null
logInfo("BlockManagerMaster stopped")
}
}
- /** Send a one-way message to the master actor, to which we expect it to reply with true. */
+ /** Send a one-way message to the master endpoint, to which we expect it to reply with true. */
private def tell(message: Any) {
- if (!askDriverWithReply[Boolean](message)) {
- throw new SparkException("BlockManagerMasterActor returned false, expected true.")
+ if (!driverEndpoint.askWithReply[Boolean](message)) {
+ throw new SparkException("BlockManagerMasterEndpoint returned false, expected true.")
}
}
- /**
- * Send a message to the driver actor and get its result within a default timeout, or
- * throw a SparkException if this fails.
- */
- private def askDriverWithReply[T](message: Any): T = {
- AkkaUtils.askWithReply(message, driverActor, AKKA_RETRY_ATTEMPTS, AKKA_RETRY_INTERVAL_MS,
- timeout)
- }
+}
+private[spark] object BlockManagerMaster {
+ val DRIVER_ENDPOINT_NAME = "BlockManagerMaster"
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
index 5b53280161..28c73a7d54 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
@@ -21,25 +21,26 @@ import java.util.{HashMap => JHashMap}
import scala.collection.mutable
import scala.collection.JavaConversions._
-import scala.concurrent.Future
-import scala.concurrent.duration._
+import scala.concurrent.{ExecutionContext, Future}
-import akka.actor.{Actor, ActorRef}
-import akka.pattern.ask
-
-import org.apache.spark.{Logging, SparkConf, SparkException}
+import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, ThreadSafeRpcEndpoint}
+import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.scheduler._
import org.apache.spark.storage.BlockManagerMessages._
-import org.apache.spark.util.{ActorLogReceive, AkkaUtils, Utils}
+import org.apache.spark.util.Utils
/**
- * BlockManagerMasterActor is an actor on the master node to track statuses of
- * all slaves' block managers.
+ * BlockManagerMasterEndpoint is an [[ThreadSafeRpcEndpoint]] on the master node to track statuses
+ * of all slaves' block managers.
*/
private[spark]
-class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus: LiveListenerBus)
- extends Actor with ActorLogReceive with Logging {
+class BlockManagerMasterEndpoint(
+ override val rpcEnv: RpcEnv,
+ val isLocal: Boolean,
+ conf: SparkConf,
+ listenerBus: LiveListenerBus)
+ extends ThreadSafeRpcEndpoint with Logging {
// Mapping from block manager id to the block manager's information.
private val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo]
@@ -50,68 +51,67 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
// Mapping from block id to the set of block managers that have the block.
private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]]
- private val akkaTimeout = AkkaUtils.askTimeout(conf)
+ private val askThreadPool = Utils.newDaemonCachedThreadPool("block-manager-ask-thread-pool")
+ private implicit val askExecutionContext = ExecutionContext.fromExecutorService(askThreadPool)
- override def receiveWithLogging: PartialFunction[Any, Unit] = {
- case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) =>
- register(blockManagerId, maxMemSize, slaveActor)
- sender ! true
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ case RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint) =>
+ register(blockManagerId, maxMemSize, slaveEndpoint)
+ context.reply(true)
case UpdateBlockInfo(
blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize) =>
- sender ! updateBlockInfo(
- blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize)
+ context.reply(updateBlockInfo(
+ blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize))
case GetLocations(blockId) =>
- sender ! getLocations(blockId)
+ context.reply(getLocations(blockId))
case GetLocationsMultipleBlockIds(blockIds) =>
- sender ! getLocationsMultipleBlockIds(blockIds)
+ context.reply(getLocationsMultipleBlockIds(blockIds))
case GetPeers(blockManagerId) =>
- sender ! getPeers(blockManagerId)
+ context.reply(getPeers(blockManagerId))
- case GetActorSystemHostPortForExecutor(executorId) =>
- sender ! getActorSystemHostPortForExecutor(executorId)
+ case GetRpcHostPortForExecutor(executorId) =>
+ context.reply(getRpcHostPortForExecutor(executorId))
case GetMemoryStatus =>
- sender ! memoryStatus
+ context.reply(memoryStatus)
case GetStorageStatus =>
- sender ! storageStatus
+ context.reply(storageStatus)
case GetBlockStatus(blockId, askSlaves) =>
- sender ! blockStatus(blockId, askSlaves)
+ context.reply(blockStatus(blockId, askSlaves))
case GetMatchingBlockIds(filter, askSlaves) =>
- sender ! getMatchingBlockIds(filter, askSlaves)
+ context.reply(getMatchingBlockIds(filter, askSlaves))
case RemoveRdd(rddId) =>
- sender ! removeRdd(rddId)
+ context.reply(removeRdd(rddId))
case RemoveShuffle(shuffleId) =>
- sender ! removeShuffle(shuffleId)
+ context.reply(removeShuffle(shuffleId))
case RemoveBroadcast(broadcastId, removeFromDriver) =>
- sender ! removeBroadcast(broadcastId, removeFromDriver)
+ context.reply(removeBroadcast(broadcastId, removeFromDriver))
case RemoveBlock(blockId) =>
removeBlockFromWorkers(blockId)
- sender ! true
+ context.reply(true)
case RemoveExecutor(execId) =>
removeExecutor(execId)
- sender ! true
+ context.reply(true)
case StopBlockManagerMaster =>
- sender ! true
- context.stop(self)
+ context.reply(true)
+ stop()
case BlockManagerHeartbeat(blockManagerId) =>
- sender ! heartbeatReceived(blockManagerId)
+ context.reply(heartbeatReceived(blockManagerId))
- case other =>
- logWarning("Got unknown message: " + other)
}
private def removeRdd(rddId: Int): Future[Seq[Int]] = {
@@ -129,22 +129,20 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
// Ask the slaves to remove the RDD, and put the result in a sequence of Futures.
// The dispatcher is used as an implicit argument into the Future sequence construction.
- import context.dispatcher
val removeMsg = RemoveRdd(rddId)
Future.sequence(
blockManagerInfo.values.map { bm =>
- bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int]
+ bm.slaveEndpoint.sendWithReply[Int](removeMsg)
}.toSeq
)
}
private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = {
- // Nothing to do in the BlockManagerMasterActor data structures
- import context.dispatcher
+ // Nothing to do in the BlockManagerMasterEndpoint data structures
val removeMsg = RemoveShuffle(shuffleId)
Future.sequence(
blockManagerInfo.values.map { bm =>
- bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Boolean]
+ bm.slaveEndpoint.sendWithReply[Boolean](removeMsg)
}.toSeq
)
}
@@ -155,14 +153,13 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
* from the executors, but not from the driver.
*/
private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean): Future[Seq[Int]] = {
- import context.dispatcher
val removeMsg = RemoveBroadcast(broadcastId, removeFromDriver)
val requiredBlockManagers = blockManagerInfo.values.filter { info =>
removeFromDriver || !info.blockManagerId.isDriver
}
Future.sequence(
requiredBlockManagers.map { bm =>
- bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int]
+ bm.slaveEndpoint.sendWithReply[Int](removeMsg)
}.toSeq
)
}
@@ -217,7 +214,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
// Remove the block from the slave's BlockManager.
// Doesn't actually wait for a confirmation and the message might get lost.
// If message loss becomes frequent, we should add retry logic here.
- blockManager.get.slaveActor.ask(RemoveBlock(blockId))(akkaTimeout)
+ blockManager.get.slaveEndpoint.sendWithReply[Boolean](RemoveBlock(blockId))
}
}
}
@@ -247,17 +244,16 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
private def blockStatus(
blockId: BlockId,
askSlaves: Boolean): Map[BlockManagerId, Future[Option[BlockStatus]]] = {
- import context.dispatcher
val getBlockStatus = GetBlockStatus(blockId)
/*
- * Rather than blocking on the block status query, master actor should simply return
+ * Rather than blocking on the block status query, master endpoint should simply return
* Futures to avoid potential deadlocks. This can arise if there exists a block manager
- * that is also waiting for this master actor's response to a previous message.
+ * that is also waiting for this master endpoint's response to a previous message.
*/
blockManagerInfo.values.map { info =>
val blockStatusFuture =
if (askSlaves) {
- info.slaveActor.ask(getBlockStatus)(akkaTimeout).mapTo[Option[BlockStatus]]
+ info.slaveEndpoint.sendWithReply[Option[BlockStatus]](getBlockStatus)
} else {
Future { info.getStatus(blockId) }
}
@@ -276,13 +272,12 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
private def getMatchingBlockIds(
filter: BlockId => Boolean,
askSlaves: Boolean): Future[Seq[BlockId]] = {
- import context.dispatcher
val getMatchingBlockIds = GetMatchingBlockIds(filter)
Future.sequence(
blockManagerInfo.values.map { info =>
val future =
if (askSlaves) {
- info.slaveActor.ask(getMatchingBlockIds)(akkaTimeout).mapTo[Seq[BlockId]]
+ info.slaveEndpoint.sendWithReply[Seq[BlockId]](getMatchingBlockIds)
} else {
Future { info.blocks.keys.filter(filter).toSeq }
}
@@ -291,7 +286,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
).map(_.flatten.toSeq)
}
- private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
+ private def register(id: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef) {
val time = System.currentTimeMillis()
if (!blockManagerInfo.contains(id)) {
blockManagerIdByExecutor.get(id.executorId) match {
@@ -308,7 +303,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
blockManagerIdByExecutor(id.executorId) = id
blockManagerInfo(id) = new BlockManagerInfo(
- id, System.currentTimeMillis(), maxMemSize, slaveActor)
+ id, System.currentTimeMillis(), maxMemSize, slaveEndpoint)
}
listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxMemSize))
}
@@ -379,19 +374,21 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
}
/**
- * Returns the hostname and port of an executor's actor system, based on the Akka address of its
- * BlockManagerSlaveActor.
+ * Returns the hostname and port of an executor, based on the [[RpcEnv]] address of its
+ * [[BlockManagerSlaveEndpoint]].
*/
- private def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = {
+ private def getRpcHostPortForExecutor(executorId: String): Option[(String, Int)] = {
for (
blockManagerId <- blockManagerIdByExecutor.get(executorId);
- info <- blockManagerInfo.get(blockManagerId);
- host <- info.slaveActor.path.address.host;
- port <- info.slaveActor.path.address.port
+ info <- blockManagerInfo.get(blockManagerId)
) yield {
- (host, port)
+ (info.slaveEndpoint.address.host, info.slaveEndpoint.address.port)
}
}
+
+ override def onStop(): Unit = {
+ askThreadPool.shutdownNow()
+ }
}
@DeveloperApi
@@ -412,7 +409,7 @@ private[spark] class BlockManagerInfo(
val blockManagerId: BlockManagerId,
timeMs: Long,
val maxMem: Long,
- val slaveActor: ActorRef)
+ val slaveEndpoint: RpcEndpointRef)
extends Logging {
private var _lastSeenMs: Long = timeMs
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
index 48247453ed..f89d8d7493 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
@@ -19,8 +19,7 @@ package org.apache.spark.storage
import java.io.{Externalizable, ObjectInput, ObjectOutput}
-import akka.actor.ActorRef
-
+import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.Utils
private[spark] object BlockManagerMessages {
@@ -52,7 +51,7 @@ private[spark] object BlockManagerMessages {
case class RegisterBlockManager(
blockManagerId: BlockManagerId,
maxMemSize: Long,
- sender: ActorRef)
+ sender: RpcEndpointRef)
extends ToBlockManagerMaster
case class UpdateBlockInfo(
@@ -92,7 +91,7 @@ private[spark] object BlockManagerMessages {
case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster
- case class GetActorSystemHostPortForExecutor(executorId: String) extends ToBlockManagerMaster
+ case class GetRpcHostPortForExecutor(executorId: String) extends ToBlockManagerMaster
case class RemoveExecutor(execId: String) extends ToBlockManagerMaster
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala
index 52fb896c4e..8980fa8eb7 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala
@@ -17,41 +17,43 @@
package org.apache.spark.storage
-import scala.concurrent.Future
-
-import akka.actor.{ActorRef, Actor}
+import scala.concurrent.{ExecutionContext, Future}
+import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint}
+import org.apache.spark.util.Utils
import org.apache.spark.{Logging, MapOutputTracker, SparkEnv}
import org.apache.spark.storage.BlockManagerMessages._
-import org.apache.spark.util.ActorLogReceive
/**
- * An actor to take commands from the master to execute options. For example,
+ * An RpcEndpoint to take commands from the master to execute options. For example,
* this is used to remove blocks from the slave's BlockManager.
*/
private[storage]
-class BlockManagerSlaveActor(
+class BlockManagerSlaveEndpoint(
+ override val rpcEnv: RpcEnv,
blockManager: BlockManager,
mapOutputTracker: MapOutputTracker)
- extends Actor with ActorLogReceive with Logging {
+ extends RpcEndpoint with Logging {
- import context.dispatcher
+ private val asyncThreadPool =
+ Utils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool")
+ private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool)
// Operations that involve removing blocks may be slow and should be done asynchronously
- override def receiveWithLogging: PartialFunction[Any, Unit] = {
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case RemoveBlock(blockId) =>
- doAsync[Boolean]("removing block " + blockId, sender) {
+ doAsync[Boolean]("removing block " + blockId, context) {
blockManager.removeBlock(blockId)
true
}
case RemoveRdd(rddId) =>
- doAsync[Int]("removing RDD " + rddId, sender) {
+ doAsync[Int]("removing RDD " + rddId, context) {
blockManager.removeRdd(rddId)
}
case RemoveShuffle(shuffleId) =>
- doAsync[Boolean]("removing shuffle " + shuffleId, sender) {
+ doAsync[Boolean]("removing shuffle " + shuffleId, context) {
if (mapOutputTracker != null) {
mapOutputTracker.unregisterShuffle(shuffleId)
}
@@ -59,30 +61,34 @@ class BlockManagerSlaveActor(
}
case RemoveBroadcast(broadcastId, _) =>
- doAsync[Int]("removing broadcast " + broadcastId, sender) {
+ doAsync[Int]("removing broadcast " + broadcastId, context) {
blockManager.removeBroadcast(broadcastId, tellMaster = true)
}
case GetBlockStatus(blockId, _) =>
- sender ! blockManager.getStatus(blockId)
+ context.reply(blockManager.getStatus(blockId))
case GetMatchingBlockIds(filter, _) =>
- sender ! blockManager.getMatchingBlockIds(filter)
+ context.reply(blockManager.getMatchingBlockIds(filter))
}
- private def doAsync[T](actionMessage: String, responseActor: ActorRef)(body: => T) {
+ private def doAsync[T](actionMessage: String, context: RpcCallContext)(body: => T) {
val future = Future {
logDebug(actionMessage)
body
}
future.onSuccess { case response =>
logDebug("Done " + actionMessage + ", response is " + response)
- responseActor ! response
- logDebug("Sent response: " + response + " to " + responseActor)
+ context.reply(response)
+ logDebug("Sent response: " + response + " to " + context.sender)
}
future.onFailure { case t: Throwable =>
logError("Error in " + actionMessage, t)
- responseActor ! null.asInstanceOf[T]
+ context.sendFailure(t)
}
}
+
+ override def onStop(): Unit = {
+ asyncThreadPool.shutdownNow()
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 7c85e28679..0fdfaf300e 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -1214,6 +1214,16 @@ private[spark] object Utils extends Logging {
}
}
+ /** Executes the given block. Log non-fatal errors if any, and only throw fatal errors */
+ def tryLogNonFatalError(block: => Unit) {
+ try {
+ block
+ } catch {
+ case NonFatal(t) =>
+ logError(s"Uncaught exception in thread ${Thread.currentThread().getName}", t)
+ }
+ }
+
/**
* Execute a block of code, then a finally block, but if exceptions happen in
* the finally block, do not suppress the original exception.
diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
new file mode 100644
index 0000000000..0fd570e529
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import scala.concurrent.duration._
+import scala.language.postfixOps
+
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.storage.BlockManagerId
+import org.scalatest.FunSuite
+import org.mockito.Mockito.{mock, spy, verify, when}
+import org.mockito.Matchers
+import org.mockito.Matchers._
+
+import org.apache.spark.scheduler.TaskScheduler
+import org.apache.spark.util.RpcUtils
+import org.scalatest.concurrent.Eventually._
+
+class HeartbeatReceiverSuite extends FunSuite with LocalSparkContext {
+
+ test("HeartbeatReceiver") {
+ sc = spy(new SparkContext("local[2]", "test"))
+ val scheduler = mock(classOf[TaskScheduler])
+ when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true)
+ when(sc.taskScheduler).thenReturn(scheduler)
+
+ val heartbeatReceiver = new HeartbeatReceiver(sc)
+ sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet)
+ eventually(timeout(5 seconds), interval(5 millis)) {
+ assert(heartbeatReceiver.scheduler != null)
+ }
+ val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv)
+
+ val metrics = new TaskMetrics
+ val blockManagerId = BlockManagerId("executor-1", "localhost", 12345)
+ val response = receiverRef.askWithReply[HeartbeatResponse](
+ Heartbeat("executor-1", Array(1L -> metrics), blockManagerId))
+
+ verify(scheduler).executorHeartbeatReceived(
+ Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId))
+ assert(false === response.reregisterBlockManager)
+ }
+
+ test("HeartbeatReceiver re-register") {
+ sc = spy(new SparkContext("local[2]", "test"))
+ val scheduler = mock(classOf[TaskScheduler])
+ when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(false)
+ when(sc.taskScheduler).thenReturn(scheduler)
+
+ val heartbeatReceiver = new HeartbeatReceiver(sc)
+ sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet)
+ eventually(timeout(5 seconds), interval(5 millis)) {
+ assert(heartbeatReceiver.scheduler != null)
+ }
+ val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv)
+
+ val metrics = new TaskMetrics
+ val blockManagerId = BlockManagerId("executor-1", "localhost", 12345)
+ val response = receiverRef.askWithReply[HeartbeatResponse](
+ Heartbeat("executor-1", Array(1L -> metrics), blockManagerId))
+
+ verify(scheduler).executorHeartbeatReceived(
+ Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId))
+ assert(true === response.reregisterBlockManager)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
index e07bdb9637..4f19c4f211 100644
--- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
@@ -311,7 +311,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll {
}
test("self: call in onStop") {
- @volatile var e: Throwable = null
+ @volatile var selfOption: Option[RpcEndpointRef] = null
val endpointRef = env.setupEndpoint("self-onStop", new RpcEndpoint {
override val rpcEnv = env
@@ -321,20 +321,18 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll {
}
override def onStop(): Unit = {
- self
+ selfOption = Option(self)
}
override def onError(cause: Throwable): Unit = {
- e = cause
}
})
env.stop(endpointRef)
eventually(timeout(5 seconds), interval(10 millis)) {
- // Calling `self` in `onStop` is invalid
- assert(e != null)
- assert(e.getMessage.contains("Cannot find RpcEndpointRef"))
+ // Calling `self` in `onStop` will return null, so selfOption will be None
+ assert(selfOption == None)
}
}
@@ -342,7 +340,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll {
// If a RpcEnv implementation breaks the `receive` contract, hope this test can expose it
for(i <- 0 until 100) {
@volatile var result = 0
- val endpointRef = env.setupThreadSafeEndpoint(s"receive-in-sequence-$i", new RpcEndpoint {
+ val endpointRef = env.setupEndpoint(s"receive-in-sequence-$i", new ThreadSafeRpcEndpoint {
override val rpcEnv = env
override def receive = {
@@ -475,7 +473,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll {
test("network events") {
val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)]
- env.setupThreadSafeEndpoint("network-events", new RpcEndpoint {
+ env.setupEndpoint("network-events", new ThreadSafeRpcEndpoint {
override val rpcEnv = env
override def receive = {
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
index c2903c8597..b4de90b65d 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
@@ -22,11 +22,11 @@ import scala.concurrent.duration._
import scala.language.implicitConversions
import scala.language.postfixOps
-import akka.actor.{ActorSystem, Props}
import org.mockito.Mockito.{mock, when}
-import org.scalatest.{BeforeAndAfter, FunSuite, Matchers, PrivateMethodTester}
+import org.scalatest.{BeforeAndAfter, FunSuite, Matchers}
import org.scalatest.concurrent.Eventually._
+import org.apache.spark.rpc.RpcEnv
import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SecurityManager}
import org.apache.spark.network.BlockTransferService
import org.apache.spark.network.nio.NioBlockTransferService
@@ -34,13 +34,12 @@ import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.shuffle.hash.HashShuffleManager
import org.apache.spark.storage.StorageLevel._
-import org.apache.spark.util.{AkkaUtils, SizeEstimator}
/** Testsuite that tests block replication in BlockManager */
class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAndAfter {
private val conf = new SparkConf(false)
- var actorSystem: ActorSystem = null
+ var rpcEnv: RpcEnv = null
var master: BlockManagerMaster = null
val securityMgr = new SecurityManager(conf)
val mapOutputTracker = new MapOutputTrackerMaster(conf)
@@ -61,7 +60,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd
maxMem: Long,
name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = {
val transfer = new NioBlockTransferService(conf, securityMgr)
- val store = new BlockManager(name, actorSystem, master, serializer, maxMem, conf,
+ val store = new BlockManager(name, rpcEnv, master, serializer, maxMem, conf,
mapOutputTracker, shuffleManager, transfer, securityMgr, 0)
store.initialize("app-id")
allStores += store
@@ -69,12 +68,10 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd
}
before {
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem(
- "test", "localhost", 0, conf = conf, securityManager = securityMgr)
- this.actorSystem = actorSystem
+ rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr)
conf.set("spark.authenticate", "false")
- conf.set("spark.driver.port", boundPort.toString)
+ conf.set("spark.driver.port", rpcEnv.address.port.toString)
conf.set("spark.storage.unrollFraction", "0.4")
conf.set("spark.storage.unrollMemoryThreshold", "512")
@@ -83,18 +80,17 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd
// to make cached peers refresh frequently
conf.set("spark.storage.cachedPeersTtl", "10")
- master = new BlockManagerMaster(
- actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))),
- conf, true)
+ master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager",
+ new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true)
allStores.clear()
}
after {
allStores.foreach { _.stop() }
allStores.clear()
- actorSystem.shutdown()
- actorSystem.awaitTermination()
- actorSystem = null
+ rpcEnv.shutdown()
+ rpcEnv.awaitTermination()
+ rpcEnv = null
master = null
}
@@ -262,7 +258,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd
val failableTransfer = mock(classOf[BlockTransferService]) // this wont actually work
when(failableTransfer.hostName).thenReturn("some-hostname")
when(failableTransfer.port).thenReturn(1000)
- val failableStore = new BlockManager("failable-store", actorSystem, master, serializer,
+ val failableStore = new BlockManager("failable-store", rpcEnv, master, serializer,
10000, conf, mapOutputTracker, shuffleManager, failableTransfer, securityMgr, 0)
failableStore.initialize("app-id")
allStores += failableStore // so that this gets stopped after test
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index ecd1cba5b5..283090e3bd 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -19,24 +19,18 @@ package org.apache.spark.storage
import java.nio.{ByteBuffer, MappedByteBuffer}
import java.util.Arrays
-import java.util.concurrent.TimeUnit
import scala.collection.mutable.ArrayBuffer
-import scala.concurrent.Await
import scala.concurrent.duration._
import scala.language.implicitConversions
import scala.language.postfixOps
-import akka.actor._
-import akka.pattern.ask
-import akka.util.Timeout
-
import org.mockito.Mockito.{mock, when}
-
import org.scalatest._
import org.scalatest.concurrent.Eventually._
import org.scalatest.concurrent.Timeouts._
+import org.apache.spark.rpc.RpcEnv
import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SecurityManager}
import org.apache.spark.executor.DataReadMethod
import org.apache.spark.network.nio.NioBlockTransferService
@@ -53,7 +47,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach
private val conf = new SparkConf(false)
var store: BlockManager = null
var store2: BlockManager = null
- var actorSystem: ActorSystem = null
+ var rpcEnv: RpcEnv = null
var master: BlockManagerMaster = null
conf.set("spark.authenticate", "false")
val securityMgr = new SecurityManager(conf)
@@ -72,28 +66,25 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach
maxMem: Long,
name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = {
val transfer = new NioBlockTransferService(conf, securityMgr)
- val manager = new BlockManager(name, actorSystem, master, serializer, maxMem, conf,
+ val manager = new BlockManager(name, rpcEnv, master, serializer, maxMem, conf,
mapOutputTracker, shuffleManager, transfer, securityMgr, 0)
manager.initialize("app-id")
manager
}
override def beforeEach(): Unit = {
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem(
- "test", "localhost", 0, conf = conf, securityManager = securityMgr)
- this.actorSystem = actorSystem
+ rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr)
// Set the arch to 64-bit and compressedOops to true to get a deterministic test-case
System.setProperty("os.arch", "amd64")
conf.set("os.arch", "amd64")
conf.set("spark.test.useCompressedOops", "true")
- conf.set("spark.driver.port", boundPort.toString)
+ conf.set("spark.driver.port", rpcEnv.address.port.toString)
conf.set("spark.storage.unrollFraction", "0.4")
conf.set("spark.storage.unrollMemoryThreshold", "512")
- master = new BlockManagerMaster(
- actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))),
- conf, true)
+ master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager",
+ new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true)
val initialize = PrivateMethod[Unit]('initialize)
SizeEstimator invokePrivate initialize()
@@ -108,9 +99,9 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach
store2.stop()
store2 = null
}
- actorSystem.shutdown()
- actorSystem.awaitTermination()
- actorSystem = null
+ rpcEnv.shutdown()
+ rpcEnv.awaitTermination()
+ rpcEnv = null
master = null
}
@@ -357,10 +348,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach
master.removeExecutor(store.blockManagerId.executorId)
assert(master.getLocations("a1").size == 0, "a1 was not removed from master")
- implicit val timeout = Timeout(30, TimeUnit.SECONDS)
- val reregister = !Await.result(
- master.driverActor ? BlockManagerHeartbeat(store.blockManagerId),
- timeout.duration).asInstanceOf[Boolean]
+ val reregister = !master.driverEndpoint.askWithReply[Boolean](
+ BlockManagerHeartbeat(store.blockManagerId))
assert(reregister == true)
}
@@ -785,7 +774,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach
test("block store put failure") {
// Use Java serializer so we can create an unserializable error.
val transfer = new NioBlockTransferService(conf, securityMgr)
- store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, actorSystem, master,
+ store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master,
new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer, securityMgr,
0)