From 4f5a24d7e73104771f233af041eeba4f41675974 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 31 Dec 2015 00:15:55 -0800 Subject: [SPARK-7995][SPARK-6280][CORE] Remove AkkaRpcEnv and remove systemName from setupEndpointRef ### Remove AkkaRpcEnv Keep `SparkEnv.actorSystem` because Streaming still uses it. Will remove it and AkkaUtils after refactoring Streaming actorStream API. ### Remove systemName There are 2 places using `systemName`: * `RpcEnvConfig.name`. Actually, although it's used as `systemName` in `AkkaRpcEnv`, `NettyRpcEnv` uses it as the service name to output the log `Successfully started service *** on port ***`. Since the service name in log is useful, I keep `RpcEnvConfig.name`. * `def setupEndpointRef(systemName: String, address: RpcAddress, endpointName: String)`. Each `ActorSystem` has a `systemName`. Akka requires `systemName` in its URI and will refuse a connection if `systemName` is not matched. However, `NettyRpcEnv` doesn't use it. So we can remove `systemName` from `setupEndpointRef` since we are removing `AkkaRpcEnv`. ### Remove RpcEnv.uriOf `uriOf` exists because Akka uses different URI formats for with and without authentication, e.g., `akka.ssl.tcp...` and `akka.tcp://...`. But `NettyRpcEnv` uses the same format. So it's not necessary after removing `AkkaRpcEnv`. Author: Shixiong Zhu Closes #10459 from zsxwing/remove-akka-rpc-env. --- .../main/scala/org/apache/spark/SparkConf.scala | 3 +- .../src/main/scala/org/apache/spark/SparkEnv.scala | 12 +- .../scala/org/apache/spark/deploy/Client.scala | 2 +- .../org/apache/spark/deploy/client/AppClient.scala | 3 +- .../org/apache/spark/deploy/worker/Worker.scala | 11 +- .../org/apache/spark/rpc/RpcEndpointAddress.scala | 73 ++++ .../main/scala/org/apache/spark/rpc/RpcEnv.scala | 28 +- .../org/apache/spark/rpc/akka/AkkaRpcEnv.scala | 404 --------------------- .../org/apache/spark/rpc/netty/NettyRpcEnv.scala | 5 +- .../spark/rpc/netty/RpcEndpointAddress.scala | 70 ---- .../scheduler/cluster/SimrSchedulerBackend.scala | 11 +- .../cluster/SparkDeploySchedulerBackend.scala | 9 +- .../mesos/CoarseMesosSchedulerBackend.scala | 10 +- .../org/apache/spark/util/ActorLogReceive.scala | 70 ---- .../scala/org/apache/spark/util/AkkaUtils.scala | 107 +----- .../scala/org/apache/spark/util/RpcUtils.scala | 11 +- .../org/apache/spark/MapOutputTrackerSuite.scala | 2 +- .../scala/org/apache/spark/SSLSampleConfigs.scala | 2 - .../deploy/StandaloneDynamicAllocationSuite.scala | 2 +- .../spark/deploy/client/AppClientSuite.scala | 2 +- .../apache/spark/deploy/master/MasterSuite.scala | 2 +- .../apache/spark/deploy/worker/WorkerSuite.scala | 8 +- .../spark/deploy/worker/WorkerWatcherSuite.scala | 6 +- .../scala/org/apache/spark/rpc/RpcEnvSuite.scala | 30 +- .../apache/spark/rpc/akka/AkkaRpcEnvSuite.scala | 71 ---- .../spark/rpc/netty/NettyRpcAddressSuite.scala | 1 + .../apache/spark/rpc/netty/NettyRpcEnvSuite.scala | 4 +- .../org/apache/spark/util/AkkaUtilsSuite.scala | 360 ------------------ 28 files changed, 140 insertions(+), 1179 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/rpc/RpcEndpointAddress.scala delete mode 100644 core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala delete mode 100644 core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala delete mode 100644 core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala delete mode 100644 core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala (limited to 'core') diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index d3384fb297..ff2c4c34c0 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -544,7 +544,8 @@ private[spark] object SparkConf extends Logging { DeprecatedConfig("spark.kryoserializer.buffer.mb", "1.4", "Please use spark.kryoserializer.buffer instead. The default value for " + "spark.kryoserializer.buffer.mb was previously specified as '0.064'. Fractional values " + - "are no longer accepted. To specify the equivalent now, one may use '64k'.") + "are no longer accepted. To specify the equivalent now, one may use '64k'."), + DeprecatedConfig("spark.rpc", "2.0", "Not used any more.") ) Map(configs.map { cfg => (cfg.key -> cfg) } : _*) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 52acde1b41..b98cc964ed 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -34,7 +34,6 @@ import org.apache.spark.memory.{MemoryManager, StaticMemoryManager, UnifiedMemor import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.{RpcEndpointRef, RpcEndpoint, RpcEnv} -import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.scheduler.{OutputCommitCoordinator, LiveListenerBus} import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint import org.apache.spark.serializer.Serializer @@ -97,9 +96,7 @@ class SparkEnv ( blockManager.master.stop() metricsSystem.stop() outputCommitCoordinator.stop() - if (!rpcEnv.isInstanceOf[AkkaRpcEnv]) { - actorSystem.shutdown() - } + actorSystem.shutdown() rpcEnv.shutdown() // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut @@ -248,14 +245,11 @@ object SparkEnv extends Logging { val securityManager = new SecurityManager(conf) - // Create the ActorSystem for Akka and get the port it binds to. val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName + // Create the ActorSystem for Akka and get the port it binds to. val rpcEnv = RpcEnv.create(actorSystemName, hostname, port, conf, securityManager, clientMode = !isDriver) - val actorSystem: ActorSystem = - if (rpcEnv.isInstanceOf[AkkaRpcEnv]) { - rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem - } else { + val actorSystem: ActorSystem = { val actorSystemPort = if (port == 0 || rpcEnv.address == null) { port diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index f03875a3e8..328a1bb84f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -230,7 +230,7 @@ object Client { RpcEnv.create("driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf)) val masterEndpoints = driverArgs.masters.map(RpcAddress.fromSparkURL). - map(rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, _, Master.ENDPOINT_NAME)) + map(rpcEnv.setupEndpointRef(_, Master.ENDPOINT_NAME)) rpcEnv.setupEndpoint("client", new ClientEndpoint(rpcEnv, driverArgs, masterEndpoints, conf)) rpcEnv.awaitTermination() diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index a5753e1053..f7c33214c2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -104,8 +104,7 @@ private[spark] class AppClient( return } logInfo("Connecting to master " + masterAddress.toSparkURL + "...") - val masterRef = - rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) + val masterRef = rpcEnv.setupEndpointRef(masterAddress, Master.ENDPOINT_NAME) masterRef.send(RegisterApplication(appDescription, self)) } catch { case ie: InterruptedException => // Cancelled diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 84e7b366bc..37b94e02cc 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -45,7 +45,6 @@ private[deploy] class Worker( cores: Int, memory: Int, masterRpcAddresses: Array[RpcAddress], - systemName: String, endpointName: String, workDirPath: String = null, val conf: SparkConf, @@ -101,7 +100,7 @@ private[deploy] class Worker( private var master: Option[RpcEndpointRef] = None private var activeMasterUrl: String = "" private[worker] var activeMasterWebUiUrl : String = "" - private val workerUri = rpcEnv.uriOf(systemName, rpcEnv.address, endpointName) + private val workerUri = RpcEndpointAddress(rpcEnv.address, endpointName).toString private var registered = false private var connected = false private val workerId = generateWorkerId() @@ -209,8 +208,7 @@ private[deploy] class Worker( override def run(): Unit = { try { logInfo("Connecting to master " + masterAddress + "...") - val masterEndpoint = - rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) + val masterEndpoint = rpcEnv.setupEndpointRef(masterAddress, Master.ENDPOINT_NAME) registerWithMaster(masterEndpoint) } catch { case ie: InterruptedException => // Cancelled @@ -266,8 +264,7 @@ private[deploy] class Worker( override def run(): Unit = { try { logInfo("Connecting to master " + masterAddress + "...") - val masterEndpoint = - rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) + val masterEndpoint = rpcEnv.setupEndpointRef(masterAddress, Master.ENDPOINT_NAME) registerWithMaster(masterEndpoint) } catch { case ie: InterruptedException => // Cancelled @@ -711,7 +708,7 @@ private[deploy] object Worker extends Logging { val rpcEnv = RpcEnv.create(systemName, host, port, conf, securityMgr) val masterAddresses = masterUrls.map(RpcAddress.fromSparkURL(_)) rpcEnv.setupEndpoint(ENDPOINT_NAME, new Worker(rpcEnv, webUiPort, cores, memory, - masterAddresses, systemName, ENDPOINT_NAME, workDir, conf, securityMgr)) + masterAddresses, ENDPOINT_NAME, workDir, conf, securityMgr)) rpcEnv } diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointAddress.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointAddress.scala new file mode 100644 index 0000000000..b9db60a779 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointAddress.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import org.apache.spark.SparkException + +/** + * An address identifier for an RPC endpoint. + * + * The `rpcAddress` may be null, in which case the endpoint is registered via a client-only + * connection and can only be reached via the client that sent the endpoint reference. + * + * @param rpcAddress The socket address of the endpoint. + * @param name Name of the endpoint. + */ +private[spark] case class RpcEndpointAddress(val rpcAddress: RpcAddress, val name: String) { + + require(name != null, "RpcEndpoint name must be provided.") + + def this(host: String, port: Int, name: String) = { + this(RpcAddress(host, port), name) + } + + override val toString = if (rpcAddress != null) { + s"spark://$name@${rpcAddress.host}:${rpcAddress.port}" + } else { + s"spark-client://$name" + } +} + +private[spark] object RpcEndpointAddress { + + def apply(host: String, port: Int, name: String): RpcEndpointAddress = { + new RpcEndpointAddress(host, port, name) + } + + def apply(sparkUrl: String): RpcEndpointAddress = { + try { + val uri = new java.net.URI(sparkUrl) + val host = uri.getHost + val port = uri.getPort + val name = uri.getUserInfo + if (uri.getScheme != "spark" || + host == null || + port < 0 || + name == null || + (uri.getPath != null && !uri.getPath.isEmpty) || // uri.getPath returns "" instead of null + uri.getFragment != null || + uri.getQuery != null) { + throw new SparkException("Invalid Spark URL: " + sparkUrl) + } + new RpcEndpointAddress(host, port, name) + } catch { + case e: java.net.URISyntaxException => + throw new SparkException("Invalid Spark URL: " + sparkUrl, e) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 64a4a8bf7c..5668377133 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -23,7 +23,8 @@ import java.nio.channels.ReadableByteChannel import scala.concurrent.Future import org.apache.spark.{SecurityManager, SparkConf} -import org.apache.spark.util.{RpcUtils, Utils} +import org.apache.spark.rpc.netty.NettyRpcEnvFactory +import org.apache.spark.util.RpcUtils /** @@ -32,15 +33,6 @@ import org.apache.spark.util.{RpcUtils, Utils} */ private[spark] object RpcEnv { - private def getRpcEnvFactory(conf: SparkConf): RpcEnvFactory = { - val rpcEnvNames = Map( - "akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory", - "netty" -> "org.apache.spark.rpc.netty.NettyRpcEnvFactory") - val rpcEnvName = conf.get("spark.rpc", "netty") - val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName) - Utils.classForName(rpcEnvFactoryClassName).newInstance().asInstanceOf[RpcEnvFactory] - } - def create( name: String, host: String, @@ -48,9 +40,8 @@ private[spark] object RpcEnv { conf: SparkConf, securityManager: SecurityManager, clientMode: Boolean = false): RpcEnv = { - // Using Reflection to create the RpcEnv to avoid to depend on Akka directly val config = RpcEnvConfig(conf, name, host, port, securityManager, clientMode) - getRpcEnvFactory(conf).create(config) + new NettyRpcEnvFactory().create(config) } } @@ -98,12 +89,11 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { } /** - * Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName`. + * Retrieve the [[RpcEndpointRef]] represented by `address` and `endpointName`. * This is a blocking action. */ - def setupEndpointRef( - systemName: String, address: RpcAddress, endpointName: String): RpcEndpointRef = { - setupEndpointRefByURI(uriOf(systemName, address, endpointName)) + def setupEndpointRef(address: RpcAddress, endpointName: String): RpcEndpointRef = { + setupEndpointRefByURI(RpcEndpointAddress(address, endpointName).toString) } /** @@ -124,12 +114,6 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { */ def awaitTermination(): Unit - /** - * Create a URI used to create a [[RpcEndpointRef]]. Use this one to create the URI instead of - * creating it manually because different [[RpcEnv]] may have different formats. - */ - def uriOf(systemName: String, address: RpcAddress, endpointName: String): String - /** * [[RpcEndpointRef]] cannot be deserialized without [[RpcEnv]]. So when deserializing any object * that contains [[RpcEndpointRef]]s, the deserialization codes should be wrapped by this method. diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala deleted file mode 100644 index 9d098154f7..0000000000 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ /dev/null @@ -1,404 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rpc.akka - -import java.io.File -import java.nio.channels.ReadableByteChannel -import java.util.concurrent.ConcurrentHashMap - -import scala.concurrent.Future -import scala.language.postfixOps -import scala.reflect.ClassTag -import scala.util.control.NonFatal - -import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Address} -import akka.event.Logging.Error -import akka.pattern.{ask => akkaAsk} -import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent} -import akka.serialization.JavaSerializer - -import org.apache.spark.{HttpFileServer, Logging, SecurityManager, SparkConf, SparkException} -import org.apache.spark.rpc._ -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils} - -/** - * A RpcEnv implementation based on Akka. - * - * TODO Once we remove all usages of Akka in other place, we can move this file to a new project and - * remove Akka from the dependencies. - */ -private[spark] class AkkaRpcEnv private[akka] ( - val actorSystem: ActorSystem, - val securityManager: SecurityManager, - conf: SparkConf, - boundPort: Int) - extends RpcEnv(conf) with Logging { - - private val defaultAddress: RpcAddress = { - val address = actorSystem.asInstanceOf[ExtendedActorSystem].provider.getDefaultAddress - // In some test case, ActorSystem doesn't bind to any address. - // So just use some default value since they are only some unit tests - RpcAddress(address.host.getOrElse("localhost"), address.port.getOrElse(boundPort)) - } - - override val address: RpcAddress = defaultAddress - - /** - * A lookup table to search a [[RpcEndpointRef]] for a [[RpcEndpoint]]. We need it to make - * [[RpcEndpoint.self]] work. - */ - private val endpointToRef = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]() - - /** - * Need this map to remove `RpcEndpoint` from `endpointToRef` via a `RpcEndpointRef` - */ - private val refToEndpoint = new ConcurrentHashMap[RpcEndpointRef, RpcEndpoint]() - - private val _fileServer = new AkkaFileServer(conf, securityManager) - - private def registerEndpoint(endpoint: RpcEndpoint, endpointRef: RpcEndpointRef): Unit = { - endpointToRef.put(endpoint, endpointRef) - refToEndpoint.put(endpointRef, endpoint) - } - - private def unregisterEndpoint(endpointRef: RpcEndpointRef): Unit = { - val endpoint = refToEndpoint.remove(endpointRef) - if (endpoint != null) { - endpointToRef.remove(endpoint) - } - } - - /** - * Retrieve the [[RpcEndpointRef]] of `endpoint`. - */ - override def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = endpointToRef.get(endpoint) - - override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { - @volatile var endpointRef: AkkaRpcEndpointRef = null - // Use defered function because the Actor needs to use `endpointRef`. - // So `actorRef` should be created after assigning `endpointRef`. - val actorRef = () => actorSystem.actorOf(Props(new Actor with ActorLogReceive with Logging { - - assert(endpointRef != null) - - override def preStart(): Unit = { - // Listen for remote client network events - context.system.eventStream.subscribe(self, classOf[AssociationEvent]) - safelyCall(endpoint) { - endpoint.onStart() - } - } - - override def receiveWithLogging: Receive = { - case AssociatedEvent(_, remoteAddress, _) => - safelyCall(endpoint) { - endpoint.onConnected(akkaAddressToRpcAddress(remoteAddress)) - } - - case DisassociatedEvent(_, remoteAddress, _) => - safelyCall(endpoint) { - endpoint.onDisconnected(akkaAddressToRpcAddress(remoteAddress)) - } - - case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) => - safelyCall(endpoint) { - endpoint.onNetworkError(cause, akkaAddressToRpcAddress(remoteAddress)) - } - - case e: AssociationEvent => - // TODO ignore? - - case m: AkkaMessage => - logDebug(s"Received RPC message: $m") - safelyCall(endpoint) { - processMessage(endpoint, m, sender) - } - - case AkkaFailure(e) => - safelyCall(endpoint) { - throw e - } - - case message: Any => { - logWarning(s"Unknown message: $message") - } - - } - - override def postStop(): Unit = { - unregisterEndpoint(endpoint.self) - safelyCall(endpoint) { - endpoint.onStop() - } - } - - }), name = name) - endpointRef = new AkkaRpcEndpointRef(defaultAddress, actorRef, conf, initInConstructor = false) - registerEndpoint(endpoint, endpointRef) - // Now actorRef can be created safely - endpointRef.init() - endpointRef - } - - private def processMessage(endpoint: RpcEndpoint, m: AkkaMessage, _sender: ActorRef): Unit = { - val message = m.message - val needReply = m.needReply - val pf: PartialFunction[Any, Unit] = - if (needReply) { - endpoint.receiveAndReply(new RpcCallContext { - override def sendFailure(e: Throwable): Unit = { - _sender ! AkkaFailure(e) - } - - override def reply(response: Any): Unit = { - _sender ! AkkaMessage(response, false) - } - - // Use "lazy" because most of RpcEndpoints don't need "senderAddress" - override lazy val senderAddress: RpcAddress = - new AkkaRpcEndpointRef(defaultAddress, _sender, conf).address - }) - } else { - endpoint.receive - } - try { - pf.applyOrElse[Any, Unit](message, { message => - throw new SparkException(s"Unmatched message $message from ${_sender}") - }) - } catch { - case NonFatal(e) => - _sender ! AkkaFailure(e) - if (!needReply) { - // If the sender does not require a reply, it may not handle the exception. So we rethrow - // "e" to make sure it will be processed. - throw e - } - } - } - - /** - * Run `action` safely to avoid to crash the thread. If any non-fatal exception happens, it will - * call `endpoint.onError`. If `endpoint.onError` throws any non-fatal exception, just log it. - */ - private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = { - try { - action - } catch { - case NonFatal(e) => { - try { - endpoint.onError(e) - } catch { - case NonFatal(e) => logError(s"Ignore error: ${e.getMessage}", e) - } - } - } - } - - private def akkaAddressToRpcAddress(address: Address): RpcAddress = { - RpcAddress(address.host.getOrElse(defaultAddress.host), - address.port.getOrElse(defaultAddress.port)) - } - - override def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = { - import actorSystem.dispatcher - actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout.duration). - map(new AkkaRpcEndpointRef(defaultAddress, _, conf)). - // this is just in case there is a timeout from creating the future in resolveOne, we want the - // exception to indicate the conf that determines the timeout - recover(defaultLookupTimeout.addMessageIfTimeout) - } - - override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = { - AkkaUtils.address( - AkkaUtils.protocol(actorSystem), systemName, address.host, address.port, endpointName) - } - - override def shutdown(): Unit = { - actorSystem.shutdown() - _fileServer.shutdown() - } - - override def stop(endpoint: RpcEndpointRef): Unit = { - require(endpoint.isInstanceOf[AkkaRpcEndpointRef]) - actorSystem.stop(endpoint.asInstanceOf[AkkaRpcEndpointRef].actorRef) - } - - override def awaitTermination(): Unit = { - actorSystem.awaitTermination() - } - - override def toString: String = s"${getClass.getSimpleName}($actorSystem)" - - override def deserialize[T](deserializationAction: () => T): T = { - JavaSerializer.currentSystem.withValue(actorSystem.asInstanceOf[ExtendedActorSystem]) { - deserializationAction() - } - } - - override def openChannel(uri: String): ReadableByteChannel = { - throw new UnsupportedOperationException( - "AkkaRpcEnv's files should be retrieved using an HTTP client.") - } - - override def fileServer: RpcEnvFileServer = _fileServer - -} - -private[akka] class AkkaFileServer( - conf: SparkConf, - securityManager: SecurityManager) extends RpcEnvFileServer { - - @volatile private var httpFileServer: HttpFileServer = _ - - override def addFile(file: File): String = { - getFileServer().addFile(file) - } - - override def addJar(file: File): String = { - getFileServer().addJar(file) - } - - override def addDirectory(baseUri: String, path: File): String = { - val fixedBaseUri = validateDirectoryUri(baseUri) - getFileServer().addDirectory(fixedBaseUri, path.getAbsolutePath()) - } - - def shutdown(): Unit = { - if (httpFileServer != null) { - httpFileServer.stop() - } - } - - private def getFileServer(): HttpFileServer = { - if (httpFileServer == null) synchronized { - if (httpFileServer == null) { - httpFileServer = startFileServer() - } - } - httpFileServer - } - - private def startFileServer(): HttpFileServer = { - val fileServerPort = conf.getInt("spark.fileserver.port", 0) - val server = new HttpFileServer(conf, securityManager, fileServerPort) - server.initialize() - server - } - -} - -private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory { - - def create(config: RpcEnvConfig): RpcEnv = { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - config.name, config.host, config.port, config.conf, config.securityManager) - actorSystem.actorOf(Props(classOf[ErrorMonitor]), "ErrorMonitor") - new AkkaRpcEnv(actorSystem, config.securityManager, config.conf, boundPort) - } -} - -/** - * Monitor errors reported by Akka and log them. - */ -private[akka] class ErrorMonitor extends Actor with ActorLogReceive with Logging { - - override def preStart(): Unit = { - context.system.eventStream.subscribe(self, classOf[Error]) - } - - override def receiveWithLogging: Actor.Receive = { - case Error(cause: Throwable, _, _, message: String) => logDebug(message, cause) - } -} - -private[akka] class AkkaRpcEndpointRef( - @transient private val defaultAddress: RpcAddress, - @transient private val _actorRef: () => ActorRef, - conf: SparkConf, - initInConstructor: Boolean) - extends RpcEndpointRef(conf) with Logging { - - def this( - defaultAddress: RpcAddress, - _actorRef: ActorRef, - conf: SparkConf) = { - this(defaultAddress, () => _actorRef, conf, true) - } - - lazy val actorRef = _actorRef() - - override lazy val address: RpcAddress = { - val akkaAddress = actorRef.path.address - RpcAddress(akkaAddress.host.getOrElse(defaultAddress.host), - akkaAddress.port.getOrElse(defaultAddress.port)) - } - - override lazy val name: String = actorRef.path.name - - private[akka] def init(): Unit = { - // Initialize the lazy vals - actorRef - address - name - } - - if (initInConstructor) { - init() - } - - override def send(message: Any): Unit = { - actorRef ! AkkaMessage(message, false) - } - - override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { - actorRef.ask(AkkaMessage(message, true))(timeout.duration).flatMap { - // The function will run in the calling thread, so it should be short and never block. - case msg @ AkkaMessage(message, reply) => - if (reply) { - logError(s"Receive $msg but the sender cannot reply") - Future.failed(new SparkException(s"Receive $msg but the sender cannot reply")) - } else { - Future.successful(message) - } - case AkkaFailure(e) => - Future.failed(e) - }(ThreadUtils.sameThread).mapTo[T]. - recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread) - } - - override def toString: String = s"${getClass.getSimpleName}($actorRef)" - - final override def equals(that: Any): Boolean = that match { - case other: AkkaRpcEndpointRef => actorRef == other.actorRef - case _ => false - } - - final override def hashCode(): Int = if (actorRef == null) 0 else actorRef.hashCode() -} - -/** - * A wrapper to `message` so that the receiver knows if the sender expects a reply. - * @param message - * @param needReply if the sender expects a reply message - */ -private[akka] case class AkkaMessage(message: Any, needReply: Boolean) - -/** - * A reply with the failure error from the receiver to the sender - */ -private[akka] case class AkkaFailure(e: Throwable) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 090a1b9f6e..ef876b1d8c 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -257,9 +257,6 @@ private[netty] class NettyRpcEnv( dispatcher.getRpcEndpointRef(endpoint) } - override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = - new RpcEndpointAddress(address, endpointName).toString - override def shutdown(): Unit = { cleanup() } @@ -427,7 +424,7 @@ private[netty] object NettyRpcEnv extends Logging { } -private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { +private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { def create(config: RpcEnvConfig): RpcEnv = { val sparkConf = config.conf diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala deleted file mode 100644 index cd6f00cc08..0000000000 --- a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rpc.netty - -import org.apache.spark.SparkException -import org.apache.spark.rpc.RpcAddress - -/** - * An address identifier for an RPC endpoint. - * - * The `rpcAddress` may be null, in which case the endpoint is registered via a client-only - * connection and can only be reached via the client that sent the endpoint reference. - * - * @param rpcAddress The socket address of the endpoint. - * @param name Name of the endpoint. - */ -private[netty] case class RpcEndpointAddress(val rpcAddress: RpcAddress, val name: String) { - - require(name != null, "RpcEndpoint name must be provided.") - - def this(host: String, port: Int, name: String) = { - this(RpcAddress(host, port), name) - } - - override val toString = if (rpcAddress != null) { - s"spark://$name@${rpcAddress.host}:${rpcAddress.port}" - } else { - s"spark-client://$name" - } -} - -private[netty] object RpcEndpointAddress { - - def apply(sparkUrl: String): RpcEndpointAddress = { - try { - val uri = new java.net.URI(sparkUrl) - val host = uri.getHost - val port = uri.getPort - val name = uri.getUserInfo - if (uri.getScheme != "spark" || - host == null || - port < 0 || - name == null || - (uri.getPath != null && !uri.getPath.isEmpty) || // uri.getPath returns "" instead of null - uri.getFragment != null || - uri.getQuery != null) { - throw new SparkException("Invalid Spark URL: " + sparkUrl) - } - new RpcEndpointAddress(host, port, name) - } catch { - case e: java.net.URISyntaxException => - throw new SparkException("Invalid Spark URL: " + sparkUrl, e) - } - } -} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index 641638a77d..781ecfff7e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -19,9 +19,9 @@ package org.apache.spark.scheduler.cluster import org.apache.hadoop.fs.{Path, FileSystem} -import org.apache.spark.rpc.RpcAddress -import org.apache.spark.{Logging, SparkContext, SparkEnv} +import org.apache.spark.{Logging, SparkContext} import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.rpc.RpcEndpointAddress import org.apache.spark.scheduler.TaskSchedulerImpl private[spark] class SimrSchedulerBackend( @@ -39,9 +39,10 @@ private[spark] class SimrSchedulerBackend( override def start() { super.start() - val driverUrl = rpcEnv.uriOf(SparkEnv.driverActorSystemName, - RpcAddress(sc.conf.get("spark.driver.host"), sc.conf.get("spark.driver.port").toInt), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + val driverUrl = RpcEndpointAddress( + sc.conf.get("spark.driver.host"), + sc.conf.get("spark.driver.port").toInt, + CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString val conf = SparkHadoopUtil.get.newConfiguration(sc.conf) val fs = FileSystem.get(conf) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 5105475c76..1209cce6d1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler.cluster import java.util.concurrent.Semaphore -import org.apache.spark.rpc.RpcAddress +import org.apache.spark.rpc.{RpcEndpointAddress, RpcAddress} import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.deploy.client.{AppClient, AppClientListener} @@ -54,9 +54,10 @@ private[spark] class SparkDeploySchedulerBackend( launcherBackend.connect() // The endpoint for executors to talk to us - val driverUrl = rpcEnv.uriOf(SparkEnv.driverActorSystemName, - RpcAddress(sc.conf.get("spark.driver.host"), sc.conf.get("spark.driver.port").toInt), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + val driverUrl = RpcEndpointAddress( + sc.conf.get("spark.driver.host"), + sc.conf.get("spark.driver.port").toInt, + CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString val args = Seq( "--driver-url", driverUrl, "--executor-id", "{{EXECUTOR_ID}}", diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 7d08eae0b4..a4ed85cd2a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -31,7 +31,7 @@ import org.apache.mesos.{Scheduler => MScheduler, SchedulerDriver} import org.apache.spark.{SecurityManager, SparkContext, SparkEnv, SparkException, TaskState} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient -import org.apache.spark.rpc.RpcAddress +import org.apache.spark.rpc.{RpcEndpointAddress, RpcAddress} import org.apache.spark.scheduler.{SlaveLost, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils @@ -215,10 +215,10 @@ private[spark] class CoarseMesosSchedulerBackend( if (conf.contains("spark.testing")) { "driverURL" } else { - sc.env.rpcEnv.uriOf( - SparkEnv.driverActorSystemName, - RpcAddress(conf.get("spark.driver.host"), conf.get("spark.driver.port").toInt), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + RpcEndpointAddress( + conf.get("spark.driver.host"), + conf.get("spark.driver.port").toInt, + CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString } } diff --git a/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala b/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala deleted file mode 100644 index 81a7cbde01..0000000000 --- a/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import akka.actor.Actor -import org.slf4j.Logger - -/** - * A trait to enable logging all Akka actor messages. Here's an example of using this: - * - * {{{ - * class BlockManagerMasterActor extends Actor with ActorLogReceive with Logging { - * ... - * override def receiveWithLogging = { - * case GetLocations(blockId) => - * sender ! getLocations(blockId) - * ... - * } - * ... - * } - * }}} - * - */ -private[spark] trait ActorLogReceive { - self: Actor => - - override def receive: Actor.Receive = new Actor.Receive { - - private val _receiveWithLogging = receiveWithLogging - - override def isDefinedAt(o: Any): Boolean = { - val handled = _receiveWithLogging.isDefinedAt(o) - if (!handled) { - log.debug(s"Received unexpected actor system event: $o") - } - handled - } - - override def apply(o: Any): Unit = { - if (log.isDebugEnabled) { - log.debug(s"[actor] received message $o from ${self.sender}") - } - val start = System.nanoTime - _receiveWithLogging.apply(o) - val timeTaken = (System.nanoTime - start).toDouble / 1000000 - if (log.isDebugEnabled) { - log.debug(s"[actor] handled message ($timeTaken ms) $o from ${self.sender}") - } - } - } - - def receiveWithLogging: Actor.Receive - - protected def log: Logger -} diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 1738258a0c..f2d93edd4f 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -19,14 +19,11 @@ package org.apache.spark.util import scala.collection.JavaConverters._ -import akka.actor.{ActorRef, ActorSystem, ExtendedActorSystem} -import akka.pattern.ask - +import akka.actor.{ActorSystem, ExtendedActorSystem} import com.typesafe.config.ConfigFactory import org.apache.log4j.{Level, Logger} -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv, SparkException} -import org.apache.spark.rpc.RpcTimeout +import org.apache.spark.{Logging, SecurityManager, SparkConf} /** * Various utility classes for working with Akka. @@ -139,104 +136,4 @@ private[spark] object AkkaUtils extends Logging { /** Space reserved for extra data in an Akka message besides serialized task or task result. */ val reservedSizeBytes = 200 * 1024 - /** - * Send a message to the given actor and get its result within a default timeout, or - * throw a SparkException if this fails. - */ - def askWithReply[T]( - message: Any, - actor: ActorRef, - timeout: RpcTimeout): T = { - askWithReply[T](message, actor, maxAttempts = 1, retryInterval = Int.MaxValue, timeout) - } - - /** - * Send a message to the given actor and get its result within a default timeout, or - * throw a SparkException if this fails even after the specified number of retries. - */ - def askWithReply[T]( - message: Any, - actor: ActorRef, - maxAttempts: Int, - retryInterval: Long, - timeout: RpcTimeout): T = { - // TODO: Consider removing multiple attempts - if (actor == null) { - throw new SparkException(s"Error sending message [message = $message]" + - " as actor is null ") - } - var attempts = 0 - var lastException: Exception = null - while (attempts < maxAttempts) { - attempts += 1 - try { - val future = actor.ask(message)(timeout.duration) - val result = timeout.awaitResult(future) - if (result == null) { - throw new SparkException("Actor returned null") - } - return result.asInstanceOf[T] - } catch { - case ie: InterruptedException => throw ie - case e: Exception => - lastException = e - logWarning(s"Error sending message [message = $message] in $attempts attempts", e) - } - if (attempts < maxAttempts) { - Thread.sleep(retryInterval) - } - } - - throw new SparkException( - s"Error sending message [message = $message]", lastException) - } - - def makeDriverRef(name: String, conf: SparkConf, actorSystem: ActorSystem): ActorRef = { - val driverActorSystemName = SparkEnv.driverActorSystemName - val driverHost: String = conf.get("spark.driver.host", "localhost") - val driverPort: Int = conf.getInt("spark.driver.port", 7077) - Utils.checkHost(driverHost, "Expected hostname") - val url = address(protocol(actorSystem), driverActorSystemName, driverHost, driverPort, name) - val timeout = RpcUtils.lookupRpcTimeout(conf) - logInfo(s"Connecting to $name: $url") - timeout.awaitResult(actorSystem.actorSelection(url).resolveOne(timeout.duration)) - } - - def makeExecutorRef( - name: String, - conf: SparkConf, - host: String, - port: Int, - actorSystem: ActorSystem): ActorRef = { - val executorActorSystemName = SparkEnv.executorActorSystemName - Utils.checkHost(host, "Expected hostname") - val url = address(protocol(actorSystem), executorActorSystemName, host, port, name) - val timeout = RpcUtils.lookupRpcTimeout(conf) - logInfo(s"Connecting to $name: $url") - timeout.awaitResult(actorSystem.actorSelection(url).resolveOne(timeout.duration)) - } - - def protocol(actorSystem: ActorSystem): String = { - val akkaConf = actorSystem.settings.config - val sslProp = "akka.remote.netty.tcp.enable-ssl" - protocol(akkaConf.hasPath(sslProp) && akkaConf.getBoolean(sslProp)) - } - - def protocol(ssl: Boolean = false): String = { - if (ssl) { - "akka.ssl.tcp" - } else { - "akka.tcp" - } - } - - def address( - protocol: String, - systemName: String, - host: String, - port: Int, - actorName: String): String = { - s"$protocol://$systemName@$host:$port/user/$actorName" - } - } diff --git a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala index 7578a3b1d8..a51f30b9c2 100644 --- a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala @@ -20,20 +20,19 @@ package org.apache.spark.util import scala.concurrent.duration.FiniteDuration import scala.language.postfixOps -import org.apache.spark.{SparkEnv, SparkConf} +import org.apache.spark.SparkConf import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv, RpcTimeout} -object RpcUtils { +private[spark] object RpcUtils { /** * Retrieve a [[RpcEndpointRef]] which is located in the driver via its name. */ def makeDriverRef(name: String, conf: SparkConf, rpcEnv: RpcEnv): RpcEndpointRef = { - val driverActorSystemName = SparkEnv.driverActorSystemName val driverHost: String = conf.get("spark.driver.host", "localhost") val driverPort: Int = conf.getInt("spark.driver.port", 7077) Utils.checkHost(driverHost, "Expected hostname") - rpcEnv.setupEndpointRef(driverActorSystemName, RpcAddress(driverHost, driverPort), name) + rpcEnv.setupEndpointRef(RpcAddress(driverHost, driverPort), name) } /** Returns the configured number of times to retry connecting */ @@ -47,7 +46,7 @@ object RpcUtils { } /** Returns the default Spark timeout to use for RPC ask operations. */ - private[spark] def askRpcTimeout(conf: SparkConf): RpcTimeout = { + def askRpcTimeout(conf: SparkConf): RpcTimeout = { RpcTimeout(conf, Seq("spark.rpc.askTimeout", "spark.network.timeout"), "120s") } @@ -57,7 +56,7 @@ object RpcUtils { } /** Returns the default Spark timeout to use for RPC remote endpoint lookup. */ - private[spark] def lookupRpcTimeout(conf: SparkConf): RpcTimeout = { + def lookupRpcTimeout(conf: SparkConf): RpcTimeout = { RpcTimeout(conf, Seq("spark.rpc.lookupTimeout", "spark.network.timeout"), "120s") } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 7e70308bb3..5b29d69cd9 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -125,7 +125,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val slaveRpcEnv = createRpcEnv("spark-slave", hostname, 0, new SecurityManager(conf)) val slaveTracker = new MapOutputTrackerWorker(conf) slaveTracker.trackerEndpoint = - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) + slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) masterTracker.registerShuffle(10, 1) masterTracker.incrementEpoch() diff --git a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala index 2d14249855..33270bec62 100644 --- a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala +++ b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala @@ -41,7 +41,6 @@ object SSLSampleConfigs { def sparkSSLConfig(): SparkConf = { val conf = new SparkConf(loadDefaults = false) - conf.set("spark.rpc", "akka") conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", keyStorePath) conf.set("spark.ssl.keyStorePassword", "password") @@ -55,7 +54,6 @@ object SSLSampleConfigs { def sparkSSLConfigUntrusted(): SparkConf = { val conf = new SparkConf(loadDefaults = false) - conf.set("spark.rpc", "akka") conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", untrustedKeyStorePath) conf.set("spark.ssl.keyStorePassword", "password") diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index 85c1c1bbf3..ab3d4cafeb 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -474,7 +474,7 @@ class StandaloneDynamicAllocationSuite (0 until numWorkers).map { i => val rpcEnv = workerRpcEnvs(i) val worker = new Worker(rpcEnv, 0, cores, memory, Array(masterRpcEnv.address), - Worker.SYSTEM_NAME + i, Worker.ENDPOINT_NAME, null, conf, securityManager) + Worker.ENDPOINT_NAME, null, conf, securityManager) rpcEnv.setupEndpoint(Worker.ENDPOINT_NAME, worker) worker } diff --git a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala index 415e2b37db..eb794b6739 100644 --- a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala @@ -147,7 +147,7 @@ class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAnd (0 until numWorkers).map { i => val rpcEnv = workerRpcEnvs(i) val worker = new Worker(rpcEnv, 0, cores, memory, Array(masterRpcEnv.address), - Worker.SYSTEM_NAME + i, Worker.ENDPOINT_NAME, null, conf, securityManager) + Worker.ENDPOINT_NAME, null, conf, securityManager) rpcEnv.setupEndpoint(Worker.ENDPOINT_NAME, worker) worker } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 242bf4b556..10e33a32ba 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -98,7 +98,7 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually with Priva Master.startRpcEnvAndEndpoint("127.0.0.1", 0, 0, conf) try { - rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, rpcEnv.address, Master.ENDPOINT_NAME) + rpcEnv.setupEndpointRef(rpcEnv.address, Master.ENDPOINT_NAME) CustomPersistenceEngine.lastInstance.isDefined shouldBe true val persistenceEngine = CustomPersistenceEngine.lastInstance.get diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala index faed4bdc68..082d5e86eb 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala @@ -67,7 +67,7 @@ class WorkerSuite extends SparkFunSuite with Matchers { conf.set("spark.worker.ui.retainedExecutors", 2.toString) val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), - "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + "Worker", "/tmp", conf, new SecurityManager(conf)) // initialize workers for (i <- 0 until 5) { worker.executors += s"app1/$i" -> createExecutorRunner(i) @@ -93,7 +93,7 @@ class WorkerSuite extends SparkFunSuite with Matchers { conf.set("spark.worker.ui.retainedExecutors", 30.toString) val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), - "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + "Worker", "/tmp", conf, new SecurityManager(conf)) // initialize workers for (i <- 0 until 50) { worker.executors += s"app1/$i" -> createExecutorRunner(i) @@ -128,7 +128,7 @@ class WorkerSuite extends SparkFunSuite with Matchers { conf.set("spark.worker.ui.retainedDrivers", 2.toString) val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), - "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + "Worker", "/tmp", conf, new SecurityManager(conf)) // initialize workers for (i <- 0 until 5) { val driverId = s"driverId-$i" @@ -154,7 +154,7 @@ class WorkerSuite extends SparkFunSuite with Matchers { conf.set("spark.worker.ui.retainedDrivers", 30.toString) val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), - "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + "Worker", "/tmp", conf, new SecurityManager(conf)) // initialize workers for (i <- 0 until 50) { val driverId = s"driverId-$i" diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala index 40c24bdecc..0ffd91d8ff 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala @@ -19,13 +19,13 @@ package org.apache.spark.deploy.worker import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.SecurityManager -import org.apache.spark.rpc.{RpcAddress, RpcEnv} +import org.apache.spark.rpc.{RpcEndpointAddress, RpcAddress, RpcEnv} class WorkerWatcherSuite extends SparkFunSuite { test("WorkerWatcher shuts down on valid disassociation") { val conf = new SparkConf() val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) - val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker") + val targetWorkerUrl = RpcEndpointAddress(RpcAddress("1.2.3.4", 1234), "Worker").toString val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl, isTesting = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) workerWatcher.onDisconnected(RpcAddress("1.2.3.4", 1234)) @@ -36,7 +36,7 @@ class WorkerWatcherSuite extends SparkFunSuite { test("WorkerWatcher stays alive on invalid disassociation") { val conf = new SparkConf() val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) - val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker") + val targetWorkerUrl = RpcEndpointAddress(RpcAddress("1.2.3.4", 1234), "Worker").toString val otherRpcAddress = RpcAddress("4.3.2.1", 1234) val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl, isTesting = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 9c850c0da5..924fce7f61 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -94,7 +94,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "send-remotely") + val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "send-remotely") try { rpcEndpointRef.send("hello") eventually(timeout(5 seconds), interval(10 millis)) { @@ -148,7 +148,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-remotely") + val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "ask-remotely") try { val reply = rpcEndpointRef.askWithRetry[String]("hello") assert("hello" === reply) @@ -176,7 +176,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.rpc.numRetries", "1") val anotherEnv = createRpcEnv(conf, "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-timeout") + val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "ask-timeout") try { // Any exception thrown in askWithRetry is wrapped with a SparkException and set as the cause val e = intercept[SparkException] { @@ -435,7 +435,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "sendWithReply-remotely") + val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "sendWithReply-remotely") try { val f = rpcEndpointRef.ask[String]("hello") val ack = Await.result(f, 5 seconds) @@ -475,8 +475,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef( - "local", env.address, "sendWithReply-remotely-error") + val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "sendWithReply-remotely-error") try { val f = rpcEndpointRef.ask[String]("hello") val e = intercept[SparkException] { @@ -527,8 +526,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val (_, events) = setupNetworkEndpoint(serverEnv1, "network-events") val (serverRef2, _) = setupNetworkEndpoint(serverEnv2, "network-events") try { - val serverRefInServer2 = - serverEnv1.setupEndpointRef("server2", serverRef2.address, serverRef2.name) + val serverRefInServer2 = serverEnv1.setupEndpointRef(serverRef2.address, serverRef2.name) // Send a message to set up the connection serverRefInServer2.send("hello") @@ -556,8 +554,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val (serverRef, events) = setupNetworkEndpoint(serverEnv, "network-events") val clientEnv = createRpcEnv(new SparkConf(), "client", 0, clientMode = true) try { - val serverRefInClient = - clientEnv.setupEndpointRef("server", serverRef.address, serverRef.name) + val serverRefInClient = clientEnv.setupEndpointRef(serverRef.address, serverRef.name) // Send a message to set up the connection serverRefInClient.send("hello") @@ -588,8 +585,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val (_, events) = setupNetworkEndpoint(clientEnv, "network-events") val (serverRef, _) = setupNetworkEndpoint(serverEnv, "network-events") try { - val serverRefInClient = - clientEnv.setupEndpointRef("server", serverRef.address, serverRef.name) + val serverRefInClient = clientEnv.setupEndpointRef(serverRef.address, serverRef.name) // Send a message to set up the connection serverRefInClient.send("hello") @@ -623,8 +619,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef( - "local", env.address, "sendWithReply-unserializable-error") + val rpcEndpointRef = + anotherEnv.setupEndpointRef(env.address, "sendWithReply-unserializable-error") try { val f = rpcEndpointRef.ask[String]("hello") val e = intercept[Exception] { @@ -661,8 +657,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { case msg: String => message = msg } }) - val rpcEndpointRef = - remoteEnv.setupEndpointRef("authentication-local", localEnv.address, "send-authentication") + val rpcEndpointRef = remoteEnv.setupEndpointRef(localEnv.address, "send-authentication") rpcEndpointRef.send("hello") eventually(timeout(5 seconds), interval(10 millis)) { assert("hello" === message) @@ -693,8 +688,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } } }) - val rpcEndpointRef = - remoteEnv.setupEndpointRef("authentication-local", localEnv.address, "ask-authentication") + val rpcEndpointRef = remoteEnv.setupEndpointRef(localEnv.address, "ask-authentication") val reply = rpcEndpointRef.askWithRetry[String]("hello") assert("hello" === reply) } finally { diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala deleted file mode 100644 index 7aac02775e..0000000000 --- a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rpc.akka - -import org.apache.spark.rpc._ -import org.apache.spark.{SSLSampleConfigs, SecurityManager, SparkConf} - -class AkkaRpcEnvSuite extends RpcEnvSuite { - - override def createRpcEnv(conf: SparkConf, - name: String, - port: Int, - clientMode: Boolean = false): RpcEnv = { - new AkkaRpcEnvFactory().create( - RpcEnvConfig(conf, name, "localhost", port, new SecurityManager(conf), clientMode)) - } - - test("setupEndpointRef: systemName, address, endpointName") { - val ref = env.setupEndpoint("test_endpoint", new RpcEndpoint { - override val rpcEnv = env - - override def receive = { - case _ => - } - }) - val conf = new SparkConf() - val newRpcEnv = new AkkaRpcEnvFactory().create( - RpcEnvConfig(conf, "test", "localhost", 0, new SecurityManager(conf), false)) - try { - val newRef = newRpcEnv.setupEndpointRef("local", ref.address, "test_endpoint") - assert(s"akka.tcp://local@${env.address}/user/test_endpoint" === - newRef.asInstanceOf[AkkaRpcEndpointRef].actorRef.path.toString) - } finally { - newRpcEnv.shutdown() - } - } - - test("uriOf") { - val uri = env.uriOf("local", RpcAddress("1.2.3.4", 12345), "test_endpoint") - assert("akka.tcp://local@1.2.3.4:12345/user/test_endpoint" === uri) - } - - test("uriOf: ssl") { - val conf = SSLSampleConfigs.sparkSSLConfig() - val securityManager = new SecurityManager(conf) - val rpcEnv = new AkkaRpcEnvFactory().create( - RpcEnvConfig(conf, "test", "localhost", 0, securityManager, false)) - try { - val uri = rpcEnv.uriOf("local", RpcAddress("1.2.3.4", 12345), "test_endpoint") - assert("akka.ssl.tcp://local@1.2.3.4:12345/user/test_endpoint" === uri) - } finally { - rpcEnv.shutdown() - } - } - -} diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala index 56743ba650..4fcdb619f9 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.rpc.netty import org.apache.spark.SparkFunSuite +import org.apache.spark.rpc.RpcEndpointAddress class NettyRpcAddressSuite extends SparkFunSuite { diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala index ce83087ec0..994a58836b 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala @@ -33,9 +33,9 @@ class NettyRpcEnvSuite extends RpcEnvSuite { } test("non-existent endpoint") { - val uri = env.uriOf("test", env.address, "nonexist-endpoint") + val uri = RpcEndpointAddress(env.address, "nonexist-endpoint").toString val e = intercept[RpcEndpointNotFoundException] { - env.setupEndpointRef("test", env.address, "nonexist-endpoint") + env.setupEndpointRef(env.address, "nonexist-endpoint") } assert(e.getMessage.contains(uri)) } diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala deleted file mode 100644 index 0af4b6098b..0000000000 --- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala +++ /dev/null @@ -1,360 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import scala.collection.mutable.ArrayBuffer - -import java.util.concurrent.TimeoutException - -import akka.actor.ActorNotFound - -import org.apache.spark._ -import org.apache.spark.rpc.RpcEnv -import org.apache.spark.scheduler.MapStatus -import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} -import org.apache.spark.SSLSampleConfigs._ - - -/** - * Test the AkkaUtils with various security settings. - */ -class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSystemProperties { - - test("remote fetch security bad password") { - val conf = new SparkConf - conf.set("spark.rpc", "akka") - conf.set("spark.authenticate", "true") - conf.set("spark.authenticate.secret", "good") - - val securityManager = new SecurityManager(conf) - val hostname = "localhost" - val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) - System.setProperty("spark.hostPort", rpcEnv.address.hostPort) - assert(securityManager.isAuthenticationEnabled() === true) - - val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - - val badconf = new SparkConf - badconf.set("spark.rpc", "akka") - badconf.set("spark.authenticate", "true") - badconf.set("spark.authenticate.secret", "bad") - val securityManagerBad = new SecurityManager(badconf) - - assert(securityManagerBad.isAuthenticationEnabled() === true) - - val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, conf, securityManagerBad) - val slaveTracker = new MapOutputTrackerWorker(conf) - intercept[akka.actor.ActorNotFound] { - slaveTracker.trackerEndpoint = - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - } - - rpcEnv.shutdown() - slaveRpcEnv.shutdown() - } - - test("remote fetch security off") { - val conf = new SparkConf - conf.set("spark.authenticate", "false") - conf.set("spark.authenticate.secret", "bad") - val securityManager = new SecurityManager(conf) - - val hostname = "localhost" - val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) - System.setProperty("spark.hostPort", rpcEnv.address.hostPort) - - assert(securityManager.isAuthenticationEnabled() === false) - - val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - - val badconf = new SparkConf - badconf.set("spark.authenticate", "false") - badconf.set("spark.authenticate.secret", "good") - val securityManagerBad = new SecurityManager(badconf) - - val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, badconf, securityManagerBad) - val slaveTracker = new MapOutputTrackerWorker(conf) - slaveTracker.trackerEndpoint = - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - - assert(securityManagerBad.isAuthenticationEnabled() === false) - - masterTracker.registerShuffle(10, 1) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) - masterTracker.registerMapOutput(10, 0, - MapStatus(BlockManagerId("a", "hostA", 1000), Array(1000L))) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - // this should succeed since security off - assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000), - ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) - - rpcEnv.shutdown() - slaveRpcEnv.shutdown() - } - - test("remote fetch security pass") { - val conf = new SparkConf - conf.set("spark.authenticate", "true") - conf.set("spark.authenticate.secret", "good") - val securityManager = new SecurityManager(conf) - - val hostname = "localhost" - val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) - System.setProperty("spark.hostPort", rpcEnv.address.hostPort) - - assert(securityManager.isAuthenticationEnabled() === true) - - val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - - val goodconf = new SparkConf - goodconf.set("spark.authenticate", "true") - goodconf.set("spark.authenticate.secret", "good") - val securityManagerGood = new SecurityManager(goodconf) - - assert(securityManagerGood.isAuthenticationEnabled() === true) - - val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, goodconf, securityManagerGood) - val slaveTracker = new MapOutputTrackerWorker(conf) - slaveTracker.trackerEndpoint = - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - - masterTracker.registerShuffle(10, 1) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) - masterTracker.registerMapOutput(10, 0, MapStatus( - BlockManagerId("a", "hostA", 1000), Array(1000L))) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - // this should succeed since security on and passwords match - assert(slaveTracker.getMapSizesByExecutorId(10, 0) === - Seq((BlockManagerId("a", "hostA", 1000), - ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) - - rpcEnv.shutdown() - slaveRpcEnv.shutdown() - } - - test("remote fetch security off client") { - val conf = new SparkConf - conf.set("spark.rpc", "akka") - conf.set("spark.authenticate", "true") - conf.set("spark.authenticate.secret", "good") - - val securityManager = new SecurityManager(conf) - - val hostname = "localhost" - val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) - System.setProperty("spark.hostPort", rpcEnv.address.hostPort) - - assert(securityManager.isAuthenticationEnabled() === true) - - val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - - val badconf = new SparkConf - badconf.set("spark.rpc", "akka") - badconf.set("spark.authenticate", "false") - badconf.set("spark.authenticate.secret", "bad") - val securityManagerBad = new SecurityManager(badconf) - - assert(securityManagerBad.isAuthenticationEnabled() === false) - - val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, badconf, securityManagerBad) - val slaveTracker = new MapOutputTrackerWorker(conf) - intercept[akka.actor.ActorNotFound] { - slaveTracker.trackerEndpoint = - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - } - - rpcEnv.shutdown() - slaveRpcEnv.shutdown() - } - - test("remote fetch ssl on") { - val conf = sparkSSLConfig() - val securityManager = new SecurityManager(conf) - - val hostname = "localhost" - val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) - System.setProperty("spark.hostPort", rpcEnv.address.hostPort) - - assert(securityManager.isAuthenticationEnabled() === false) - - val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - - val slaveConf = sparkSSLConfig() - val securityManagerBad = new SecurityManager(slaveConf) - - val slaveRpcEnv = RpcEnv.create("spark-slaves", hostname, 0, slaveConf, securityManagerBad) - val slaveTracker = new MapOutputTrackerWorker(conf) - slaveTracker.trackerEndpoint = - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - - assert(securityManagerBad.isAuthenticationEnabled() === false) - - masterTracker.registerShuffle(10, 1) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) - masterTracker.registerMapOutput(10, 0, - MapStatus(BlockManagerId("a", "hostA", 1000), Array(1000L))) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - // this should succeed since security off - assert(slaveTracker.getMapSizesByExecutorId(10, 0) === - Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) - - rpcEnv.shutdown() - slaveRpcEnv.shutdown() - } - - - test("remote fetch ssl on and security enabled") { - val conf = sparkSSLConfig() - conf.set("spark.authenticate", "true") - conf.set("spark.authenticate.secret", "good") - val securityManager = new SecurityManager(conf) - - val hostname = "localhost" - val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) - System.setProperty("spark.hostPort", rpcEnv.address.hostPort) - - assert(securityManager.isAuthenticationEnabled() === true) - - val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - - val slaveConf = sparkSSLConfig() - slaveConf.set("spark.authenticate", "true") - slaveConf.set("spark.authenticate.secret", "good") - val securityManagerBad = new SecurityManager(slaveConf) - - val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad) - val slaveTracker = new MapOutputTrackerWorker(conf) - slaveTracker.trackerEndpoint = - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - - assert(securityManagerBad.isAuthenticationEnabled() === true) - - masterTracker.registerShuffle(10, 1) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) - masterTracker.registerMapOutput(10, 0, - MapStatus(BlockManagerId("a", "hostA", 1000), Array(1000L))) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - assert(slaveTracker.getMapSizesByExecutorId(10, 0) === - Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) - - rpcEnv.shutdown() - slaveRpcEnv.shutdown() - } - - - test("remote fetch ssl on and security enabled - bad credentials") { - val conf = sparkSSLConfig() - conf.set("spark.rpc", "akka") - conf.set("spark.authenticate", "true") - conf.set("spark.authenticate.secret", "good") - val securityManager = new SecurityManager(conf) - - val hostname = "localhost" - val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) - System.setProperty("spark.hostPort", rpcEnv.address.hostPort) - - assert(securityManager.isAuthenticationEnabled() === true) - - val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - - val slaveConf = sparkSSLConfig() - slaveConf.set("spark.rpc", "akka") - slaveConf.set("spark.authenticate", "true") - slaveConf.set("spark.authenticate.secret", "bad") - val securityManagerBad = new SecurityManager(slaveConf) - - val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad) - val slaveTracker = new MapOutputTrackerWorker(conf) - intercept[akka.actor.ActorNotFound] { - slaveTracker.trackerEndpoint = - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - } - - rpcEnv.shutdown() - slaveRpcEnv.shutdown() - } - - - test("remote fetch ssl on - untrusted server") { - val conf = sparkSSLConfigUntrusted() - val securityManager = new SecurityManager(conf) - - val hostname = "localhost" - val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) - System.setProperty("spark.hostPort", rpcEnv.address.hostPort) - - assert(securityManager.isAuthenticationEnabled() === false) - - val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - - val slaveConf = sparkSSLConfig() - .set("spark.rpc.askTimeout", "5s") - .set("spark.rpc.lookupTimeout", "5s") - val securityManagerBad = new SecurityManager(slaveConf) - - val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad) - try { - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - fail("should receive either ActorNotFound or TimeoutException") - } catch { - case e: ActorNotFound => - case e: TimeoutException => - } - - rpcEnv.shutdown() - slaveRpcEnv.shutdown() - } - -} -- cgit v1.2.3