aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorzsxwing <zsxwing@gmail.com>2015-04-05 21:57:15 -0700
committerReynold Xin <rxin@databricks.com>2015-04-05 21:57:15 -0700
commit0b5d028a93b7d5adb148fbf3a576257bb3a6d8cb (patch)
treeff2700f33ca69693f59608ea53f2fbee3fbd2490
parentacffc43455d7b3e4000be4ff0175b8ea19cd280b (diff)
downloadspark-0b5d028a93b7d5adb148fbf3a576257bb3a6d8cb.tar.gz
spark-0b5d028a93b7d5adb148fbf3a576257bb3a6d8cb.tar.bz2
spark-0b5d028a93b7d5adb148fbf3a576257bb3a6d8cb.zip
[SPARK-6602][Core] Update MapOutputTrackerMasterActor to MapOutputTrackerMasterEndpoint
This is the second PR for [SPARK-6602]. It updated MapOutputTrackerMasterActor and its unit tests. cc rxin Author: zsxwing <zsxwing@gmail.com> Closes #5371 from zsxwing/rpc-rewrite-part2 and squashes the following commits: fcf3816 [zsxwing] Fix the code style 4013a22 [zsxwing] Add doc for uncaught exceptions in RpcEnv 93c6c20 [zsxwing] Add an example of UnserializableException and add ErrorMonitor to monitor errors from Akka 134fe7b [zsxwing] Update MapOutputTrackerMasterActor to MapOutputTrackerMasterEndpoint
-rw-r--r--core/src/main/scala/org/apache/spark/MapOutputTracker.scala61
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala18
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala19
-rw-r--r--core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala100
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala33
-rw-r--r--core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala198
7 files changed, 221 insertions, 212 deletions
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 5718951451..d65c94e410 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -21,13 +21,11 @@ import java.io._
import java.util.concurrent.ConcurrentHashMap
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
-import scala.collection.mutable.{HashSet, HashMap, Map}
-import scala.concurrent.Await
+import scala.collection.mutable.{HashSet, Map}
import scala.collection.JavaConversions._
+import scala.reflect.ClassTag
-import akka.actor._
-import akka.pattern.ask
-
+import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, RpcEndpoint}
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.shuffle.MetadataFetchFailedException
import org.apache.spark.storage.BlockManagerId
@@ -38,14 +36,15 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int)
extends MapOutputTrackerMessage
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
-/** Actor class for MapOutputTrackerMaster */
-private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster, conf: SparkConf)
- extends Actor with ActorLogReceive with Logging {
+/** RpcEndpoint class for MapOutputTrackerMaster */
+private[spark] class MapOutputTrackerMasterEndpoint(
+ override val rpcEnv: RpcEnv, tracker: MapOutputTrackerMaster, conf: SparkConf)
+ extends RpcEndpoint with Logging {
val maxAkkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)
- override def receiveWithLogging: PartialFunction[Any, Unit] = {
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case GetMapOutputStatuses(shuffleId: Int) =>
- val hostPort = sender.path.address.hostPort
+ val hostPort = context.sender.address.hostPort
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId)
val serializedSize = mapOutputStatuses.size
@@ -53,19 +52,19 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster
val msg = s"Map output statuses were $serializedSize bytes which " +
s"exceeds spark.akka.frameSize ($maxAkkaFrameSize bytes)."
- /* For SPARK-1244 we'll opt for just logging an error and then throwing an exception.
- * Note that on exception the actor will just restart. A bigger refactoring (SPARK-1239)
- * will ultimately remove this entire code path. */
+ /* For SPARK-1244 we'll opt for just logging an error and then sending it to the sender.
+ * A bigger refactoring (SPARK-1239) will ultimately remove this entire code path. */
val exception = new SparkException(msg)
logError(msg, exception)
- throw exception
+ context.sendFailure(exception)
+ } else {
+ context.reply(mapOutputStatuses)
}
- sender ! mapOutputStatuses
case StopMapOutputTracker =>
- logInfo("MapOutputTrackerActor stopped!")
- sender ! true
- context.stop(self)
+ logInfo("MapOutputTrackerMasterEndpoint stopped!")
+ context.reply(true)
+ stop()
}
}
@@ -75,12 +74,9 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster
* (driver and executor) use different HashMap to store its metadata.
*/
private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging {
- private val timeout = AkkaUtils.askTimeout(conf)
- private val retryAttempts = AkkaUtils.numRetries(conf)
- private val retryIntervalMs = AkkaUtils.retryWaitMs(conf)
- /** Set to the MapOutputTrackerActor living on the driver. */
- var trackerActor: ActorRef = _
+ /** Set to the MapOutputTrackerMasterEndpoint living on the driver. */
+ var trackerEndpoint: RpcEndpointRef = _
/**
* This HashMap has different behavior for the driver and the executors.
@@ -105,12 +101,12 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
private val fetching = new HashSet[Int]
/**
- * Send a message to the trackerActor and get its result within a default timeout, or
+ * Send a message to the trackerEndpoint and get its result within a default timeout, or
* throw a SparkException if this fails.
*/
- protected def askTracker(message: Any): Any = {
+ protected def askTracker[T: ClassTag](message: Any): T = {
try {
- AkkaUtils.askWithReply(message, trackerActor, retryAttempts, retryIntervalMs, timeout)
+ trackerEndpoint.askWithReply[T](message)
} catch {
case e: Exception =>
logError("Error communicating with MapOutputTracker", e)
@@ -118,9 +114,9 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
}
}
- /** Send a one-way message to the trackerActor, to which we expect it to reply with true. */
+ /** Send a one-way message to the trackerEndpoint, to which we expect it to reply with true. */
protected def sendTracker(message: Any) {
- val response = askTracker(message)
+ val response = askTracker[Boolean](message)
if (response != true) {
throw new SparkException(
"Error reply received from MapOutputTracker. Expecting true, got " + response.toString)
@@ -157,11 +153,10 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
if (fetchedStatuses == null) {
// We won the race to fetch the output locs; do so
- logInfo("Doing the fetch; tracker actor = " + trackerActor)
+ logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
// This try-finally prevents hangs due to timeouts:
try {
- val fetchedBytes =
- askTracker(GetMapOutputStatuses(shuffleId)).asInstanceOf[Array[Byte]]
+ val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
logInfo("Got the output locations")
mapStatuses.put(shuffleId, fetchedStatuses)
@@ -328,7 +323,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
override def stop() {
sendTracker(StopMapOutputTracker)
mapStatuses.clear()
- trackerActor = null
+ trackerEndpoint = null
metadataCleaner.cancel()
cachedSerializedStatuses.clear()
}
@@ -350,6 +345,8 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
private[spark] object MapOutputTracker extends Logging {
+ val ENDPOINT_NAME = "MapOutputTracker"
+
// Serialize an array of map output locations into an efficient byte format so that we can send
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
// generally be pretty compressible because many map outputs will be on the same hostname.
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 55be0a59fe..0171488e09 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -24,7 +24,6 @@ import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.util.Properties
-import akka.actor._
import com.google.common.collect.MapMaker
import org.apache.spark.annotation.DeveloperApi
@@ -41,7 +40,7 @@ import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinato
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager}
import org.apache.spark.storage._
-import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils}
+import org.apache.spark.util.{RpcUtils, Utils}
/**
* :: DeveloperApi ::
@@ -286,15 +285,6 @@ object SparkEnv extends Logging {
val closureSerializer = instantiateClassFromConf[Serializer](
"spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer")
- def registerOrLookup(name: String, newActor: => Actor): ActorRef = {
- if (isDriver) {
- logInfo("Registering " + name)
- actorSystem.actorOf(Props(newActor), name = name)
- } else {
- AkkaUtils.makeDriverRef(name, conf, actorSystem)
- }
- }
-
def registerOrLookupEndpoint(
name: String, endpointCreator: => RpcEndpoint):
RpcEndpointRef = {
@@ -314,9 +304,9 @@ object SparkEnv extends Logging {
// Have to assign trackerActor after initialization as MapOutputTrackerActor
// requires the MapOutputTracker itself
- mapOutputTracker.trackerActor = registerOrLookup(
- "MapOutputTracker",
- new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf))
+ mapOutputTracker.trackerEndpoint = registerOrLookupEndpoint(MapOutputTracker.ENDPOINT_NAME,
+ new MapOutputTrackerMasterEndpoint(
+ rpcEnv, mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf))
// Let the user specify short names for shuffle managers
val shortShuffleMgrNames = Map(
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
index d47e41abcf..e259867c14 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
@@ -30,7 +30,9 @@ import org.apache.spark.util.{AkkaUtils, Utils}
/**
* An RPC environment. [[RpcEndpoint]]s need to register itself with a name to [[RpcEnv]] to
* receives messages. Then [[RpcEnv]] will process messages sent from [[RpcEndpointRef]] or remote
- * nodes, and deliver them to corresponding [[RpcEndpoint]]s.
+ * nodes, and deliver them to corresponding [[RpcEndpoint]]s. For uncaught exceptions caught by
+ * [[RpcEnv]], [[RpcEnv]] will use [[RpcCallContext.sendFailure]] to send exceptions back to the
+ * sender, or logging them if no such sender or `NotSerializableException`.
*
* [[RpcEnv]] also provides some methods to retrieve [[RpcEndpointRef]]s given name or uri.
*/
diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
index 9e06147dff..652e52f2b2 100644
--- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
@@ -17,16 +17,16 @@
package org.apache.spark.rpc.akka
-import java.net.URI
import java.util.concurrent.ConcurrentHashMap
-import scala.concurrent.{Await, Future}
+import scala.concurrent.Future
import scala.concurrent.duration._
import scala.language.postfixOps
import scala.reflect.ClassTag
import scala.util.control.NonFatal
import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Address}
+import akka.event.Logging.Error
import akka.pattern.{ask => akkaAsk}
import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent}
import org.apache.spark.{SparkException, Logging, SparkConf}
@@ -242,10 +242,25 @@ private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory {
def create(config: RpcEnvConfig): RpcEnv = {
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(
config.name, config.host, config.port, config.conf, config.securityManager)
+ actorSystem.actorOf(Props(classOf[ErrorMonitor]), "ErrorMonitor")
new AkkaRpcEnv(actorSystem, config.conf, boundPort)
}
}
+/**
+ * Monitor errors reported by Akka and log them.
+ */
+private[akka] class ErrorMonitor extends Actor with ActorLogReceive with Logging {
+
+ override def preStart(): Unit = {
+ context.system.eventStream.subscribe(self, classOf[Error])
+ }
+
+ override def receiveWithLogging: Actor.Receive = {
+ case Error(cause: Throwable, _, _, message: String) => logError(message, cause)
+ }
+}
+
private[akka] class AkkaRpcEndpointRef(
@transient defaultAddress: RpcAddress,
@transient _actorRef: => ActorRef,
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index ccfe0678cb..6295d34be5 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -17,34 +17,37 @@
package org.apache.spark
-import scala.concurrent.Await
-
-import akka.actor._
-import akka.testkit.TestActorRef
+import org.mockito.Mockito._
+import org.mockito.Matchers.{any, isA}
import org.scalatest.FunSuite
+import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcCallContext, RpcEnv}
import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus}
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.BlockManagerId
-import org.apache.spark.util.AkkaUtils
class MapOutputTrackerSuite extends FunSuite {
private val conf = new SparkConf
+ def createRpcEnv(name: String, host: String = "localhost", port: Int = 0,
+ securityManager: SecurityManager = new SecurityManager(conf)): RpcEnv = {
+ RpcEnv.create(name, host, port, conf, securityManager)
+ }
+
test("master start and stop") {
- val actorSystem = ActorSystem("test")
+ val rpcEnv = createRpcEnv("test")
val tracker = new MapOutputTrackerMaster(conf)
- tracker.trackerActor =
- actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf)))
+ tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
+ new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
tracker.stop()
- actorSystem.shutdown()
+ rpcEnv.shutdown()
}
test("master register shuffle and fetch") {
- val actorSystem = ActorSystem("test")
+ val rpcEnv = createRpcEnv("test")
val tracker = new MapOutputTrackerMaster(conf)
- tracker.trackerActor =
- actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf)))
+ tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
+ new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
tracker.registerShuffle(10, 2)
assert(tracker.containsShuffle(10))
val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
@@ -57,13 +60,14 @@ class MapOutputTrackerSuite extends FunSuite {
assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000),
(BlockManagerId("b", "hostB", 1000), size10000)))
tracker.stop()
- actorSystem.shutdown()
+ rpcEnv.shutdown()
}
test("master register and unregister shuffle") {
- val actorSystem = ActorSystem("test")
+ val rpcEnv = createRpcEnv("test")
val tracker = new MapOutputTrackerMaster(conf)
- tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf)))
+ tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
+ new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
tracker.registerShuffle(10, 2)
val compressedSize1000 = MapStatus.compressSize(1000L)
val compressedSize10000 = MapStatus.compressSize(10000L)
@@ -78,14 +82,14 @@ class MapOutputTrackerSuite extends FunSuite {
assert(tracker.getServerStatuses(10, 0).isEmpty)
tracker.stop()
- actorSystem.shutdown()
+ rpcEnv.shutdown()
}
test("master register shuffle and unregister map output and fetch") {
- val actorSystem = ActorSystem("test")
+ val rpcEnv = createRpcEnv("test")
val tracker = new MapOutputTrackerMaster(conf)
- tracker.trackerActor =
- actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf)))
+ tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
+ new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
tracker.registerShuffle(10, 2)
val compressedSize1000 = MapStatus.compressSize(1000L)
val compressedSize10000 = MapStatus.compressSize(10000L)
@@ -104,25 +108,21 @@ class MapOutputTrackerSuite extends FunSuite {
intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) }
tracker.stop()
- actorSystem.shutdown()
+ rpcEnv.shutdown()
}
test("remote fetch") {
val hostname = "localhost"
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf,
- securityManager = new SecurityManager(conf))
+ val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf))
val masterTracker = new MapOutputTrackerMaster(conf)
- masterTracker.trackerActor = actorSystem.actorOf(
- Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")
+ masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
+ new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf))
- val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf,
- securityManager = new SecurityManager(conf))
+ val slaveRpcEnv = createRpcEnv("spark-slave", hostname, 0, new SecurityManager(conf))
val slaveTracker = new MapOutputTrackerWorker(conf)
- val selection = slaveSystem.actorSelection(
- AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker"))
- val timeout = AkkaUtils.lookupTimeout(conf)
- slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout)
+ slaveTracker.trackerEndpoint =
+ slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME)
masterTracker.registerShuffle(10, 1)
masterTracker.incrementEpoch()
@@ -147,8 +147,8 @@ class MapOutputTrackerSuite extends FunSuite {
masterTracker.stop()
slaveTracker.stop()
- actorSystem.shutdown()
- slaveSystem.shutdown()
+ rpcEnv.shutdown()
+ slaveRpcEnv.shutdown()
}
test("remote fetch below akka frame size") {
@@ -157,19 +157,24 @@ class MapOutputTrackerSuite extends FunSuite {
newConf.set("spark.akka.askTimeout", "1") // Fail fast
val masterTracker = new MapOutputTrackerMaster(conf)
- val actorSystem = ActorSystem("test")
- val actorRef = TestActorRef[MapOutputTrackerMasterActor](
- Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem)
- val masterActor = actorRef.underlyingActor
+ val rpcEnv = createRpcEnv("spark")
+ val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf)
+ rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint)
// Frame size should be ~123B, and no exception should be thrown
masterTracker.registerShuffle(10, 1)
masterTracker.registerMapOutput(10, 0, MapStatus(
BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0)))
- masterActor.receive(GetMapOutputStatuses(10))
+ val sender = mock(classOf[RpcEndpointRef])
+ when(sender.address).thenReturn(RpcAddress("localhost", 12345))
+ val rpcCallContext = mock(classOf[RpcCallContext])
+ when(rpcCallContext.sender).thenReturn(sender)
+ masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(10))
+ verify(rpcCallContext).reply(any())
+ verify(rpcCallContext, never()).sendFailure(any())
// masterTracker.stop() // this throws an exception
- actorSystem.shutdown()
+ rpcEnv.shutdown()
}
test("remote fetch exceeds akka frame size") {
@@ -178,12 +183,11 @@ class MapOutputTrackerSuite extends FunSuite {
newConf.set("spark.akka.askTimeout", "1") // Fail fast
val masterTracker = new MapOutputTrackerMaster(conf)
- val actorSystem = ActorSystem("test")
- val actorRef = TestActorRef[MapOutputTrackerMasterActor](
- Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem)
- val masterActor = actorRef.underlyingActor
+ val rpcEnv = createRpcEnv("test")
+ val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf)
+ rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint)
- // Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should throw exception.
+ // Frame size should be ~1.1MB, and MapOutputTrackerMasterEndpoint should throw exception.
// Note that the size is hand-selected here because map output statuses are compressed before
// being sent.
masterTracker.registerShuffle(20, 100)
@@ -191,9 +195,15 @@ class MapOutputTrackerSuite extends FunSuite {
masterTracker.registerMapOutput(20, i, new CompressedMapStatus(
BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0)))
}
- intercept[SparkException] { masterActor.receive(GetMapOutputStatuses(20)) }
+ val sender = mock(classOf[RpcEndpointRef])
+ when(sender.address).thenReturn(RpcAddress("localhost", 12345))
+ val rpcCallContext = mock(classOf[RpcCallContext])
+ when(rpcCallContext.sender).thenReturn(sender)
+ masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(20))
+ verify(rpcCallContext, never()).reply(any())
+ verify(rpcCallContext).sendFailure(isA(classOf[SparkException]))
// masterTracker.stop() // this throws an exception
- actorSystem.shutdown()
+ rpcEnv.shutdown()
}
}
diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
index 4f19c4f211..5a734ec5ba 100644
--- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
@@ -514,10 +514,35 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll {
("onDisconnected", remoteAddress)))
}
}
-}
-case object Start
+ test("sendWithReply: unserializable error") {
+ env.setupEndpoint("sendWithReply-unserializable-error", new RpcEndpoint {
+ override val rpcEnv = env
-case class Ping(id: Int)
+ override def receiveAndReply(context: RpcCallContext) = {
+ case msg: String => context.sendFailure(new UnserializableException)
+ }
+ })
-case class Pong(id: Int)
+ val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345)
+ // Use anotherEnv to find out the RpcEndpointRef
+ val rpcEndpointRef = anotherEnv.setupEndpointRef(
+ "local", env.address, "sendWithReply-unserializable-error")
+ try {
+ val f = rpcEndpointRef.sendWithReply[String]("hello")
+ intercept[TimeoutException] {
+ Await.result(f, 1 seconds)
+ }
+ } finally {
+ anotherEnv.shutdown()
+ anotherEnv.awaitTermination()
+ }
+ }
+
+}
+
+class UnserializableClass
+
+class UnserializableException extends Exception {
+ private val unserializableField = new UnserializableClass
+}
diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala
index 6250d50fb7..bec79fc4dc 100644
--- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala
@@ -19,14 +19,11 @@ package org.apache.spark.util
import java.util.concurrent.TimeoutException
-import scala.concurrent.Await
-import scala.util.{Failure, Try}
-
-import akka.actor._
-
+import akka.actor.ActorNotFound
import org.scalatest.FunSuite
import org.apache.spark._
+import org.apache.spark.rpc.RpcEnv
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.SSLSampleConfigs._
@@ -39,39 +36,37 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro
test("remote fetch security bad password") {
val conf = new SparkConf
+ conf.set("spark.rpc", "akka")
conf.set("spark.authenticate", "true")
conf.set("spark.authenticate.secret", "good")
val securityManager = new SecurityManager(conf)
val hostname = "localhost"
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0,
- conf = conf, securityManager = securityManager)
- System.setProperty("spark.hostPort", hostname + ":" + boundPort)
+ val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager)
+ System.setProperty("spark.hostPort", rpcEnv.address.hostPort)
assert(securityManager.isAuthenticationEnabled() === true)
val masterTracker = new MapOutputTrackerMaster(conf)
- masterTracker.trackerActor = actorSystem.actorOf(
- Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")
+ masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
+ new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf))
val badconf = new SparkConf
+ badconf.set("spark.rpc", "akka")
badconf.set("spark.authenticate", "true")
badconf.set("spark.authenticate.secret", "bad")
val securityManagerBad = new SecurityManager(badconf)
assert(securityManagerBad.isAuthenticationEnabled() === true)
- val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
- conf = conf, securityManager = securityManagerBad)
+ val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, conf, securityManagerBad)
val slaveTracker = new MapOutputTrackerWorker(conf)
- val selection = slaveSystem.actorSelection(
- AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker"))
- val timeout = AkkaUtils.lookupTimeout(conf)
intercept[akka.actor.ActorNotFound] {
- slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout)
+ slaveTracker.trackerEndpoint =
+ slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME)
}
- actorSystem.shutdown()
- slaveSystem.shutdown()
+ rpcEnv.shutdown()
+ slaveRpcEnv.shutdown()
}
test("remote fetch security off") {
@@ -81,28 +76,24 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro
val securityManager = new SecurityManager(conf)
val hostname = "localhost"
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0,
- conf = conf, securityManager = securityManager)
- System.setProperty("spark.hostPort", hostname + ":" + boundPort)
+ val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager)
+ System.setProperty("spark.hostPort", rpcEnv.address.hostPort)
assert(securityManager.isAuthenticationEnabled() === false)
val masterTracker = new MapOutputTrackerMaster(conf)
- masterTracker.trackerActor = actorSystem.actorOf(
- Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")
+ masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
+ new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf))
val badconf = new SparkConf
badconf.set("spark.authenticate", "false")
badconf.set("spark.authenticate.secret", "good")
val securityManagerBad = new SecurityManager(badconf)
- val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
- conf = badconf, securityManager = securityManagerBad)
+ val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, badconf, securityManagerBad)
val slaveTracker = new MapOutputTrackerWorker(conf)
- val selection = slaveSystem.actorSelection(
- AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker"))
- val timeout = AkkaUtils.lookupTimeout(conf)
- slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout)
+ slaveTracker.trackerEndpoint =
+ slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME)
assert(securityManagerBad.isAuthenticationEnabled() === false)
@@ -120,8 +111,8 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro
assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
Seq((BlockManagerId("a", "hostA", 1000), size1000)))
- actorSystem.shutdown()
- slaveSystem.shutdown()
+ rpcEnv.shutdown()
+ slaveRpcEnv.shutdown()
}
test("remote fetch security pass") {
@@ -131,15 +122,14 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro
val securityManager = new SecurityManager(conf)
val hostname = "localhost"
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0,
- conf = conf, securityManager = securityManager)
- System.setProperty("spark.hostPort", hostname + ":" + boundPort)
+ val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager)
+ System.setProperty("spark.hostPort", rpcEnv.address.hostPort)
assert(securityManager.isAuthenticationEnabled() === true)
val masterTracker = new MapOutputTrackerMaster(conf)
- masterTracker.trackerActor = actorSystem.actorOf(
- Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")
+ masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
+ new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf))
val goodconf = new SparkConf
goodconf.set("spark.authenticate", "true")
@@ -148,13 +138,10 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro
assert(securityManagerGood.isAuthenticationEnabled() === true)
- val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
- conf = goodconf, securityManager = securityManagerGood)
+ val slaveRpcEnv =RpcEnv.create("spark-slave", hostname, 0, goodconf, securityManagerGood)
val slaveTracker = new MapOutputTrackerWorker(conf)
- val selection = slaveSystem.actorSelection(
- AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker"))
- val timeout = AkkaUtils.lookupTimeout(conf)
- slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout)
+ slaveTracker.trackerEndpoint =
+ slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME)
masterTracker.registerShuffle(10, 1)
masterTracker.incrementEpoch()
@@ -170,47 +157,45 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro
assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
Seq((BlockManagerId("a", "hostA", 1000), size1000)))
- actorSystem.shutdown()
- slaveSystem.shutdown()
+ rpcEnv.shutdown()
+ slaveRpcEnv.shutdown()
}
test("remote fetch security off client") {
val conf = new SparkConf
+ conf.set("spark.rpc", "akka")
conf.set("spark.authenticate", "true")
conf.set("spark.authenticate.secret", "good")
val securityManager = new SecurityManager(conf)
val hostname = "localhost"
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0,
- conf = conf, securityManager = securityManager)
- System.setProperty("spark.hostPort", hostname + ":" + boundPort)
+ val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager)
+ System.setProperty("spark.hostPort", rpcEnv.address.hostPort)
assert(securityManager.isAuthenticationEnabled() === true)
val masterTracker = new MapOutputTrackerMaster(conf)
- masterTracker.trackerActor = actorSystem.actorOf(
- Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")
+ masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
+ new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf))
val badconf = new SparkConf
+ badconf.set("spark.rpc", "akka")
badconf.set("spark.authenticate", "false")
badconf.set("spark.authenticate.secret", "bad")
val securityManagerBad = new SecurityManager(badconf)
assert(securityManagerBad.isAuthenticationEnabled() === false)
- val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
- conf = badconf, securityManager = securityManagerBad)
+ val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, badconf, securityManagerBad)
val slaveTracker = new MapOutputTrackerWorker(conf)
- val selection = slaveSystem.actorSelection(
- AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker"))
- val timeout = AkkaUtils.lookupTimeout(conf)
intercept[akka.actor.ActorNotFound] {
- slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout)
+ slaveTracker.trackerEndpoint =
+ slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME)
}
- actorSystem.shutdown()
- slaveSystem.shutdown()
+ rpcEnv.shutdown()
+ slaveRpcEnv.shutdown()
}
test("remote fetch ssl on") {
@@ -218,26 +203,22 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro
val securityManager = new SecurityManager(conf)
val hostname = "localhost"
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0,
- conf = conf, securityManager = securityManager)
- System.setProperty("spark.hostPort", hostname + ":" + boundPort)
+ val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager)
+ System.setProperty("spark.hostPort", rpcEnv.address.hostPort)
assert(securityManager.isAuthenticationEnabled() === false)
val masterTracker = new MapOutputTrackerMaster(conf)
- masterTracker.trackerActor = actorSystem.actorOf(
- Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")
+ masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
+ new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf))
val slaveConf = sparkSSLConfig()
val securityManagerBad = new SecurityManager(slaveConf)
- val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
- conf = slaveConf, securityManager = securityManagerBad)
+ val slaveRpcEnv = RpcEnv.create("spark-slaves", hostname, 0, slaveConf, securityManagerBad)
val slaveTracker = new MapOutputTrackerWorker(conf)
- val selection = slaveSystem.actorSelection(
- AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker"))
- val timeout = AkkaUtils.lookupTimeout(conf)
- slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout)
+ slaveTracker.trackerEndpoint =
+ slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME)
assert(securityManagerBad.isAuthenticationEnabled() === false)
@@ -255,8 +236,8 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro
assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
Seq((BlockManagerId("a", "hostA", 1000), size1000)))
- actorSystem.shutdown()
- slaveSystem.shutdown()
+ rpcEnv.shutdown()
+ slaveRpcEnv.shutdown()
}
@@ -267,28 +248,24 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro
val securityManager = new SecurityManager(conf)
val hostname = "localhost"
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0,
- conf = conf, securityManager = securityManager)
- System.setProperty("spark.hostPort", hostname + ":" + boundPort)
+ val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager)
+ System.setProperty("spark.hostPort", rpcEnv.address.hostPort)
assert(securityManager.isAuthenticationEnabled() === true)
val masterTracker = new MapOutputTrackerMaster(conf)
- masterTracker.trackerActor = actorSystem.actorOf(
- Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")
+ masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
+ new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf))
val slaveConf = sparkSSLConfig()
slaveConf.set("spark.authenticate", "true")
slaveConf.set("spark.authenticate.secret", "good")
val securityManagerBad = new SecurityManager(slaveConf)
- val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
- conf = slaveConf, securityManager = securityManagerBad)
+ val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad)
val slaveTracker = new MapOutputTrackerWorker(conf)
- val selection = slaveSystem.actorSelection(
- AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker"))
- val timeout = AkkaUtils.lookupTimeout(conf)
- slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout)
+ slaveTracker.trackerEndpoint =
+ slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME)
assert(securityManagerBad.isAuthenticationEnabled() === true)
@@ -305,45 +282,43 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro
assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
Seq((BlockManagerId("a", "hostA", 1000), size1000)))
- actorSystem.shutdown()
- slaveSystem.shutdown()
+ rpcEnv.shutdown()
+ slaveRpcEnv.shutdown()
}
test("remote fetch ssl on and security enabled - bad credentials") {
val conf = sparkSSLConfig()
+ conf.set("spark.rpc", "akka")
conf.set("spark.authenticate", "true")
conf.set("spark.authenticate.secret", "good")
val securityManager = new SecurityManager(conf)
val hostname = "localhost"
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0,
- conf = conf, securityManager = securityManager)
- System.setProperty("spark.hostPort", hostname + ":" + boundPort)
+ val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager)
+ System.setProperty("spark.hostPort", rpcEnv.address.hostPort)
assert(securityManager.isAuthenticationEnabled() === true)
val masterTracker = new MapOutputTrackerMaster(conf)
- masterTracker.trackerActor = actorSystem.actorOf(
- Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")
+ masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
+ new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf))
val slaveConf = sparkSSLConfig()
+ slaveConf.set("spark.rpc", "akka")
slaveConf.set("spark.authenticate", "true")
slaveConf.set("spark.authenticate.secret", "bad")
val securityManagerBad = new SecurityManager(slaveConf)
- val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
- conf = slaveConf, securityManager = securityManagerBad)
+ val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad)
val slaveTracker = new MapOutputTrackerWorker(conf)
- val selection = slaveSystem.actorSelection(
- AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker"))
- val timeout = AkkaUtils.lookupTimeout(conf)
intercept[akka.actor.ActorNotFound] {
- slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout)
+ slaveTracker.trackerEndpoint =
+ slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME)
}
- actorSystem.shutdown()
- slaveSystem.shutdown()
+ rpcEnv.shutdown()
+ slaveRpcEnv.shutdown()
}
@@ -352,35 +327,30 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro
val securityManager = new SecurityManager(conf)
val hostname = "localhost"
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0,
- conf = conf, securityManager = securityManager)
- System.setProperty("spark.hostPort", hostname + ":" + boundPort)
+ val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager)
+ System.setProperty("spark.hostPort", rpcEnv.address.hostPort)
assert(securityManager.isAuthenticationEnabled() === false)
val masterTracker = new MapOutputTrackerMaster(conf)
- masterTracker.trackerActor = actorSystem.actorOf(
- Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")
+ masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
+ new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf))
val slaveConf = sparkSSLConfig()
val securityManagerBad = new SecurityManager(slaveConf)
- val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
- conf = slaveConf, securityManager = securityManagerBad)
+ val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad)
val slaveTracker = new MapOutputTrackerWorker(conf)
- val selection = slaveSystem.actorSelection(
- AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker"))
- val timeout = AkkaUtils.lookupTimeout(conf)
- val result = Try(Await.result(selection.resolveOne(timeout * 2), timeout))
-
- result match {
- case Failure(ex: ActorNotFound) =>
- case Failure(ex: TimeoutException) =>
- case r => fail(s"$r is neither Failure(ActorNotFound) nor Failure(TimeoutException)")
+ try {
+ slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME)
+ fail("should receive either ActorNotFound or TimeoutException")
+ } catch {
+ case e: ActorNotFound =>
+ case e: TimeoutException =>
}
- actorSystem.shutdown()
- slaveSystem.shutdown()
+ rpcEnv.shutdown()
+ slaveRpcEnv.shutdown()
}
}