From 02144d6745ec0a6d8877d969feb82139bd22437f Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 24 Sep 2015 08:25:44 -0700 Subject: Revert "[SPARK-6028][Core]A new RPC implemetation based on the network module" This reverts commit 084e4e126211d74a79e8dbd2d0e604dd3c650822. --- .../scala/org/apache/spark/MapOutputTracker.scala | 2 +- .../src/main/scala/org/apache/spark/SparkEnv.scala | 20 +- .../org/apache/spark/deploy/worker/Worker.scala | 2 +- .../apache/spark/deploy/worker/WorkerWatcher.scala | 13 +- .../org/apache/spark/rpc/RpcCallContext.scala | 2 +- .../scala/org/apache/spark/rpc/RpcEndpoint.scala | 51 +-- .../spark/rpc/RpcEndpointNotFoundException.scala | 22 - .../main/scala/org/apache/spark/rpc/RpcEnv.scala | 7 +- .../org/apache/spark/rpc/akka/AkkaRpcEnv.scala | 6 +- .../org/apache/spark/rpc/netty/Dispatcher.scala | 218 --------- .../org/apache/spark/rpc/netty/IDVerifier.scala | 39 -- .../scala/org/apache/spark/rpc/netty/Inbox.scala | 220 --------- .../apache/spark/rpc/netty/NettyRpcAddress.scala | 56 --- .../spark/rpc/netty/NettyRpcCallContext.scala | 87 ---- .../org/apache/spark/rpc/netty/NettyRpcEnv.scala | 504 --------------------- .../spark/storage/BlockManagerSlaveEndpoint.scala | 6 +- .../scala/org/apache/spark/util/ThreadUtils.scala | 6 +- .../org/apache/spark/MapOutputTrackerSuite.scala | 10 +- .../scala/org/apache/spark/SSLSampleConfigs.scala | 2 - .../spark/deploy/worker/WorkerWatcherSuite.scala | 6 +- .../scala/org/apache/spark/rpc/RpcEnvSuite.scala | 78 +--- .../org/apache/spark/rpc/TestRpcEndpoint.scala | 123 ----- .../org/apache/spark/rpc/netty/InboxSuite.scala | 148 ------ .../spark/rpc/netty/NettyRpcAddressSuite.scala | 29 -- .../apache/spark/rpc/netty/NettyRpcEnvSuite.scala | 38 -- .../spark/rpc/netty/NettyRpcHandlerSuite.scala | 67 --- .../spark/network/client/TransportClient.java | 4 - .../apache/spark/network/server/RpcHandler.java | 2 - .../network/server/TransportRequestHandler.java | 1 - .../streaming/scheduler/ReceiverTracker.scala | 2 +- .../spark/deploy/yarn/ApplicationMaster.scala | 5 +- 31 files changed, 68 insertions(+), 1708 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/rpc/RpcEndpointNotFoundException.scala delete mode 100644 core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala delete mode 100644 core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala delete mode 100644 core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala delete mode 100644 core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala delete mode 100644 core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala delete mode 100644 core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala delete mode 100644 core/src/test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala delete mode 100644 core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index e380c5b8af..94eb8daa85 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -45,7 +45,7 @@ private[spark] class MapOutputTrackerMasterEndpoint( override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case GetMapOutputStatuses(shuffleId: Int) => - val hostPort = context.senderAddress.hostPort + val hostPort = context.sender.address.hostPort logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort) val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId) val serializedSize = mapOutputStatuses.size diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index cfde27fb2e..c6fef7f91f 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -20,10 +20,11 @@ package org.apache.spark import java.io.File import java.net.Socket +import akka.actor.ActorSystem + import scala.collection.mutable import scala.util.Properties -import akka.actor.ActorSystem import com.google.common.collect.MapMaker import org.apache.spark.annotation.DeveloperApi @@ -40,7 +41,7 @@ import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} import org.apache.spark.storage._ import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator} -import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils} +import org.apache.spark.util.{RpcUtils, Utils} /** * :: DeveloperApi :: @@ -56,7 +57,6 @@ import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils} class SparkEnv ( val executorId: String, private[spark] val rpcEnv: RpcEnv, - _actorSystem: ActorSystem, // TODO Remove actorSystem val serializer: Serializer, val closureSerializer: Serializer, val cacheManager: CacheManager, @@ -76,7 +76,7 @@ class SparkEnv ( // TODO Remove actorSystem @deprecated("Actor system is no longer supported as of 1.4.0", "1.4.0") - val actorSystem: ActorSystem = _actorSystem + val actorSystem: ActorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem private[spark] var isStopped = false private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() @@ -100,9 +100,6 @@ class SparkEnv ( blockManager.master.stop() metricsSystem.stop() outputCommitCoordinator.stop() - if (!rpcEnv.isInstanceOf[AkkaRpcEnv]) { - actorSystem.shutdown() - } rpcEnv.shutdown() // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut @@ -252,13 +249,7 @@ object SparkEnv extends Logging { // Create the ActorSystem for Akka and get the port it binds to. val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName val rpcEnv = RpcEnv.create(actorSystemName, hostname, port, conf, securityManager) - val actorSystem: ActorSystem = - if (rpcEnv.isInstanceOf[AkkaRpcEnv]) { - rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem - } else { - // Create a ActorSystem for legacy codes - AkkaUtils.createActorSystem(actorSystemName, hostname, port, conf, securityManager)._1 - } + val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem // Figure out which port Akka actually bound to in case the original port is 0 or occupied. if (isDriver) { @@ -404,7 +395,6 @@ object SparkEnv extends Logging { val envInstance = new SparkEnv( executorId, rpcEnv, - actorSystem, serializer, closureSerializer, cacheManager, diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 93a1b3f310..770927c80f 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -329,7 +329,7 @@ private[deploy] class Worker( registrationRetryTimer = Some(forwordMessageScheduler.scheduleAtFixedRate( new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { - Option(self).foreach(_.send(ReregisterWithMaster)) + self.send(ReregisterWithMaster) } }, INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS, diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index ab56fde938..735c4f0927 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -24,13 +24,14 @@ import org.apache.spark.rpc._ * Actor which connects to a worker process and terminates the JVM if the connection is severed. * Provides fate sharing between a worker and its associated child processes. */ -private[spark] class WorkerWatcher( - override val rpcEnv: RpcEnv, workerUrl: String, isTesting: Boolean = false) +private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: String) extends RpcEndpoint with Logging { - logInfo(s"Connecting to worker $workerUrl") - if (!isTesting) { - rpcEnv.asyncSetupEndpointRefByURI(workerUrl) + override def onStart() { + logInfo(s"Connecting to worker $workerUrl") + if (!isTesting) { + rpcEnv.asyncSetupEndpointRefByURI(workerUrl) + } } // Used to avoid shutting down JVM during tests @@ -39,6 +40,8 @@ private[spark] class WorkerWatcher( // true rather than calling `System.exit`. The user can check `isShutDown` to know if // `exitNonZero` is called. private[deploy] var isShutDown = false + private[deploy] def setTesting(testing: Boolean) = isTesting = testing + private var isTesting = false // Lets filter events only from the worker's rpc system private val expectedAddress = RpcAddress.fromURIString(workerUrl) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala index f527ec86ab..3e5b64265e 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala @@ -37,5 +37,5 @@ private[spark] trait RpcCallContext { /** * The sender of this message. */ - def senderAddress: RpcAddress + def sender: RpcEndpointRef } diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala index f1ddc6d2cd..dfcbc51cdf 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala @@ -28,6 +28,20 @@ private[spark] trait RpcEnvFactory { def create(config: RpcEnvConfig): RpcEnv } +/** + * A trait that requires RpcEnv thread-safely sending messages to it. + * + * Thread-safety means processing of one message happens before processing of the next message by + * the same [[ThreadSafeRpcEndpoint]]. In the other words, changes to internal fields of a + * [[ThreadSafeRpcEndpoint]] are visible when processing the next message, and fields in the + * [[ThreadSafeRpcEndpoint]] need not be volatile or equivalent. + * + * However, there is no guarantee that the same thread will be executing the same + * [[ThreadSafeRpcEndpoint]] for different messages. + */ +private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint + + /** * An end point for the RPC that defines what functions to trigger given a message. * @@ -87,39 +101,38 @@ private[spark] trait RpcEndpoint { } /** - * Invoked when `remoteAddress` is connected to the current node. + * Invoked before [[RpcEndpoint]] starts to handle any message. */ - def onConnected(remoteAddress: RpcAddress): Unit = { + def onStart(): Unit = { // By default, do nothing. } /** - * Invoked when `remoteAddress` is lost. + * Invoked when [[RpcEndpoint]] is stopping. */ - def onDisconnected(remoteAddress: RpcAddress): Unit = { + def onStop(): Unit = { // By default, do nothing. } /** - * Invoked when some network error happens in the connection between the current node and - * `remoteAddress`. + * Invoked when `remoteAddress` is connected to the current node. */ - def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + def onConnected(remoteAddress: RpcAddress): Unit = { // By default, do nothing. } /** - * Invoked before [[RpcEndpoint]] starts to handle any message. + * Invoked when `remoteAddress` is lost. */ - def onStart(): Unit = { + def onDisconnected(remoteAddress: RpcAddress): Unit = { // By default, do nothing. } /** - * Invoked when [[RpcEndpoint]] is stopping. `self` will be `null` in this method and you cannot - * use it to send or ask messages. + * Invoked when some network error happens in the connection between the current node and + * `remoteAddress`. */ - def onStop(): Unit = { + def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { // By default, do nothing. } @@ -133,17 +146,3 @@ private[spark] trait RpcEndpoint { } } } - -/** - * A trait that requires RpcEnv thread-safely sending messages to it. - * - * Thread-safety means processing of one message happens before processing of the next message by - * the same [[ThreadSafeRpcEndpoint]]. In the other words, changes to internal fields of a - * [[ThreadSafeRpcEndpoint]] are visible when processing the next message, and fields in the - * [[ThreadSafeRpcEndpoint]] need not be volatile or equivalent. - * - * However, there is no guarantee that the same thread will be executing the same - * [[ThreadSafeRpcEndpoint]] for different messages. - */ -private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint { -} diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointNotFoundException.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointNotFoundException.scala deleted file mode 100644 index d177881fb3..0000000000 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointNotFoundException.scala +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.rpc - -import org.apache.spark.SparkException - -private[rpc] class RpcEndpointNotFoundException(uri: String) - extends SparkException(s"Cannot find endpoint: $uri") diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index afe1ee8a0b..29debe8081 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -36,11 +36,8 @@ private[spark] object RpcEnv { private def getRpcEnvFactory(conf: SparkConf): RpcEnvFactory = { // Add more RpcEnv implementations here - val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory", - "netty" -> "org.apache.spark.rpc.netty.NettyRpcEnvFactory") - // Use "netty" by default so that Jenkins can run all tests using NettyRpcEnv. - // Will change it back to "akka" before merging the new implementation. - val rpcEnvName = conf.get("spark.rpc", "netty") + val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory") + val rpcEnvName = conf.get("spark.rpc", "akka") val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName) Utils.classForName(rpcEnvFactoryClassName).newInstance().asInstanceOf[RpcEnvFactory] } diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 95132a4e4a..ad67e1c5ad 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -166,9 +166,9 @@ private[spark] class AkkaRpcEnv private[akka] ( _sender ! AkkaMessage(response, false) } - // Use "lazy" because most of RpcEndpoints don't need "senderAddress" - override lazy val senderAddress: RpcAddress = - new AkkaRpcEndpointRef(defaultAddress, _sender, conf).address + // Some RpcEndpoints need to know the sender's address + override val sender: RpcEndpointRef = + new AkkaRpcEndpointRef(defaultAddress, _sender, conf) }) } else { endpoint.receive diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala deleted file mode 100644 index d71e6f01db..0000000000 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ /dev/null @@ -1,218 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rpc.netty - -import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, TimeUnit} -import javax.annotation.concurrent.GuardedBy - -import scala.collection.JavaConverters._ -import scala.concurrent.Promise -import scala.util.control.NonFatal - -import org.apache.spark.{SparkException, Logging} -import org.apache.spark.network.client.RpcResponseCallback -import org.apache.spark.rpc._ -import org.apache.spark.util.ThreadUtils - -private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { - - private class EndpointData( - val name: String, - val endpoint: RpcEndpoint, - val ref: NettyRpcEndpointRef) { - val inbox = new Inbox(ref, endpoint) - } - - private val endpoints = new ConcurrentHashMap[String, EndpointData]() - private val endpointRefs = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]() - - // Track the receivers whose inboxes may contain messages. - private val receivers = new LinkedBlockingQueue[EndpointData]() - - @GuardedBy("this") - private var stopped = false - - def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = { - val addr = new NettyRpcAddress(nettyEnv.address.host, nettyEnv.address.port, name) - val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv) - synchronized { - if (stopped) { - throw new IllegalStateException("RpcEnv has been stopped") - } - if (endpoints.putIfAbsent(name, new EndpointData(name, endpoint, endpointRef)) != null) { - throw new IllegalArgumentException(s"There is already an RpcEndpoint called $name") - } - val data = endpoints.get(name) - endpointRefs.put(data.endpoint, data.ref) - receivers.put(data) - } - endpointRef - } - - def getRpcEndpointRef(endpoint: RpcEndpoint): RpcEndpointRef = endpointRefs.get(endpoint) - - def removeRpcEndpointRef(endpoint: RpcEndpoint): Unit = endpointRefs.remove(endpoint) - - // Should be idempotent - private def unregisterRpcEndpoint(name: String): Unit = { - val data = endpoints.remove(name) - if (data != null) { - data.inbox.stop() - receivers.put(data) - } - // Don't clean `endpointRefs` here because it's possible that some messages are being processed - // now and they can use `getRpcEndpointRef`. So `endpointRefs` will be cleaned in Inbox via - // `removeRpcEndpointRef`. - } - - def stop(rpcEndpointRef: RpcEndpointRef): Unit = { - synchronized { - if (stopped) { - // This endpoint will be stopped by Dispatcher.stop() method. - return - } - unregisterRpcEndpoint(rpcEndpointRef.name) - } - } - - /** - * Send a message to all registered [[RpcEndpoint]]s. - * @param message - */ - def broadcastMessage(message: InboxMessage): Unit = { - val iter = endpoints.keySet().iterator() - while (iter.hasNext) { - val name = iter.next - postMessageToInbox(name, (_) => message, - () => { logWarning(s"Drop ${message} because ${name} has been stopped") }) - } - } - - def postMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = { - def createMessage(sender: NettyRpcEndpointRef): InboxMessage = { - val rpcCallContext = - new RemoteNettyRpcCallContext( - nettyEnv, sender, callback, message.senderAddress, message.needReply) - ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext) - } - - def onEndpointStopped(): Unit = { - callback.onFailure( - new SparkException(s"Could not find ${message.receiver.name} or it has been stopped")) - } - - postMessageToInbox(message.receiver.name, createMessage, onEndpointStopped) - } - - def postMessage(message: RequestMessage, p: Promise[Any]): Unit = { - def createMessage(sender: NettyRpcEndpointRef): InboxMessage = { - val rpcCallContext = - new LocalNettyRpcCallContext(sender, message.senderAddress, message.needReply, p) - ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext) - } - - def onEndpointStopped(): Unit = { - p.tryFailure( - new SparkException(s"Could not find ${message.receiver.name} or it has been stopped")) - } - - postMessageToInbox(message.receiver.name, createMessage, onEndpointStopped) - } - - private def postMessageToInbox( - endpointName: String, - createMessageFn: NettyRpcEndpointRef => InboxMessage, - onStopped: () => Unit): Unit = { - val shouldCallOnStop = - synchronized { - val data = endpoints.get(endpointName) - if (stopped || data == null) { - true - } else { - data.inbox.post(createMessageFn(data.ref)) - receivers.put(data) - false - } - } - if (shouldCallOnStop) { - // We don't need to call `onStop` in the `synchronized` block - onStopped() - } - } - - private val parallelism = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.parallelism", - Runtime.getRuntime.availableProcessors()) - - private val executor = ThreadUtils.newDaemonFixedThreadPool(parallelism, "dispatcher-event-loop") - - (0 until parallelism) foreach { _ => - executor.execute(new MessageLoop) - } - - def stop(): Unit = { - synchronized { - if (stopped) { - return - } - stopped = true - } - // Stop all endpoints. This will queue all endpoints for processing by the message loops. - endpoints.keySet().asScala.foreach(unregisterRpcEndpoint) - // Enqueue a message that tells the message loops to stop. - receivers.put(PoisonEndpoint) - executor.shutdown() - } - - def awaitTermination(): Unit = { - executor.awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS) - } - - /** - * Return if the endpoint exists - */ - def verify(name: String): Boolean = { - endpoints.containsKey(name) - } - - private class MessageLoop extends Runnable { - override def run(): Unit = { - try { - while (true) { - try { - val data = receivers.take() - if (data == PoisonEndpoint) { - // Put PoisonEndpoint back so that other MessageLoops can see it. - receivers.put(PoisonEndpoint) - return - } - data.inbox.process(Dispatcher.this) - } catch { - case NonFatal(e) => logError(e.getMessage, e) - } - } - } catch { - case ie: InterruptedException => // exit - } - } - } - - /** - * A poison endpoint that indicates MessageLoop should exit its loop. - */ - private val PoisonEndpoint = new EndpointData(null, null, null) -} diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala b/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala deleted file mode 100644 index 6061c9b8de..0000000000 --- a/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.rpc.netty - -import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv} - -/** - * A message used to ask the remote [[IDVerifier]] if an [[RpcEndpoint]] exists - */ -private[netty] case class ID(name: String) - -/** - * An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if a [[RpcEndpoint]] exists in this [[RpcEnv]] - */ -private[netty] class IDVerifier( - override val rpcEnv: RpcEnv, dispatcher: Dispatcher) extends RpcEndpoint { - - override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case ID(name) => context.reply(dispatcher.verify(name)) - } -} - -private[netty] object IDVerifier { - val NAME = "id-verifier" -} diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala deleted file mode 100644 index 4803548365..0000000000 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala +++ /dev/null @@ -1,220 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rpc.netty - -import java.util.LinkedList -import javax.annotation.concurrent.GuardedBy - -import scala.util.control.NonFatal - -import org.apache.spark.{Logging, SparkException} -import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, ThreadSafeRpcEndpoint} - -private[netty] sealed trait InboxMessage - -private[netty] case class ContentMessage( - senderAddress: RpcAddress, - content: Any, - needReply: Boolean, - context: NettyRpcCallContext) extends InboxMessage - -private[netty] case object OnStart extends InboxMessage - -private[netty] case object OnStop extends InboxMessage - -/** - * A broadcast message that indicates connecting to a remote node. - */ -private[netty] case class Associated(remoteAddress: RpcAddress) extends InboxMessage - -/** - * A broadcast message that indicates a remote connection is lost. - */ -private[netty] case class Disassociated(remoteAddress: RpcAddress) extends InboxMessage - -/** - * A broadcast message that indicates a network error - */ -private[netty] case class AssociationError(cause: Throwable, remoteAddress: RpcAddress) - extends InboxMessage - -/** - * A inbox that stores messages for an [[RpcEndpoint]] and posts messages to it thread-safely. - * @param endpointRef - * @param endpoint - */ -private[netty] class Inbox( - val endpointRef: NettyRpcEndpointRef, - val endpoint: RpcEndpoint) extends Logging { - - @GuardedBy("this") - protected val messages = new LinkedList[InboxMessage]() - - @GuardedBy("this") - private var stopped = false - - @GuardedBy("this") - private var enableConcurrent = false - - @GuardedBy("this") - private var workerCount = 0 - - // OnStart should be the first message to process - synchronized { - messages.add(OnStart) - } - - /** - * Process stored messages. - */ - def process(dispatcher: Dispatcher): Unit = { - var message: InboxMessage = null - synchronized { - if (!enableConcurrent && workerCount != 0) { - return - } - message = messages.poll() - if (message != null) { - workerCount += 1 - } else { - return - } - } - while (true) { - safelyCall(endpoint) { - message match { - case ContentMessage(_sender, content, needReply, context) => - val pf: PartialFunction[Any, Unit] = - if (needReply) { - endpoint.receiveAndReply(context) - } else { - endpoint.receive - } - try { - pf.applyOrElse[Any, Unit](content, { msg => - throw new SparkException(s"Unmatched message $message from ${_sender}") - }) - if (!needReply) { - context.finish() - } - } catch { - case NonFatal(e) => - if (needReply) { - // If the sender asks a reply, we should send the error back to the sender - context.sendFailure(e) - } else { - context.finish() - throw e - } - } - - case OnStart => { - endpoint.onStart() - if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) { - synchronized { - enableConcurrent = true - } - } - } - - case OnStop => - assert(synchronized { workerCount } == 1) - dispatcher.removeRpcEndpointRef(endpoint) - endpoint.onStop() - assert(isEmpty, "OnStop should be the last message") - - case Associated(remoteAddress) => - endpoint.onConnected(remoteAddress) - - case Disassociated(remoteAddress) => - endpoint.onDisconnected(remoteAddress) - - case AssociationError(cause, remoteAddress) => - endpoint.onNetworkError(cause, remoteAddress) - } - } - - synchronized { - // "enableConcurrent" will be set to false after `onStop` is called, so we should check it - // every time. - if (!enableConcurrent && workerCount != 1) { - // If we are not the only one worker, exit - workerCount -= 1 - return - } - message = messages.poll() - if (message == null) { - workerCount -= 1 - return - } - } - } - } - - def post(message: InboxMessage): Unit = { - val dropped = - synchronized { - if (stopped) { - // We already put "OnStop" into "messages", so we should drop further messages - true - } else { - messages.add(message) - false - } - } - if (dropped) { - onDrop(message) - } - } - - def stop(): Unit = synchronized { - // The following codes should be in `synchronized` so that we can make sure "OnStop" is the last - // message - if (!stopped) { - // We should disable concurrent here. Then when RpcEndpoint.onStop is called, it's the only - // thread that is processing messages. So `RpcEndpoint.onStop` can release its resources - // safely. - enableConcurrent = false - stopped = true - messages.add(OnStop) - // Note: The concurrent events in messages will be processed one by one. - } - } - - // Visible for testing. - protected def onDrop(message: InboxMessage): Unit = { - logWarning(s"Drop ${message} because $endpointRef is stopped") - } - - def isEmpty: Boolean = synchronized { messages.isEmpty } - - private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = { - try { - action - } catch { - case NonFatal(e) => { - try { - endpoint.onError(e) - } catch { - case NonFatal(e) => logWarning(s"Ignore error", e) - } - } - } - } - -} diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala deleted file mode 100644 index 1876b25592..0000000000 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rpc.netty - -import java.net.URI - -import org.apache.spark.SparkException -import org.apache.spark.rpc.RpcAddress - -private[netty] case class NettyRpcAddress(host: String, port: Int, name: String) { - - def toRpcAddress: RpcAddress = RpcAddress(host, port) - - override val toString = s"spark://$name@$host:$port" -} - -private[netty] object NettyRpcAddress { - - def apply(sparkUrl: String): NettyRpcAddress = { - try { - val uri = new URI(sparkUrl) - val host = uri.getHost - val port = uri.getPort - val name = uri.getUserInfo - if (uri.getScheme != "spark" || - host == null || - port < 0 || - name == null || - (uri.getPath != null && !uri.getPath.isEmpty) || // uri.getPath returns "" instead of null - uri.getFragment != null || - uri.getQuery != null) { - throw new SparkException("Invalid Spark URL: " + sparkUrl) - } - NettyRpcAddress(host, port, name) - } catch { - case e: java.net.URISyntaxException => - throw new SparkException("Invalid Spark URL: " + sparkUrl, e) - } - } - -} diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala deleted file mode 100644 index 75dcc02a0c..0000000000 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rpc.netty - -import scala.concurrent.Promise - -import org.apache.spark.Logging -import org.apache.spark.network.client.RpcResponseCallback -import org.apache.spark.rpc.{RpcAddress, RpcCallContext} - -private[netty] abstract class NettyRpcCallContext( - endpointRef: NettyRpcEndpointRef, - override val senderAddress: RpcAddress, - needReply: Boolean) extends RpcCallContext with Logging { - - protected def send(message: Any): Unit - - override def reply(response: Any): Unit = { - if (needReply) { - send(AskResponse(endpointRef, response)) - } else { - throw new IllegalStateException( - s"Cannot send $response to the sender because the sender won't handle it") - } - } - - override def sendFailure(e: Throwable): Unit = { - if (needReply) { - send(AskResponse(endpointRef, RpcFailure(e))) - } else { - logError(e.getMessage, e) - throw new IllegalStateException( - "Cannot send reply to the sender because the sender won't handle it") - } - } - - def finish(): Unit = { - if (!needReply) { - send(Ack(endpointRef)) - } - } -} - -/** - * If the sender and the receiver are in the same process, the reply can be sent back via `Promise`. - */ -private[netty] class LocalNettyRpcCallContext( - endpointRef: NettyRpcEndpointRef, - senderAddress: RpcAddress, - needReply: Boolean, - p: Promise[Any]) extends NettyRpcCallContext(endpointRef, senderAddress, needReply) { - - override protected def send(message: Any): Unit = { - p.success(message) - } -} - -/** - * A [[RpcCallContext]] that will call [[RpcResponseCallback]] to send the reply back. - */ -private[netty] class RemoteNettyRpcCallContext( - nettyEnv: NettyRpcEnv, - endpointRef: NettyRpcEndpointRef, - callback: RpcResponseCallback, - senderAddress: RpcAddress, - needReply: Boolean) extends NettyRpcCallContext(endpointRef, senderAddress, needReply) { - - override protected def send(message: Any): Unit = { - val reply = nettyEnv.serialize(message) - callback.onSuccess(reply) - } -} diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala deleted file mode 100644 index 5522b40782..0000000000 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ /dev/null @@ -1,504 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.rpc.netty - -import java.io._ -import java.net.{InetSocketAddress, URI} -import java.nio.ByteBuffer -import java.util.Arrays -import java.util.concurrent._ -import javax.annotation.concurrent.GuardedBy - -import scala.collection.JavaConverters._ -import scala.collection.mutable -import scala.concurrent.{Future, Promise} -import scala.reflect.ClassTag -import scala.util.{DynamicVariable, Failure, Success} -import scala.util.control.NonFatal - -import org.apache.spark.{Logging, SecurityManager, SparkConf} -import org.apache.spark.network.TransportContext -import org.apache.spark.network.client._ -import org.apache.spark.network.netty.SparkTransportConf -import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap} -import org.apache.spark.network.server._ -import org.apache.spark.rpc._ -import org.apache.spark.serializer.{JavaSerializer, JavaSerializerInstance} -import org.apache.spark.util.{ThreadUtils, Utils} - -private[netty] class NettyRpcEnv( - val conf: SparkConf, - javaSerializerInstance: JavaSerializerInstance, - host: String, - securityManager: SecurityManager) extends RpcEnv(conf) with Logging { - - private val transportConf = - SparkTransportConf.fromSparkConf(conf, conf.getInt("spark.rpc.io.threads", 0)) - - private val dispatcher: Dispatcher = new Dispatcher(this) - - private val transportContext = - new TransportContext(transportConf, new NettyRpcHandler(dispatcher, this)) - - private val clientFactory = { - val bootstraps: Seq[TransportClientBootstrap] = - if (securityManager.isAuthenticationEnabled()) { - Seq(new SaslClientBootstrap(transportConf, "", securityManager, - securityManager.isSaslEncryptionEnabled())) - } else { - Nil - } - transportContext.createClientFactory(bootstraps.asJava) - } - - val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout") - - // Because TransportClientFactory.createClient is blocking, we need to run it in this thread pool - // to implement non-blocking send/ask. - // TODO: a non-blocking TransportClientFactory.createClient in future - private val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool( - "netty-rpc-connection", - conf.getInt("spark.rpc.connect.threads", 256)) - - @volatile private var server: TransportServer = _ - - def start(port: Int): Unit = { - val bootstraps: Seq[TransportServerBootstrap] = - if (securityManager.isAuthenticationEnabled()) { - Seq(new SaslServerBootstrap(transportConf, securityManager)) - } else { - Nil - } - server = transportContext.createServer(port, bootstraps.asJava) - dispatcher.registerRpcEndpoint(IDVerifier.NAME, new IDVerifier(this, dispatcher)) - } - - override lazy val address: RpcAddress = { - require(server != null, "NettyRpcEnv has not yet started") - RpcAddress(host, server.getPort()) - } - - override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { - dispatcher.registerRpcEndpoint(name, endpoint) - } - - def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = { - val addr = NettyRpcAddress(uri) - val endpointRef = new NettyRpcEndpointRef(conf, addr, this) - val idVerifierRef = - new NettyRpcEndpointRef(conf, NettyRpcAddress(addr.host, addr.port, IDVerifier.NAME), this) - idVerifierRef.ask[Boolean](ID(endpointRef.name)).flatMap { find => - if (find) { - Future.successful(endpointRef) - } else { - Future.failed(new RpcEndpointNotFoundException(uri)) - } - }(ThreadUtils.sameThread) - } - - override def stop(endpointRef: RpcEndpointRef): Unit = { - require(endpointRef.isInstanceOf[NettyRpcEndpointRef]) - dispatcher.stop(endpointRef) - } - - private[netty] def send(message: RequestMessage): Unit = { - val remoteAddr = message.receiver.address - if (remoteAddr == address) { - val promise = Promise[Any]() - dispatcher.postMessage(message, promise) - promise.future.onComplete { - case Success(response) => - val ack = response.asInstanceOf[Ack] - logDebug(s"Receive ack from ${ack.sender}") - case Failure(e) => - logError(s"Exception when sending $message", e) - }(ThreadUtils.sameThread) - } else { - try { - // `createClient` will block if it cannot find a known connection, so we should run it in - // clientConnectionExecutor - clientConnectionExecutor.execute(new Runnable { - override def run(): Unit = Utils.tryLogNonFatalError { - val client = clientFactory.createClient(remoteAddr.host, remoteAddr.port) - client.sendRpc(serialize(message), new RpcResponseCallback { - - override def onFailure(e: Throwable): Unit = { - logError(s"Exception when sending $message", e) - } - - override def onSuccess(response: Array[Byte]): Unit = { - val ack = deserialize[Ack](response) - logDebug(s"Receive ack from ${ack.sender}") - } - }) - } - }) - } catch { - case e: RejectedExecutionException => { - // `send` after shutting clientConnectionExecutor down, ignore it - logWarning(s"Cannot send ${message} because RpcEnv is stopped") - } - } - } - } - - private[netty] def ask(message: RequestMessage): Future[Any] = { - val promise = Promise[Any]() - val remoteAddr = message.receiver.address - if (remoteAddr == address) { - val p = Promise[Any]() - dispatcher.postMessage(message, p) - p.future.onComplete { - case Success(response) => - val reply = response.asInstanceOf[AskResponse] - if (reply.reply.isInstanceOf[RpcFailure]) { - if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) { - logWarning(s"Ignore failure: ${reply.reply}") - } - } else if (!promise.trySuccess(reply.reply)) { - logWarning(s"Ignore message: ${reply}") - } - case Failure(e) => - if (!promise.tryFailure(e)) { - logWarning("Ignore Exception", e) - } - }(ThreadUtils.sameThread) - } else { - try { - // `createClient` will block if it cannot find a known connection, so we should run it in - // clientConnectionExecutor - clientConnectionExecutor.execute(new Runnable { - override def run(): Unit = { - val client = clientFactory.createClient(remoteAddr.host, remoteAddr.port) - client.sendRpc(serialize(message), new RpcResponseCallback { - - override def onFailure(e: Throwable): Unit = { - if (!promise.tryFailure(e)) { - logWarning("Ignore Exception", e) - } - } - - override def onSuccess(response: Array[Byte]): Unit = { - val reply = deserialize[AskResponse](response) - if (reply.reply.isInstanceOf[RpcFailure]) { - if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) { - logWarning(s"Ignore failure: ${reply.reply}") - } - } else if (!promise.trySuccess(reply.reply)) { - logWarning(s"Ignore message: ${reply}") - } - } - }) - } - }) - } catch { - case e: RejectedExecutionException => { - if (!promise.tryFailure(e)) { - logWarning(s"Ignore failure", e) - } - } - } - } - promise.future - } - - private[netty] def serialize(content: Any): Array[Byte] = { - val buffer = javaSerializerInstance.serialize(content) - Arrays.copyOfRange( - buffer.array(), buffer.arrayOffset + buffer.position, buffer.arrayOffset + buffer.limit) - } - - private[netty] def deserialize[T: ClassTag](bytes: Array[Byte]): T = { - deserialize { () => - javaSerializerInstance.deserialize[T](ByteBuffer.wrap(bytes)) - } - } - - override def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = { - dispatcher.getRpcEndpointRef(endpoint) - } - - override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = - new NettyRpcAddress(address.host, address.port, endpointName).toString - - override def shutdown(): Unit = { - cleanup() - } - - override def awaitTermination(): Unit = { - dispatcher.awaitTermination() - } - - private def cleanup(): Unit = { - if (timeoutScheduler != null) { - timeoutScheduler.shutdownNow() - } - if (server != null) { - server.close() - } - if (clientFactory != null) { - clientFactory.close() - } - if (dispatcher != null) { - dispatcher.stop() - } - if (clientConnectionExecutor != null) { - clientConnectionExecutor.shutdownNow() - } - } - - override def deserialize[T](deserializationAction: () => T): T = { - NettyRpcEnv.currentEnv.withValue(this) { - deserializationAction() - } - } -} - -private[netty] object NettyRpcEnv extends Logging { - - /** - * When deserializing the [[NettyRpcEndpointRef]], it needs a reference to [[NettyRpcEnv]]. - * Use `currentEnv` to wrap the deserialization codes. E.g., - * - * {{{ - * NettyRpcEnv.currentEnv.withValue(this) { - * your deserialization codes - * } - * }}} - */ - private[netty] val currentEnv = new DynamicVariable[NettyRpcEnv](null) -} - -private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { - - def create(config: RpcEnvConfig): RpcEnv = { - val sparkConf = config.conf - // Use JavaSerializerInstance in multiple threads is safe. However, if we plan to support - // KryoSerializer in future, we have to use ThreadLocal to store SerializerInstance - val javaSerializerInstance = - new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance] - val nettyEnv = - new NettyRpcEnv(sparkConf, javaSerializerInstance, config.host, config.securityManager) - val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort => - nettyEnv.start(actualPort) - (nettyEnv, actualPort) - } - try { - Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, "NettyRpcEnv")._1 - } catch { - case NonFatal(e) => - nettyEnv.shutdown() - throw e - } - } -} - -private[netty] class NettyRpcEndpointRef(@transient conf: SparkConf) - extends RpcEndpointRef(conf) with Serializable with Logging { - - @transient @volatile private var nettyEnv: NettyRpcEnv = _ - - @transient @volatile private var _address: NettyRpcAddress = _ - - def this(conf: SparkConf, _address: NettyRpcAddress, nettyEnv: NettyRpcEnv) { - this(conf) - this._address = _address - this.nettyEnv = nettyEnv - } - - override def address: RpcAddress = _address.toRpcAddress - - private def readObject(in: ObjectInputStream): Unit = { - in.defaultReadObject() - _address = in.readObject().asInstanceOf[NettyRpcAddress] - nettyEnv = NettyRpcEnv.currentEnv.value - } - - private def writeObject(out: ObjectOutputStream): Unit = { - out.defaultWriteObject() - out.writeObject(_address) - } - - override def name: String = _address.name - - override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { - val promise = Promise[Any]() - val timeoutCancelable = nettyEnv.timeoutScheduler.schedule(new Runnable { - override def run(): Unit = { - promise.tryFailure(new TimeoutException("Cannot receive any reply in " + timeout.duration)) - } - }, timeout.duration.toNanos, TimeUnit.NANOSECONDS) - val f = nettyEnv.ask(RequestMessage(nettyEnv.address, this, message, true)) - f.onComplete { v => - timeoutCancelable.cancel(true) - if (!promise.tryComplete(v)) { - logWarning(s"Ignore message $v") - } - }(ThreadUtils.sameThread) - promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread) - } - - override def send(message: Any): Unit = { - require(message != null, "Message is null") - nettyEnv.send(RequestMessage(nettyEnv.address, this, message, false)) - } - - override def toString: String = s"NettyRpcEndpointRef(${_address})" - - def toURI: URI = new URI(s"spark://${_address}") - - final override def equals(that: Any): Boolean = that match { - case other: NettyRpcEndpointRef => _address == other._address - case _ => false - } - - final override def hashCode(): Int = if (_address == null) 0 else _address.hashCode() -} - -/** - * The message that is sent from the sender to the receiver. - */ -private[netty] case class RequestMessage( - senderAddress: RpcAddress, receiver: NettyRpcEndpointRef, content: Any, needReply: Boolean) - -/** - * The base trait for all messages that are sent back from the receiver to the sender. - */ -private[netty] trait ResponseMessage - -/** - * The reply for `ask` from the receiver side. - */ -private[netty] case class AskResponse(sender: NettyRpcEndpointRef, reply: Any) - extends ResponseMessage - -/** - * A message to send back to the receiver side. It's necessary because [[TransportClient]] only - * clean the resources when it receives a reply. - */ -private[netty] case class Ack(sender: NettyRpcEndpointRef) extends ResponseMessage - -/** - * A response that indicates some failure happens in the receiver side. - */ -private[netty] case class RpcFailure(e: Throwable) - -/** - * Maintain the mapping relations between client addresses and [[RpcEnv]] addresses, broadcast - * network events and forward messages to [[Dispatcher]]. - */ -private[netty] class NettyRpcHandler( - dispatcher: Dispatcher, nettyEnv: NettyRpcEnv) extends RpcHandler with Logging { - - private type ClientAddress = RpcAddress - private type RemoteEnvAddress = RpcAddress - - // Store all client addresses and their NettyRpcEnv addresses. - @GuardedBy("this") - private val remoteAddresses = new mutable.HashMap[ClientAddress, RemoteEnvAddress]() - - // Store the connections from other NettyRpcEnv addresses. We need to keep track of the connection - // count because `TransportClientFactory.createClient` will create multiple connections - // (at most `spark.shuffle.io.numConnectionsPerPeer` connections) and randomly select a connection - // to send the message. See `TransportClientFactory.createClient` for more details. - @GuardedBy("this") - private val remoteConnectionCount = new mutable.HashMap[RemoteEnvAddress, Int]() - - override def receive( - client: TransportClient, message: Array[Byte], callback: RpcResponseCallback): Unit = { - val requestMessage = nettyEnv.deserialize[RequestMessage](message) - val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] - assert(addr != null) - val remoteEnvAddress = requestMessage.senderAddress - val clientAddr = RpcAddress(addr.getHostName, addr.getPort) - val broadcastMessage = - synchronized { - // If the first connection to a remote RpcEnv is found, we should broadcast "Associated" - if (remoteAddresses.put(clientAddr, remoteEnvAddress).isEmpty) { - // clientAddr connects at the first time - val count = remoteConnectionCount.getOrElse(remoteEnvAddress, 0) - // Increase the connection number of remoteEnvAddress - remoteConnectionCount.put(remoteEnvAddress, count + 1) - if (count == 0) { - // This is the first connection, so fire "Associated" - Some(Associated(remoteEnvAddress)) - } else { - None - } - } else { - None - } - } - broadcastMessage.foreach(dispatcher.broadcastMessage) - dispatcher.postMessage(requestMessage, callback) - } - - override def getStreamManager: StreamManager = new OneForOneStreamManager - - override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = { - val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] - if (addr != null) { - val clientAddr = RpcAddress(addr.getHostName, addr.getPort) - val broadcastMessage = - synchronized { - remoteAddresses.get(clientAddr).map(AssociationError(cause, _)) - } - if (broadcastMessage.isEmpty) { - logError(cause.getMessage, cause) - } else { - dispatcher.broadcastMessage(broadcastMessage.get) - } - } else { - // If the channel is closed before connecting, its remoteAddress will be null. - // See java.net.Socket.getRemoteSocketAddress - // Because we cannot get a RpcAddress, just log it - logError("Exception before connecting to the client", cause) - } - } - - override def connectionTerminated(client: TransportClient): Unit = { - val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] - if (addr != null) { - val clientAddr = RpcAddress(addr.getHostName, addr.getPort) - val broadcastMessage = - synchronized { - // If the last connection to a remote RpcEnv is terminated, we should broadcast - // "Disassociated" - remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress => - remoteAddresses -= clientAddr - val count = remoteConnectionCount.getOrElse(remoteEnvAddress, 0) - assert(count != 0, "remoteAddresses and remoteConnectionCount are not consistent") - if (count - 1 == 0) { - // We lost all clients, so clean up and fire "Disassociated" - remoteConnectionCount.remove(remoteEnvAddress) - Some(Disassociated(remoteEnvAddress)) - } else { - // Decrease the connection number of remoteEnvAddress - remoteConnectionCount.put(remoteEnvAddress, count - 1) - None - } - } - } - broadcastMessage.foreach(dispatcher.broadcastMessage) - } else { - // If the channel is closed before connecting, its remoteAddress will be null. In this case, - // we can ignore it since we don't fire "Associated". - // See java.net.Socket.getRemoteSocketAddress - } - } - -} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala index e749631bf6..7478ab0fc2 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -19,7 +19,7 @@ package org.apache.spark.storage import scala.concurrent.{ExecutionContext, Future} -import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext, RpcEndpoint} +import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint} import org.apache.spark.util.ThreadUtils import org.apache.spark.{Logging, MapOutputTracker, SparkEnv} import org.apache.spark.storage.BlockManagerMessages._ @@ -33,7 +33,7 @@ class BlockManagerSlaveEndpoint( override val rpcEnv: RpcEnv, blockManager: BlockManager, mapOutputTracker: MapOutputTracker) - extends ThreadSafeRpcEndpoint with Logging { + extends RpcEndpoint with Logging { private val asyncThreadPool = ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool") @@ -80,7 +80,7 @@ class BlockManagerSlaveEndpoint( future.onSuccess { case response => logDebug("Done " + actionMessage + ", response is " + response) context.reply(response) - logDebug("Sent response: " + response + " to " + context.senderAddress) + logDebug("Sent response: " + response + " to " + context.sender) } future.onFailure { case t: Throwable => logError("Error in " + actionMessage, t) diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 1ed098379e..22e291a2b4 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -85,11 +85,7 @@ private[spark] object ThreadUtils { */ def newDaemonSingleThreadScheduledExecutor(threadName: String): ScheduledExecutorService = { val threadFactory = new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadName).build() - val executor = new ScheduledThreadPoolExecutor(1, threadFactory) - // By default, a cancelled task is not automatically removed from the work queue until its delay - // elapses. We have to enable it manually. - executor.setRemoveOnCancelPolicy(true) - executor + Executors.newSingleThreadScheduledExecutor(threadFactory) } /** diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 7e70308bb3..af4e68950f 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -168,9 +168,10 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.registerShuffle(10, 1) masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0))) - val senderAddress = RpcAddress("localhost", 12345) + val sender = mock(classOf[RpcEndpointRef]) + when(sender.address).thenReturn(RpcAddress("localhost", 12345)) val rpcCallContext = mock(classOf[RpcCallContext]) - when(rpcCallContext.senderAddress).thenReturn(senderAddress) + when(rpcCallContext.sender).thenReturn(sender) masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(10)) verify(rpcCallContext).reply(any()) verify(rpcCallContext, never()).sendFailure(any()) @@ -197,9 +198,10 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.registerMapOutput(20, i, new CompressedMapStatus( BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) } - val senderAddress = RpcAddress("localhost", 12345) + val sender = mock(classOf[RpcEndpointRef]) + when(sender.address).thenReturn(RpcAddress("localhost", 12345)) val rpcCallContext = mock(classOf[RpcCallContext]) - when(rpcCallContext.senderAddress).thenReturn(senderAddress) + when(rpcCallContext.sender).thenReturn(sender) masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(20)) verify(rpcCallContext, never()).reply(any()) verify(rpcCallContext).sendFailure(isA(classOf[SparkException])) diff --git a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala index 2d14249855..33270bec62 100644 --- a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala +++ b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala @@ -41,7 +41,6 @@ object SSLSampleConfigs { def sparkSSLConfig(): SparkConf = { val conf = new SparkConf(loadDefaults = false) - conf.set("spark.rpc", "akka") conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", keyStorePath) conf.set("spark.ssl.keyStorePassword", "password") @@ -55,7 +54,6 @@ object SSLSampleConfigs { def sparkSSLConfigUntrusted(): SparkConf = { val conf = new SparkConf(loadDefaults = false) - conf.set("spark.rpc", "akka") conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", untrustedKeyStorePath) conf.set("spark.ssl.keyStorePassword", "password") diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala index 40c24bdecc..e9034e39a7 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala @@ -26,7 +26,8 @@ class WorkerWatcherSuite extends SparkFunSuite { val conf = new SparkConf() val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker") - val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl, isTesting = true) + val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) + workerWatcher.setTesting(testing = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) workerWatcher.onDisconnected(RpcAddress("1.2.3.4", 1234)) assert(workerWatcher.isShutDown) @@ -38,7 +39,8 @@ class WorkerWatcherSuite extends SparkFunSuite { val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker") val otherRpcAddress = RpcAddress("4.3.2.1", 1234) - val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl, isTesting = true) + val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) + workerWatcher.setTesting(testing = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) workerWatcher.onDisconnected(otherRpcAddress) assert(!workerWatcher.isShutDown) diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index e836946a59..6ceafe4337 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.rpc -import java.io.NotSerializableException import java.util.concurrent.{TimeUnit, CountDownLatch, TimeoutException} import scala.collection.mutable @@ -100,6 +99,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } } val rpcEndpointRef = env.setupEndpoint("send-ref", endpoint) + val newRpcEndpointRef = rpcEndpointRef.askWithRetry[RpcEndpointRef]("Hello") val reply = newRpcEndpointRef.askWithRetry[String]("Echo") assert("Echo" === reply) @@ -516,9 +516,6 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { assert(events === List( ("onConnected", remoteAddress), ("onNetworkError", remoteAddress), - ("onDisconnected", remoteAddress)) || - events === List( - ("onConnected", remoteAddress), ("onDisconnected", remoteAddress))) } } @@ -538,84 +535,15 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { "local", env.address, "sendWithReply-unserializable-error") try { val f = rpcEndpointRef.ask[String]("hello") - val e = intercept[Exception] { + intercept[TimeoutException] { Await.result(f, 1 seconds) } - assert(e.isInstanceOf[TimeoutException] || // For Akka - e.isInstanceOf[NotSerializableException] // For Netty - ) } finally { anotherEnv.shutdown() anotherEnv.awaitTermination() } } - test("port conflict") { - val anotherEnv = createRpcEnv(new SparkConf(), "remote", env.address.port) - assert(anotherEnv.address.port != env.address.port) - } - - test("send with authentication") { - val conf = new SparkConf - conf.set("spark.authenticate", "true") - conf.set("spark.authenticate.secret", "good") - - val localEnv = createRpcEnv(conf, "authentication-local", 13345) - val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345) - - try { - @volatile var message: String = null - localEnv.setupEndpoint("send-authentication", new RpcEndpoint { - override val rpcEnv = localEnv - - override def receive: PartialFunction[Any, Unit] = { - case msg: String => message = msg - } - }) - val rpcEndpointRef = - remoteEnv.setupEndpointRef("authentication-local", localEnv.address, "send-authentication") - rpcEndpointRef.send("hello") - eventually(timeout(5 seconds), interval(10 millis)) { - assert("hello" === message) - } - } finally { - localEnv.shutdown() - localEnv.awaitTermination() - remoteEnv.shutdown() - remoteEnv.awaitTermination() - } - } - - test("ask with authentication") { - val conf = new SparkConf - conf.set("spark.authenticate", "true") - conf.set("spark.authenticate.secret", "good") - - val localEnv = createRpcEnv(conf, "authentication-local", 13345) - val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345) - - try { - localEnv.setupEndpoint("ask-authentication", new RpcEndpoint { - override val rpcEnv = localEnv - - override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case msg: String => { - context.reply(msg) - } - } - }) - val rpcEndpointRef = - remoteEnv.setupEndpointRef("authentication-local", localEnv.address, "ask-authentication") - val reply = rpcEndpointRef.askWithRetry[String]("hello") - assert("hello" === reply) - } finally { - localEnv.shutdown() - localEnv.awaitTermination() - remoteEnv.shutdown() - remoteEnv.awaitTermination() - } - } - test("construct RpcTimeout with conf property") { val conf = new SparkConf @@ -684,7 +612,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { // once the future is complete to verify addMessageIfTimeout was invoked val reply3 = intercept[RpcTimeoutException] { - Await.result(fut3, 2000 millis) + Await.result(fut3, 200 millis) }.getMessage // When the future timed out, the recover callback should have used diff --git a/core/src/test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala b/core/src/test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala deleted file mode 100644 index 5e8da3e205..0000000000 --- a/core/src/test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rpc - -import scala.collection.mutable.ArrayBuffer - -import org.scalactic.TripleEquals - -class TestRpcEndpoint extends ThreadSafeRpcEndpoint with TripleEquals { - - override val rpcEnv: RpcEnv = null - - @volatile private var receiveMessages = ArrayBuffer[Any]() - - @volatile private var receiveAndReplyMessages = ArrayBuffer[Any]() - - @volatile private var onConnectedMessages = ArrayBuffer[RpcAddress]() - - @volatile private var onDisconnectedMessages = ArrayBuffer[RpcAddress]() - - @volatile private var onNetworkErrorMessages = ArrayBuffer[(Throwable, RpcAddress)]() - - @volatile private var started = false - - @volatile private var stopped = false - - override def receive: PartialFunction[Any, Unit] = { - case message: Any => receiveMessages += message - } - - override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case message: Any => receiveAndReplyMessages += message - } - - override def onConnected(remoteAddress: RpcAddress): Unit = { - onConnectedMessages += remoteAddress - } - - /** - * Invoked when some network error happens in the connection between the current node and - * `remoteAddress`. - */ - override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { - onNetworkErrorMessages += cause -> remoteAddress - } - - override def onDisconnected(remoteAddress: RpcAddress): Unit = { - onDisconnectedMessages += remoteAddress - } - - def numReceiveMessages: Int = receiveMessages.size - - override def onStart(): Unit = { - started = true - } - - override def onStop(): Unit = { - stopped = true - } - - def verifyStarted(): Unit = { - assert(started, "RpcEndpoint is not started") - } - - def verifyStopped(): Unit = { - assert(stopped, "RpcEndpoint is not stopped") - } - - def verifyReceiveMessages(expected: Seq[Any]): Unit = { - assert(receiveMessages === expected) - } - - def verifySingleReceiveMessage(message: Any): Unit = { - verifyReceiveMessages(List(message)) - } - - def verifyReceiveAndReplyMessages(expected: Seq[Any]): Unit = { - assert(receiveAndReplyMessages === expected) - } - - def verifySingleReceiveAndReplyMessage(message: Any): Unit = { - verifyReceiveAndReplyMessages(List(message)) - } - - def verifySingleOnConnectedMessage(remoteAddress: RpcAddress): Unit = { - verifyOnConnectedMessages(List(remoteAddress)) - } - - def verifyOnConnectedMessages(expected: Seq[RpcAddress]): Unit = { - assert(onConnectedMessages === expected) - } - - def verifySingleOnDisconnectedMessage(remoteAddress: RpcAddress): Unit = { - verifyOnDisconnectedMessages(List(remoteAddress)) - } - - def verifyOnDisconnectedMessages(expected: Seq[RpcAddress]): Unit = { - assert(onDisconnectedMessages === expected) - } - - def verifySingleOnNetworkErrorMessage(cause: Throwable, remoteAddress: RpcAddress): Unit = { - verifyOnNetworkErrorMessages(List(cause -> remoteAddress)) - } - - def verifyOnNetworkErrorMessages(expected: Seq[(Throwable, RpcAddress)]): Unit = { - assert(onNetworkErrorMessages === expected) - } -} diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala deleted file mode 100644 index ff83ab9b32..0000000000 --- a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala +++ /dev/null @@ -1,148 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rpc.netty - -import java.util.concurrent.{CountDownLatch, TimeUnit} -import java.util.concurrent.atomic.AtomicInteger - -import org.mockito.Mockito._ - -import org.apache.spark.SparkFunSuite -import org.apache.spark.rpc.{RpcEnv, RpcEndpoint, RpcAddress, TestRpcEndpoint} - -class InboxSuite extends SparkFunSuite { - - test("post") { - val endpoint = new TestRpcEndpoint - val endpointRef = mock(classOf[NettyRpcEndpointRef]) - when(endpointRef.name).thenReturn("hello") - - val dispatcher = mock(classOf[Dispatcher]) - - val inbox = new Inbox(endpointRef, endpoint) - val message = ContentMessage(null, "hi", false, null) - inbox.post(message) - inbox.process(dispatcher) - assert(inbox.isEmpty) - - endpoint.verifySingleReceiveMessage("hi") - - inbox.stop() - inbox.process(dispatcher) - assert(inbox.isEmpty) - endpoint.verifyStarted() - endpoint.verifyStopped() - } - - test("post: with reply") { - val endpoint = new TestRpcEndpoint - val endpointRef = mock(classOf[NettyRpcEndpointRef]) - val dispatcher = mock(classOf[Dispatcher]) - - val inbox = new Inbox(endpointRef, endpoint) - val message = ContentMessage(null, "hi", true, null) - inbox.post(message) - inbox.process(dispatcher) - assert(inbox.isEmpty) - - endpoint.verifySingleReceiveAndReplyMessage("hi") - } - - test("post: multiple threads") { - val endpoint = new TestRpcEndpoint - val endpointRef = mock(classOf[NettyRpcEndpointRef]) - when(endpointRef.name).thenReturn("hello") - - val dispatcher = mock(classOf[Dispatcher]) - - val numDroppedMessages = new AtomicInteger(0) - val inbox = new Inbox(endpointRef, endpoint) { - override def onDrop(message: InboxMessage): Unit = { - numDroppedMessages.incrementAndGet() - } - } - - val exitLatch = new CountDownLatch(10) - - for (_ <- 0 until 10) { - new Thread { - override def run(): Unit = { - for (_ <- 0 until 100) { - val message = ContentMessage(null, "hi", false, null) - inbox.post(message) - } - exitLatch.countDown() - } - }.start() - } - inbox.process(dispatcher) - assert(inbox.isEmpty) - inbox.stop() - inbox.process(dispatcher) - assert(inbox.isEmpty) - - exitLatch.await(30, TimeUnit.SECONDS) - - assert(1000 === endpoint.numReceiveMessages + numDroppedMessages.get) - endpoint.verifyStarted() - endpoint.verifyStopped() - } - - test("post: Associated") { - val endpoint = new TestRpcEndpoint - val endpointRef = mock(classOf[NettyRpcEndpointRef]) - val dispatcher = mock(classOf[Dispatcher]) - - val remoteAddress = RpcAddress("localhost", 11111) - - val inbox = new Inbox(endpointRef, endpoint) - inbox.post(Associated(remoteAddress)) - inbox.process(dispatcher) - - endpoint.verifySingleOnConnectedMessage(remoteAddress) - } - - test("post: Disassociated") { - val endpoint = new TestRpcEndpoint - val endpointRef = mock(classOf[NettyRpcEndpointRef]) - val dispatcher = mock(classOf[Dispatcher]) - - val remoteAddress = RpcAddress("localhost", 11111) - - val inbox = new Inbox(endpointRef, endpoint) - inbox.post(Disassociated(remoteAddress)) - inbox.process(dispatcher) - - endpoint.verifySingleOnDisconnectedMessage(remoteAddress) - } - - test("post: AssociationError") { - val endpoint = new TestRpcEndpoint - val endpointRef = mock(classOf[NettyRpcEndpointRef]) - val dispatcher = mock(classOf[Dispatcher]) - - val remoteAddress = RpcAddress("localhost", 11111) - val cause = new RuntimeException("Oops") - - val inbox = new Inbox(endpointRef, endpoint) - inbox.post(AssociationError(cause, remoteAddress)) - inbox.process(dispatcher) - - endpoint.verifySingleOnNetworkErrorMessage(cause, remoteAddress) - } -} diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala deleted file mode 100644 index a5d43d3704..0000000000 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rpc.netty - -import org.apache.spark.SparkFunSuite - -class NettyRpcAddressSuite extends SparkFunSuite { - - test("toString") { - val addr = NettyRpcAddress("localhost", 12345, "test") - assert(addr.toString === "spark://test@localhost:12345") - } - -} diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala deleted file mode 100644 index be19668e17..0000000000 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rpc.netty - -import org.apache.spark.{SecurityManager, SparkConf} -import org.apache.spark.rpc._ - -class NettyRpcEnvSuite extends RpcEnvSuite { - - override def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv = { - val config = RpcEnvConfig(conf, "test", "localhost", port, new SecurityManager(conf)) - new NettyRpcEnvFactory().create(config) - } - - test("non-existent endpoint") { - val uri = env.uriOf("test", env.address, "nonexist-endpoint") - val e = intercept[RpcEndpointNotFoundException] { - env.setupEndpointRef("test", env.address, "nonexist-endpoint") - } - assert(e.getMessage.contains(uri)) - } - -} diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala deleted file mode 100644 index 06ca035d19..0000000000 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rpc.netty - -import java.net.InetSocketAddress - -import io.netty.channel.Channel -import org.mockito.Mockito._ -import org.mockito.Matchers._ - -import org.apache.spark.SparkFunSuite -import org.apache.spark.network.client.{TransportResponseHandler, TransportClient} -import org.apache.spark.rpc._ - -class NettyRpcHandlerSuite extends SparkFunSuite { - - val env = mock(classOf[NettyRpcEnv]) - when(env.deserialize(any(classOf[Array[Byte]]))(any())). - thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false)) - - test("receive") { - val dispatcher = mock(classOf[Dispatcher]) - val nettyRpcHandler = new NettyRpcHandler(dispatcher, env) - - val channel = mock(classOf[Channel]) - val client = new TransportClient(channel, mock(classOf[TransportResponseHandler])) - when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) - nettyRpcHandler.receive(client, null, null) - - when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40001)) - nettyRpcHandler.receive(client, null, null) - - verify(dispatcher, times(1)).broadcastMessage(Associated(RpcAddress("localhost", 12345))) - } - - test("connectionTerminated") { - val dispatcher = mock(classOf[Dispatcher]) - val nettyRpcHandler = new NettyRpcHandler(dispatcher, env) - - val channel = mock(classOf[Channel]) - val client = new TransportClient(channel, mock(classOf[TransportResponseHandler])) - when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) - nettyRpcHandler.receive(client, null, null) - - when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) - nettyRpcHandler.connectionTerminated(client) - - verify(dispatcher, times(1)).broadcastMessage(Associated(RpcAddress("localhost", 12345))) - verify(dispatcher, times(1)).broadcastMessage(Disassociated(RpcAddress("localhost", 12345))) - } - -} diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index fbb8bb6b2f..df841288a0 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -78,10 +78,6 @@ public class TransportClient implements Closeable { this.handler = Preconditions.checkNotNull(handler); } - public Channel getChannel() { - return channel; - } - public boolean isActive() { return channel.isOpen() || channel.isActive(); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java index dbb7f95f55..2ba92a40f8 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -52,6 +52,4 @@ public abstract class RpcHandler { * No further requests will come from this client. */ public void connectionTerminated(TransportClient client) { } - - public void exceptionCaught(Throwable cause, TransportClient client) { } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 96941d26be..df60278058 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -71,7 +71,6 @@ public class TransportRequestHandler extends MessageHandler { @Override public void exceptionCaught(Throwable cause) { - rpcHandler.exceptionCaught(cause, reverseClient); } @Override diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index d053e9e849..204e6142fd 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -474,7 +474,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // Remote messages case RegisterReceiver(streamId, typ, hostPort, receiverEndpoint) => val successful = - registerReceiver(streamId, typ, hostPort, receiverEndpoint, context.senderAddress) + registerReceiver(streamId, typ, hostPort, receiverEndpoint, context.sender.address) context.reply(successful) case AddBlock(receivedBlockInfo) => context.reply(addBlock(receivedBlockInfo)) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 32d218169a..93621b44c9 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -556,7 +556,10 @@ private[spark] class ApplicationMaster( override val rpcEnv: RpcEnv, driver: RpcEndpointRef, isClusterMode: Boolean) extends RpcEndpoint with Logging { - driver.send(RegisterClusterManager(self)) + override def onStart(): Unit = { + driver.send(RegisterClusterManager(self)) + + } override def receive: PartialFunction[Any, Unit] = { case x: AddWebUIFilter => -- cgit v1.2.3