From 107320c9bbfe2496963a4e75e60fd6ba7fbfbabc Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sat, 3 Oct 2015 01:04:35 -0700 Subject: [SPARK-6028] [CORE] Remerge #6457: new RPC implemetation and also pick #8905 This PR just reverted https://github.com/apache/spark/commit/02144d6745ec0a6d8877d969feb82139bd22437f to remerge #6457 and also included the commits in #8905. Author: zsxwing Closes #8944 from zsxwing/SPARK-6028. --- .../scala/org/apache/spark/MapOutputTracker.scala | 2 +- .../src/main/scala/org/apache/spark/SparkEnv.scala | 20 +- .../org/apache/spark/deploy/worker/Worker.scala | 2 +- .../apache/spark/deploy/worker/WorkerWatcher.scala | 13 +- .../org/apache/spark/rpc/RpcCallContext.scala | 2 +- .../scala/org/apache/spark/rpc/RpcEndpoint.scala | 51 ++- .../spark/rpc/RpcEndpointNotFoundException.scala | 22 + .../main/scala/org/apache/spark/rpc/RpcEnv.scala | 5 +- .../org/apache/spark/rpc/akka/AkkaRpcEnv.scala | 6 +- .../org/apache/spark/rpc/netty/Dispatcher.scala | 218 +++++++++ .../org/apache/spark/rpc/netty/IDVerifier.scala | 39 ++ .../scala/org/apache/spark/rpc/netty/Inbox.scala | 227 ++++++++++ .../apache/spark/rpc/netty/NettyRpcAddress.scala | 56 +++ .../spark/rpc/netty/NettyRpcCallContext.scala | 87 ++++ .../org/apache/spark/rpc/netty/NettyRpcEnv.scala | 504 +++++++++++++++++++++ .../spark/storage/BlockManagerSlaveEndpoint.scala | 6 +- .../scala/org/apache/spark/util/ThreadUtils.scala | 6 +- .../org/apache/spark/MapOutputTrackerSuite.scala | 10 +- .../scala/org/apache/spark/SSLSampleConfigs.scala | 2 + .../spark/deploy/worker/WorkerWatcherSuite.scala | 6 +- .../scala/org/apache/spark/rpc/RpcEnvSuite.scala | 81 +++- .../org/apache/spark/rpc/TestRpcEndpoint.scala | 123 +++++ .../org/apache/spark/rpc/netty/InboxSuite.scala | 150 ++++++ .../spark/rpc/netty/NettyRpcAddressSuite.scala | 29 ++ .../apache/spark/rpc/netty/NettyRpcEnvSuite.scala | 38 ++ .../spark/rpc/netty/NettyRpcHandlerSuite.scala | 67 +++ .../spark/network/client/TransportClient.java | 4 + .../apache/spark/network/server/RpcHandler.java | 2 + .../network/server/TransportRequestHandler.java | 1 + .../streaming/scheduler/ReceiverTracker.scala | 2 +- .../spark/deploy/yarn/ApplicationMaster.scala | 5 +- 31 files changed, 1715 insertions(+), 71 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/rpc/RpcEndpointNotFoundException.scala create mode 100644 core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala create mode 100644 core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala create mode 100644 core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala create mode 100644 core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala create mode 100644 core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala create mode 100644 core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala create mode 100644 core/src/test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala create mode 100644 core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index e4cb72e39d..45e12e40c8 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -45,7 +45,7 @@ private[spark] class MapOutputTrackerMasterEndpoint( override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case GetMapOutputStatuses(shuffleId: Int) => - val hostPort = context.sender.address.hostPort + val hostPort = context.senderAddress.hostPort logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort) val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId) val serializedSize = mapOutputStatuses.size diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index c6fef7f91f..cfde27fb2e 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -20,11 +20,10 @@ package org.apache.spark import java.io.File import java.net.Socket -import akka.actor.ActorSystem - import scala.collection.mutable import scala.util.Properties +import akka.actor.ActorSystem import com.google.common.collect.MapMaker import org.apache.spark.annotation.DeveloperApi @@ -41,7 +40,7 @@ import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} import org.apache.spark.storage._ import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator} -import org.apache.spark.util.{RpcUtils, Utils} +import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils} /** * :: DeveloperApi :: @@ -57,6 +56,7 @@ import org.apache.spark.util.{RpcUtils, Utils} class SparkEnv ( val executorId: String, private[spark] val rpcEnv: RpcEnv, + _actorSystem: ActorSystem, // TODO Remove actorSystem val serializer: Serializer, val closureSerializer: Serializer, val cacheManager: CacheManager, @@ -76,7 +76,7 @@ class SparkEnv ( // TODO Remove actorSystem @deprecated("Actor system is no longer supported as of 1.4.0", "1.4.0") - val actorSystem: ActorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem + val actorSystem: ActorSystem = _actorSystem private[spark] var isStopped = false private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() @@ -100,6 +100,9 @@ class SparkEnv ( blockManager.master.stop() metricsSystem.stop() outputCommitCoordinator.stop() + if (!rpcEnv.isInstanceOf[AkkaRpcEnv]) { + actorSystem.shutdown() + } rpcEnv.shutdown() // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut @@ -249,7 +252,13 @@ object SparkEnv extends Logging { // Create the ActorSystem for Akka and get the port it binds to. val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName val rpcEnv = RpcEnv.create(actorSystemName, hostname, port, conf, securityManager) - val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem + val actorSystem: ActorSystem = + if (rpcEnv.isInstanceOf[AkkaRpcEnv]) { + rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem + } else { + // Create a ActorSystem for legacy codes + AkkaUtils.createActorSystem(actorSystemName, hostname, port, conf, securityManager)._1 + } // Figure out which port Akka actually bound to in case the original port is 0 or occupied. if (isDriver) { @@ -395,6 +404,7 @@ object SparkEnv extends Logging { val envInstance = new SparkEnv( executorId, rpcEnv, + actorSystem, serializer, closureSerializer, cacheManager, diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 770927c80f..93a1b3f310 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -329,7 +329,7 @@ private[deploy] class Worker( registrationRetryTimer = Some(forwordMessageScheduler.scheduleAtFixedRate( new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { - self.send(ReregisterWithMaster) + Option(self).foreach(_.send(ReregisterWithMaster)) } }, INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS, diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 735c4f0927..ab56fde938 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -24,14 +24,13 @@ import org.apache.spark.rpc._ * Actor which connects to a worker process and terminates the JVM if the connection is severed. * Provides fate sharing between a worker and its associated child processes. */ -private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: String) +private[spark] class WorkerWatcher( + override val rpcEnv: RpcEnv, workerUrl: String, isTesting: Boolean = false) extends RpcEndpoint with Logging { - override def onStart() { - logInfo(s"Connecting to worker $workerUrl") - if (!isTesting) { - rpcEnv.asyncSetupEndpointRefByURI(workerUrl) - } + logInfo(s"Connecting to worker $workerUrl") + if (!isTesting) { + rpcEnv.asyncSetupEndpointRefByURI(workerUrl) } // Used to avoid shutting down JVM during tests @@ -40,8 +39,6 @@ private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: Strin // true rather than calling `System.exit`. The user can check `isShutDown` to know if // `exitNonZero` is called. private[deploy] var isShutDown = false - private[deploy] def setTesting(testing: Boolean) = isTesting = testing - private var isTesting = false // Lets filter events only from the worker's rpc system private val expectedAddress = RpcAddress.fromURIString(workerUrl) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala index 3e5b64265e..f527ec86ab 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala @@ -37,5 +37,5 @@ private[spark] trait RpcCallContext { /** * The sender of this message. */ - def sender: RpcEndpointRef + def senderAddress: RpcAddress } diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala index dfcbc51cdf..f1ddc6d2cd 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala @@ -28,20 +28,6 @@ private[spark] trait RpcEnvFactory { def create(config: RpcEnvConfig): RpcEnv } -/** - * A trait that requires RpcEnv thread-safely sending messages to it. - * - * Thread-safety means processing of one message happens before processing of the next message by - * the same [[ThreadSafeRpcEndpoint]]. In the other words, changes to internal fields of a - * [[ThreadSafeRpcEndpoint]] are visible when processing the next message, and fields in the - * [[ThreadSafeRpcEndpoint]] need not be volatile or equivalent. - * - * However, there is no guarantee that the same thread will be executing the same - * [[ThreadSafeRpcEndpoint]] for different messages. - */ -private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint - - /** * An end point for the RPC that defines what functions to trigger given a message. * @@ -101,38 +87,39 @@ private[spark] trait RpcEndpoint { } /** - * Invoked before [[RpcEndpoint]] starts to handle any message. + * Invoked when `remoteAddress` is connected to the current node. */ - def onStart(): Unit = { + def onConnected(remoteAddress: RpcAddress): Unit = { // By default, do nothing. } /** - * Invoked when [[RpcEndpoint]] is stopping. + * Invoked when `remoteAddress` is lost. */ - def onStop(): Unit = { + def onDisconnected(remoteAddress: RpcAddress): Unit = { // By default, do nothing. } /** - * Invoked when `remoteAddress` is connected to the current node. + * Invoked when some network error happens in the connection between the current node and + * `remoteAddress`. */ - def onConnected(remoteAddress: RpcAddress): Unit = { + def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { // By default, do nothing. } /** - * Invoked when `remoteAddress` is lost. + * Invoked before [[RpcEndpoint]] starts to handle any message. */ - def onDisconnected(remoteAddress: RpcAddress): Unit = { + def onStart(): Unit = { // By default, do nothing. } /** - * Invoked when some network error happens in the connection between the current node and - * `remoteAddress`. + * Invoked when [[RpcEndpoint]] is stopping. `self` will be `null` in this method and you cannot + * use it to send or ask messages. */ - def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + def onStop(): Unit = { // By default, do nothing. } @@ -146,3 +133,17 @@ private[spark] trait RpcEndpoint { } } } + +/** + * A trait that requires RpcEnv thread-safely sending messages to it. + * + * Thread-safety means processing of one message happens before processing of the next message by + * the same [[ThreadSafeRpcEndpoint]]. In the other words, changes to internal fields of a + * [[ThreadSafeRpcEndpoint]] are visible when processing the next message, and fields in the + * [[ThreadSafeRpcEndpoint]] need not be volatile or equivalent. + * + * However, there is no guarantee that the same thread will be executing the same + * [[ThreadSafeRpcEndpoint]] for different messages. + */ +private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint { +} diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointNotFoundException.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointNotFoundException.scala new file mode 100644 index 0000000000..d177881fb3 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointNotFoundException.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.rpc + +import org.apache.spark.SparkException + +private[rpc] class RpcEndpointNotFoundException(uri: String) + extends SparkException(s"Cannot find endpoint: $uri") diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 29debe8081..35e402c725 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -36,8 +36,9 @@ private[spark] object RpcEnv { private def getRpcEnvFactory(conf: SparkConf): RpcEnvFactory = { // Add more RpcEnv implementations here - val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory") - val rpcEnvName = conf.get("spark.rpc", "akka") + val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory", + "netty" -> "org.apache.spark.rpc.netty.NettyRpcEnvFactory") + val rpcEnvName = conf.get("spark.rpc", "netty") val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName) Utils.classForName(rpcEnvFactoryClassName).newInstance().asInstanceOf[RpcEnvFactory] } diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index ad67e1c5ad..95132a4e4a 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -166,9 +166,9 @@ private[spark] class AkkaRpcEnv private[akka] ( _sender ! AkkaMessage(response, false) } - // Some RpcEndpoints need to know the sender's address - override val sender: RpcEndpointRef = - new AkkaRpcEndpointRef(defaultAddress, _sender, conf) + // Use "lazy" because most of RpcEndpoints don't need "senderAddress" + override lazy val senderAddress: RpcAddress = + new AkkaRpcEndpointRef(defaultAddress, _sender, conf).address }) } else { endpoint.receive diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala new file mode 100644 index 0000000000..d71e6f01db --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, TimeUnit} +import javax.annotation.concurrent.GuardedBy + +import scala.collection.JavaConverters._ +import scala.concurrent.Promise +import scala.util.control.NonFatal + +import org.apache.spark.{SparkException, Logging} +import org.apache.spark.network.client.RpcResponseCallback +import org.apache.spark.rpc._ +import org.apache.spark.util.ThreadUtils + +private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { + + private class EndpointData( + val name: String, + val endpoint: RpcEndpoint, + val ref: NettyRpcEndpointRef) { + val inbox = new Inbox(ref, endpoint) + } + + private val endpoints = new ConcurrentHashMap[String, EndpointData]() + private val endpointRefs = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]() + + // Track the receivers whose inboxes may contain messages. + private val receivers = new LinkedBlockingQueue[EndpointData]() + + @GuardedBy("this") + private var stopped = false + + def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = { + val addr = new NettyRpcAddress(nettyEnv.address.host, nettyEnv.address.port, name) + val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv) + synchronized { + if (stopped) { + throw new IllegalStateException("RpcEnv has been stopped") + } + if (endpoints.putIfAbsent(name, new EndpointData(name, endpoint, endpointRef)) != null) { + throw new IllegalArgumentException(s"There is already an RpcEndpoint called $name") + } + val data = endpoints.get(name) + endpointRefs.put(data.endpoint, data.ref) + receivers.put(data) + } + endpointRef + } + + def getRpcEndpointRef(endpoint: RpcEndpoint): RpcEndpointRef = endpointRefs.get(endpoint) + + def removeRpcEndpointRef(endpoint: RpcEndpoint): Unit = endpointRefs.remove(endpoint) + + // Should be idempotent + private def unregisterRpcEndpoint(name: String): Unit = { + val data = endpoints.remove(name) + if (data != null) { + data.inbox.stop() + receivers.put(data) + } + // Don't clean `endpointRefs` here because it's possible that some messages are being processed + // now and they can use `getRpcEndpointRef`. So `endpointRefs` will be cleaned in Inbox via + // `removeRpcEndpointRef`. + } + + def stop(rpcEndpointRef: RpcEndpointRef): Unit = { + synchronized { + if (stopped) { + // This endpoint will be stopped by Dispatcher.stop() method. + return + } + unregisterRpcEndpoint(rpcEndpointRef.name) + } + } + + /** + * Send a message to all registered [[RpcEndpoint]]s. + * @param message + */ + def broadcastMessage(message: InboxMessage): Unit = { + val iter = endpoints.keySet().iterator() + while (iter.hasNext) { + val name = iter.next + postMessageToInbox(name, (_) => message, + () => { logWarning(s"Drop ${message} because ${name} has been stopped") }) + } + } + + def postMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = { + def createMessage(sender: NettyRpcEndpointRef): InboxMessage = { + val rpcCallContext = + new RemoteNettyRpcCallContext( + nettyEnv, sender, callback, message.senderAddress, message.needReply) + ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext) + } + + def onEndpointStopped(): Unit = { + callback.onFailure( + new SparkException(s"Could not find ${message.receiver.name} or it has been stopped")) + } + + postMessageToInbox(message.receiver.name, createMessage, onEndpointStopped) + } + + def postMessage(message: RequestMessage, p: Promise[Any]): Unit = { + def createMessage(sender: NettyRpcEndpointRef): InboxMessage = { + val rpcCallContext = + new LocalNettyRpcCallContext(sender, message.senderAddress, message.needReply, p) + ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext) + } + + def onEndpointStopped(): Unit = { + p.tryFailure( + new SparkException(s"Could not find ${message.receiver.name} or it has been stopped")) + } + + postMessageToInbox(message.receiver.name, createMessage, onEndpointStopped) + } + + private def postMessageToInbox( + endpointName: String, + createMessageFn: NettyRpcEndpointRef => InboxMessage, + onStopped: () => Unit): Unit = { + val shouldCallOnStop = + synchronized { + val data = endpoints.get(endpointName) + if (stopped || data == null) { + true + } else { + data.inbox.post(createMessageFn(data.ref)) + receivers.put(data) + false + } + } + if (shouldCallOnStop) { + // We don't need to call `onStop` in the `synchronized` block + onStopped() + } + } + + private val parallelism = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.parallelism", + Runtime.getRuntime.availableProcessors()) + + private val executor = ThreadUtils.newDaemonFixedThreadPool(parallelism, "dispatcher-event-loop") + + (0 until parallelism) foreach { _ => + executor.execute(new MessageLoop) + } + + def stop(): Unit = { + synchronized { + if (stopped) { + return + } + stopped = true + } + // Stop all endpoints. This will queue all endpoints for processing by the message loops. + endpoints.keySet().asScala.foreach(unregisterRpcEndpoint) + // Enqueue a message that tells the message loops to stop. + receivers.put(PoisonEndpoint) + executor.shutdown() + } + + def awaitTermination(): Unit = { + executor.awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS) + } + + /** + * Return if the endpoint exists + */ + def verify(name: String): Boolean = { + endpoints.containsKey(name) + } + + private class MessageLoop extends Runnable { + override def run(): Unit = { + try { + while (true) { + try { + val data = receivers.take() + if (data == PoisonEndpoint) { + // Put PoisonEndpoint back so that other MessageLoops can see it. + receivers.put(PoisonEndpoint) + return + } + data.inbox.process(Dispatcher.this) + } catch { + case NonFatal(e) => logError(e.getMessage, e) + } + } + } catch { + case ie: InterruptedException => // exit + } + } + } + + /** + * A poison endpoint that indicates MessageLoop should exit its loop. + */ + private val PoisonEndpoint = new EndpointData(null, null, null) +} diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala b/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala new file mode 100644 index 0000000000..6061c9b8de --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.rpc.netty + +import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv} + +/** + * A message used to ask the remote [[IDVerifier]] if an [[RpcEndpoint]] exists + */ +private[netty] case class ID(name: String) + +/** + * An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if a [[RpcEndpoint]] exists in this [[RpcEnv]] + */ +private[netty] class IDVerifier( + override val rpcEnv: RpcEnv, dispatcher: Dispatcher) extends RpcEndpoint { + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case ID(name) => context.reply(dispatcher.verify(name)) + } +} + +private[netty] object IDVerifier { + val NAME = "id-verifier" +} diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala new file mode 100644 index 0000000000..b669f59a28 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import java.util.LinkedList +import javax.annotation.concurrent.GuardedBy + +import scala.util.control.NonFatal + +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, ThreadSafeRpcEndpoint} + +private[netty] sealed trait InboxMessage + +private[netty] case class ContentMessage( + senderAddress: RpcAddress, + content: Any, + needReply: Boolean, + context: NettyRpcCallContext) extends InboxMessage + +private[netty] case object OnStart extends InboxMessage + +private[netty] case object OnStop extends InboxMessage + +/** + * A broadcast message that indicates connecting to a remote node. + */ +private[netty] case class Associated(remoteAddress: RpcAddress) extends InboxMessage + +/** + * A broadcast message that indicates a remote connection is lost. + */ +private[netty] case class Disassociated(remoteAddress: RpcAddress) extends InboxMessage + +/** + * A broadcast message that indicates a network error + */ +private[netty] case class AssociationError(cause: Throwable, remoteAddress: RpcAddress) + extends InboxMessage + +/** + * A inbox that stores messages for an [[RpcEndpoint]] and posts messages to it thread-safely. + * @param endpointRef + * @param endpoint + */ +private[netty] class Inbox( + val endpointRef: NettyRpcEndpointRef, + val endpoint: RpcEndpoint) extends Logging { + + inbox => + + @GuardedBy("this") + protected val messages = new LinkedList[InboxMessage]() + + @GuardedBy("this") + private var stopped = false + + @GuardedBy("this") + private var enableConcurrent = false + + @GuardedBy("this") + private var workerCount = 0 + + // OnStart should be the first message to process + inbox.synchronized { + messages.add(OnStart) + } + + /** + * Process stored messages. + */ + def process(dispatcher: Dispatcher): Unit = { + var message: InboxMessage = null + inbox.synchronized { + if (!enableConcurrent && workerCount != 0) { + return + } + message = messages.poll() + if (message != null) { + workerCount += 1 + } else { + return + } + } + while (true) { + safelyCall(endpoint) { + message match { + case ContentMessage(_sender, content, needReply, context) => + val pf: PartialFunction[Any, Unit] = + if (needReply) { + endpoint.receiveAndReply(context) + } else { + endpoint.receive + } + try { + pf.applyOrElse[Any, Unit](content, { msg => + throw new SparkException(s"Unmatched message $message from ${_sender}") + }) + if (!needReply) { + context.finish() + } + } catch { + case NonFatal(e) => + if (needReply) { + // If the sender asks a reply, we should send the error back to the sender + context.sendFailure(e) + } else { + context.finish() + throw e + } + } + + case OnStart => { + endpoint.onStart() + if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) { + inbox.synchronized { + if (!stopped) { + enableConcurrent = true + } + } + } + } + + case OnStop => + val _workCount = inbox.synchronized { + workerCount + } + assert(_workCount == 1, s"There should be only one worker but was ${_workCount}") + dispatcher.removeRpcEndpointRef(endpoint) + endpoint.onStop() + assert(isEmpty, "OnStop should be the last message") + + case Associated(remoteAddress) => + endpoint.onConnected(remoteAddress) + + case Disassociated(remoteAddress) => + endpoint.onDisconnected(remoteAddress) + + case AssociationError(cause, remoteAddress) => + endpoint.onNetworkError(cause, remoteAddress) + } + } + + inbox.synchronized { + // "enableConcurrent" will be set to false after `onStop` is called, so we should check it + // every time. + if (!enableConcurrent && workerCount != 1) { + // If we are not the only one worker, exit + workerCount -= 1 + return + } + message = messages.poll() + if (message == null) { + workerCount -= 1 + return + } + } + } + } + + def post(message: InboxMessage): Unit = { + val dropped = + inbox.synchronized { + if (stopped) { + // We already put "OnStop" into "messages", so we should drop further messages + true + } else { + messages.add(message) + false + } + } + if (dropped) { + onDrop(message) + } + } + + def stop(): Unit = inbox.synchronized { + // The following codes should be in `synchronized` so that we can make sure "OnStop" is the last + // message + if (!stopped) { + // We should disable concurrent here. Then when RpcEndpoint.onStop is called, it's the only + // thread that is processing messages. So `RpcEndpoint.onStop` can release its resources + // safely. + enableConcurrent = false + stopped = true + messages.add(OnStop) + // Note: The concurrent events in messages will be processed one by one. + } + } + + // Visible for testing. + protected def onDrop(message: InboxMessage): Unit = { + logWarning(s"Drop ${message} because $endpointRef is stopped") + } + + def isEmpty: Boolean = inbox.synchronized { messages.isEmpty } + + private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = { + try { + action + } catch { + case NonFatal(e) => { + try { + endpoint.onError(e) + } catch { + case NonFatal(e) => logWarning(s"Ignore error", e) + } + } + } + } + +} diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala new file mode 100644 index 0000000000..1876b25592 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import java.net.URI + +import org.apache.spark.SparkException +import org.apache.spark.rpc.RpcAddress + +private[netty] case class NettyRpcAddress(host: String, port: Int, name: String) { + + def toRpcAddress: RpcAddress = RpcAddress(host, port) + + override val toString = s"spark://$name@$host:$port" +} + +private[netty] object NettyRpcAddress { + + def apply(sparkUrl: String): NettyRpcAddress = { + try { + val uri = new URI(sparkUrl) + val host = uri.getHost + val port = uri.getPort + val name = uri.getUserInfo + if (uri.getScheme != "spark" || + host == null || + port < 0 || + name == null || + (uri.getPath != null && !uri.getPath.isEmpty) || // uri.getPath returns "" instead of null + uri.getFragment != null || + uri.getQuery != null) { + throw new SparkException("Invalid Spark URL: " + sparkUrl) + } + NettyRpcAddress(host, port, name) + } catch { + case e: java.net.URISyntaxException => + throw new SparkException("Invalid Spark URL: " + sparkUrl, e) + } + } + +} diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala new file mode 100644 index 0000000000..75dcc02a0c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import scala.concurrent.Promise + +import org.apache.spark.Logging +import org.apache.spark.network.client.RpcResponseCallback +import org.apache.spark.rpc.{RpcAddress, RpcCallContext} + +private[netty] abstract class NettyRpcCallContext( + endpointRef: NettyRpcEndpointRef, + override val senderAddress: RpcAddress, + needReply: Boolean) extends RpcCallContext with Logging { + + protected def send(message: Any): Unit + + override def reply(response: Any): Unit = { + if (needReply) { + send(AskResponse(endpointRef, response)) + } else { + throw new IllegalStateException( + s"Cannot send $response to the sender because the sender won't handle it") + } + } + + override def sendFailure(e: Throwable): Unit = { + if (needReply) { + send(AskResponse(endpointRef, RpcFailure(e))) + } else { + logError(e.getMessage, e) + throw new IllegalStateException( + "Cannot send reply to the sender because the sender won't handle it") + } + } + + def finish(): Unit = { + if (!needReply) { + send(Ack(endpointRef)) + } + } +} + +/** + * If the sender and the receiver are in the same process, the reply can be sent back via `Promise`. + */ +private[netty] class LocalNettyRpcCallContext( + endpointRef: NettyRpcEndpointRef, + senderAddress: RpcAddress, + needReply: Boolean, + p: Promise[Any]) extends NettyRpcCallContext(endpointRef, senderAddress, needReply) { + + override protected def send(message: Any): Unit = { + p.success(message) + } +} + +/** + * A [[RpcCallContext]] that will call [[RpcResponseCallback]] to send the reply back. + */ +private[netty] class RemoteNettyRpcCallContext( + nettyEnv: NettyRpcEnv, + endpointRef: NettyRpcEndpointRef, + callback: RpcResponseCallback, + senderAddress: RpcAddress, + needReply: Boolean) extends NettyRpcCallContext(endpointRef, senderAddress, needReply) { + + override protected def send(message: Any): Unit = { + val reply = nettyEnv.serialize(message) + callback.onSuccess(reply) + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala new file mode 100644 index 0000000000..5522b40782 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -0,0 +1,504 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.rpc.netty + +import java.io._ +import java.net.{InetSocketAddress, URI} +import java.nio.ByteBuffer +import java.util.Arrays +import java.util.concurrent._ +import javax.annotation.concurrent.GuardedBy + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.concurrent.{Future, Promise} +import scala.reflect.ClassTag +import scala.util.{DynamicVariable, Failure, Success} +import scala.util.control.NonFatal + +import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.network.TransportContext +import org.apache.spark.network.client._ +import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap} +import org.apache.spark.network.server._ +import org.apache.spark.rpc._ +import org.apache.spark.serializer.{JavaSerializer, JavaSerializerInstance} +import org.apache.spark.util.{ThreadUtils, Utils} + +private[netty] class NettyRpcEnv( + val conf: SparkConf, + javaSerializerInstance: JavaSerializerInstance, + host: String, + securityManager: SecurityManager) extends RpcEnv(conf) with Logging { + + private val transportConf = + SparkTransportConf.fromSparkConf(conf, conf.getInt("spark.rpc.io.threads", 0)) + + private val dispatcher: Dispatcher = new Dispatcher(this) + + private val transportContext = + new TransportContext(transportConf, new NettyRpcHandler(dispatcher, this)) + + private val clientFactory = { + val bootstraps: Seq[TransportClientBootstrap] = + if (securityManager.isAuthenticationEnabled()) { + Seq(new SaslClientBootstrap(transportConf, "", securityManager, + securityManager.isSaslEncryptionEnabled())) + } else { + Nil + } + transportContext.createClientFactory(bootstraps.asJava) + } + + val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout") + + // Because TransportClientFactory.createClient is blocking, we need to run it in this thread pool + // to implement non-blocking send/ask. + // TODO: a non-blocking TransportClientFactory.createClient in future + private val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool( + "netty-rpc-connection", + conf.getInt("spark.rpc.connect.threads", 256)) + + @volatile private var server: TransportServer = _ + + def start(port: Int): Unit = { + val bootstraps: Seq[TransportServerBootstrap] = + if (securityManager.isAuthenticationEnabled()) { + Seq(new SaslServerBootstrap(transportConf, securityManager)) + } else { + Nil + } + server = transportContext.createServer(port, bootstraps.asJava) + dispatcher.registerRpcEndpoint(IDVerifier.NAME, new IDVerifier(this, dispatcher)) + } + + override lazy val address: RpcAddress = { + require(server != null, "NettyRpcEnv has not yet started") + RpcAddress(host, server.getPort()) + } + + override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { + dispatcher.registerRpcEndpoint(name, endpoint) + } + + def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = { + val addr = NettyRpcAddress(uri) + val endpointRef = new NettyRpcEndpointRef(conf, addr, this) + val idVerifierRef = + new NettyRpcEndpointRef(conf, NettyRpcAddress(addr.host, addr.port, IDVerifier.NAME), this) + idVerifierRef.ask[Boolean](ID(endpointRef.name)).flatMap { find => + if (find) { + Future.successful(endpointRef) + } else { + Future.failed(new RpcEndpointNotFoundException(uri)) + } + }(ThreadUtils.sameThread) + } + + override def stop(endpointRef: RpcEndpointRef): Unit = { + require(endpointRef.isInstanceOf[NettyRpcEndpointRef]) + dispatcher.stop(endpointRef) + } + + private[netty] def send(message: RequestMessage): Unit = { + val remoteAddr = message.receiver.address + if (remoteAddr == address) { + val promise = Promise[Any]() + dispatcher.postMessage(message, promise) + promise.future.onComplete { + case Success(response) => + val ack = response.asInstanceOf[Ack] + logDebug(s"Receive ack from ${ack.sender}") + case Failure(e) => + logError(s"Exception when sending $message", e) + }(ThreadUtils.sameThread) + } else { + try { + // `createClient` will block if it cannot find a known connection, so we should run it in + // clientConnectionExecutor + clientConnectionExecutor.execute(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + val client = clientFactory.createClient(remoteAddr.host, remoteAddr.port) + client.sendRpc(serialize(message), new RpcResponseCallback { + + override def onFailure(e: Throwable): Unit = { + logError(s"Exception when sending $message", e) + } + + override def onSuccess(response: Array[Byte]): Unit = { + val ack = deserialize[Ack](response) + logDebug(s"Receive ack from ${ack.sender}") + } + }) + } + }) + } catch { + case e: RejectedExecutionException => { + // `send` after shutting clientConnectionExecutor down, ignore it + logWarning(s"Cannot send ${message} because RpcEnv is stopped") + } + } + } + } + + private[netty] def ask(message: RequestMessage): Future[Any] = { + val promise = Promise[Any]() + val remoteAddr = message.receiver.address + if (remoteAddr == address) { + val p = Promise[Any]() + dispatcher.postMessage(message, p) + p.future.onComplete { + case Success(response) => + val reply = response.asInstanceOf[AskResponse] + if (reply.reply.isInstanceOf[RpcFailure]) { + if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) { + logWarning(s"Ignore failure: ${reply.reply}") + } + } else if (!promise.trySuccess(reply.reply)) { + logWarning(s"Ignore message: ${reply}") + } + case Failure(e) => + if (!promise.tryFailure(e)) { + logWarning("Ignore Exception", e) + } + }(ThreadUtils.sameThread) + } else { + try { + // `createClient` will block if it cannot find a known connection, so we should run it in + // clientConnectionExecutor + clientConnectionExecutor.execute(new Runnable { + override def run(): Unit = { + val client = clientFactory.createClient(remoteAddr.host, remoteAddr.port) + client.sendRpc(serialize(message), new RpcResponseCallback { + + override def onFailure(e: Throwable): Unit = { + if (!promise.tryFailure(e)) { + logWarning("Ignore Exception", e) + } + } + + override def onSuccess(response: Array[Byte]): Unit = { + val reply = deserialize[AskResponse](response) + if (reply.reply.isInstanceOf[RpcFailure]) { + if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) { + logWarning(s"Ignore failure: ${reply.reply}") + } + } else if (!promise.trySuccess(reply.reply)) { + logWarning(s"Ignore message: ${reply}") + } + } + }) + } + }) + } catch { + case e: RejectedExecutionException => { + if (!promise.tryFailure(e)) { + logWarning(s"Ignore failure", e) + } + } + } + } + promise.future + } + + private[netty] def serialize(content: Any): Array[Byte] = { + val buffer = javaSerializerInstance.serialize(content) + Arrays.copyOfRange( + buffer.array(), buffer.arrayOffset + buffer.position, buffer.arrayOffset + buffer.limit) + } + + private[netty] def deserialize[T: ClassTag](bytes: Array[Byte]): T = { + deserialize { () => + javaSerializerInstance.deserialize[T](ByteBuffer.wrap(bytes)) + } + } + + override def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = { + dispatcher.getRpcEndpointRef(endpoint) + } + + override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = + new NettyRpcAddress(address.host, address.port, endpointName).toString + + override def shutdown(): Unit = { + cleanup() + } + + override def awaitTermination(): Unit = { + dispatcher.awaitTermination() + } + + private def cleanup(): Unit = { + if (timeoutScheduler != null) { + timeoutScheduler.shutdownNow() + } + if (server != null) { + server.close() + } + if (clientFactory != null) { + clientFactory.close() + } + if (dispatcher != null) { + dispatcher.stop() + } + if (clientConnectionExecutor != null) { + clientConnectionExecutor.shutdownNow() + } + } + + override def deserialize[T](deserializationAction: () => T): T = { + NettyRpcEnv.currentEnv.withValue(this) { + deserializationAction() + } + } +} + +private[netty] object NettyRpcEnv extends Logging { + + /** + * When deserializing the [[NettyRpcEndpointRef]], it needs a reference to [[NettyRpcEnv]]. + * Use `currentEnv` to wrap the deserialization codes. E.g., + * + * {{{ + * NettyRpcEnv.currentEnv.withValue(this) { + * your deserialization codes + * } + * }}} + */ + private[netty] val currentEnv = new DynamicVariable[NettyRpcEnv](null) +} + +private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { + + def create(config: RpcEnvConfig): RpcEnv = { + val sparkConf = config.conf + // Use JavaSerializerInstance in multiple threads is safe. However, if we plan to support + // KryoSerializer in future, we have to use ThreadLocal to store SerializerInstance + val javaSerializerInstance = + new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance] + val nettyEnv = + new NettyRpcEnv(sparkConf, javaSerializerInstance, config.host, config.securityManager) + val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort => + nettyEnv.start(actualPort) + (nettyEnv, actualPort) + } + try { + Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, "NettyRpcEnv")._1 + } catch { + case NonFatal(e) => + nettyEnv.shutdown() + throw e + } + } +} + +private[netty] class NettyRpcEndpointRef(@transient conf: SparkConf) + extends RpcEndpointRef(conf) with Serializable with Logging { + + @transient @volatile private var nettyEnv: NettyRpcEnv = _ + + @transient @volatile private var _address: NettyRpcAddress = _ + + def this(conf: SparkConf, _address: NettyRpcAddress, nettyEnv: NettyRpcEnv) { + this(conf) + this._address = _address + this.nettyEnv = nettyEnv + } + + override def address: RpcAddress = _address.toRpcAddress + + private def readObject(in: ObjectInputStream): Unit = { + in.defaultReadObject() + _address = in.readObject().asInstanceOf[NettyRpcAddress] + nettyEnv = NettyRpcEnv.currentEnv.value + } + + private def writeObject(out: ObjectOutputStream): Unit = { + out.defaultWriteObject() + out.writeObject(_address) + } + + override def name: String = _address.name + + override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { + val promise = Promise[Any]() + val timeoutCancelable = nettyEnv.timeoutScheduler.schedule(new Runnable { + override def run(): Unit = { + promise.tryFailure(new TimeoutException("Cannot receive any reply in " + timeout.duration)) + } + }, timeout.duration.toNanos, TimeUnit.NANOSECONDS) + val f = nettyEnv.ask(RequestMessage(nettyEnv.address, this, message, true)) + f.onComplete { v => + timeoutCancelable.cancel(true) + if (!promise.tryComplete(v)) { + logWarning(s"Ignore message $v") + } + }(ThreadUtils.sameThread) + promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread) + } + + override def send(message: Any): Unit = { + require(message != null, "Message is null") + nettyEnv.send(RequestMessage(nettyEnv.address, this, message, false)) + } + + override def toString: String = s"NettyRpcEndpointRef(${_address})" + + def toURI: URI = new URI(s"spark://${_address}") + + final override def equals(that: Any): Boolean = that match { + case other: NettyRpcEndpointRef => _address == other._address + case _ => false + } + + final override def hashCode(): Int = if (_address == null) 0 else _address.hashCode() +} + +/** + * The message that is sent from the sender to the receiver. + */ +private[netty] case class RequestMessage( + senderAddress: RpcAddress, receiver: NettyRpcEndpointRef, content: Any, needReply: Boolean) + +/** + * The base trait for all messages that are sent back from the receiver to the sender. + */ +private[netty] trait ResponseMessage + +/** + * The reply for `ask` from the receiver side. + */ +private[netty] case class AskResponse(sender: NettyRpcEndpointRef, reply: Any) + extends ResponseMessage + +/** + * A message to send back to the receiver side. It's necessary because [[TransportClient]] only + * clean the resources when it receives a reply. + */ +private[netty] case class Ack(sender: NettyRpcEndpointRef) extends ResponseMessage + +/** + * A response that indicates some failure happens in the receiver side. + */ +private[netty] case class RpcFailure(e: Throwable) + +/** + * Maintain the mapping relations between client addresses and [[RpcEnv]] addresses, broadcast + * network events and forward messages to [[Dispatcher]]. + */ +private[netty] class NettyRpcHandler( + dispatcher: Dispatcher, nettyEnv: NettyRpcEnv) extends RpcHandler with Logging { + + private type ClientAddress = RpcAddress + private type RemoteEnvAddress = RpcAddress + + // Store all client addresses and their NettyRpcEnv addresses. + @GuardedBy("this") + private val remoteAddresses = new mutable.HashMap[ClientAddress, RemoteEnvAddress]() + + // Store the connections from other NettyRpcEnv addresses. We need to keep track of the connection + // count because `TransportClientFactory.createClient` will create multiple connections + // (at most `spark.shuffle.io.numConnectionsPerPeer` connections) and randomly select a connection + // to send the message. See `TransportClientFactory.createClient` for more details. + @GuardedBy("this") + private val remoteConnectionCount = new mutable.HashMap[RemoteEnvAddress, Int]() + + override def receive( + client: TransportClient, message: Array[Byte], callback: RpcResponseCallback): Unit = { + val requestMessage = nettyEnv.deserialize[RequestMessage](message) + val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] + assert(addr != null) + val remoteEnvAddress = requestMessage.senderAddress + val clientAddr = RpcAddress(addr.getHostName, addr.getPort) + val broadcastMessage = + synchronized { + // If the first connection to a remote RpcEnv is found, we should broadcast "Associated" + if (remoteAddresses.put(clientAddr, remoteEnvAddress).isEmpty) { + // clientAddr connects at the first time + val count = remoteConnectionCount.getOrElse(remoteEnvAddress, 0) + // Increase the connection number of remoteEnvAddress + remoteConnectionCount.put(remoteEnvAddress, count + 1) + if (count == 0) { + // This is the first connection, so fire "Associated" + Some(Associated(remoteEnvAddress)) + } else { + None + } + } else { + None + } + } + broadcastMessage.foreach(dispatcher.broadcastMessage) + dispatcher.postMessage(requestMessage, callback) + } + + override def getStreamManager: StreamManager = new OneForOneStreamManager + + override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = { + val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] + if (addr != null) { + val clientAddr = RpcAddress(addr.getHostName, addr.getPort) + val broadcastMessage = + synchronized { + remoteAddresses.get(clientAddr).map(AssociationError(cause, _)) + } + if (broadcastMessage.isEmpty) { + logError(cause.getMessage, cause) + } else { + dispatcher.broadcastMessage(broadcastMessage.get) + } + } else { + // If the channel is closed before connecting, its remoteAddress will be null. + // See java.net.Socket.getRemoteSocketAddress + // Because we cannot get a RpcAddress, just log it + logError("Exception before connecting to the client", cause) + } + } + + override def connectionTerminated(client: TransportClient): Unit = { + val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] + if (addr != null) { + val clientAddr = RpcAddress(addr.getHostName, addr.getPort) + val broadcastMessage = + synchronized { + // If the last connection to a remote RpcEnv is terminated, we should broadcast + // "Disassociated" + remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress => + remoteAddresses -= clientAddr + val count = remoteConnectionCount.getOrElse(remoteEnvAddress, 0) + assert(count != 0, "remoteAddresses and remoteConnectionCount are not consistent") + if (count - 1 == 0) { + // We lost all clients, so clean up and fire "Disassociated" + remoteConnectionCount.remove(remoteEnvAddress) + Some(Disassociated(remoteEnvAddress)) + } else { + // Decrease the connection number of remoteEnvAddress + remoteConnectionCount.put(remoteEnvAddress, count - 1) + None + } + } + } + broadcastMessage.foreach(dispatcher.broadcastMessage) + } else { + // If the channel is closed before connecting, its remoteAddress will be null. In this case, + // we can ignore it since we don't fire "Associated". + // See java.net.Socket.getRemoteSocketAddress + } + } + +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala index 7478ab0fc2..e749631bf6 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -19,7 +19,7 @@ package org.apache.spark.storage import scala.concurrent.{ExecutionContext, Future} -import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint} +import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext, RpcEndpoint} import org.apache.spark.util.ThreadUtils import org.apache.spark.{Logging, MapOutputTracker, SparkEnv} import org.apache.spark.storage.BlockManagerMessages._ @@ -33,7 +33,7 @@ class BlockManagerSlaveEndpoint( override val rpcEnv: RpcEnv, blockManager: BlockManager, mapOutputTracker: MapOutputTracker) - extends RpcEndpoint with Logging { + extends ThreadSafeRpcEndpoint with Logging { private val asyncThreadPool = ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool") @@ -80,7 +80,7 @@ class BlockManagerSlaveEndpoint( future.onSuccess { case response => logDebug("Done " + actionMessage + ", response is " + response) context.reply(response) - logDebug("Sent response: " + response + " to " + context.sender) + logDebug("Sent response: " + response + " to " + context.senderAddress) } future.onFailure { case t: Throwable => logError("Error in " + actionMessage, t) diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 22e291a2b4..1ed098379e 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -85,7 +85,11 @@ private[spark] object ThreadUtils { */ def newDaemonSingleThreadScheduledExecutor(threadName: String): ScheduledExecutorService = { val threadFactory = new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadName).build() - Executors.newSingleThreadScheduledExecutor(threadFactory) + val executor = new ScheduledThreadPoolExecutor(1, threadFactory) + // By default, a cancelled task is not automatically removed from the work queue until its delay + // elapses. We have to enable it manually. + executor.setRemoveOnCancelPolicy(true) + executor } /** diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index af4e68950f..7e70308bb3 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -168,10 +168,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.registerShuffle(10, 1) masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0))) - val sender = mock(classOf[RpcEndpointRef]) - when(sender.address).thenReturn(RpcAddress("localhost", 12345)) + val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) - when(rpcCallContext.sender).thenReturn(sender) + when(rpcCallContext.senderAddress).thenReturn(senderAddress) masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(10)) verify(rpcCallContext).reply(any()) verify(rpcCallContext, never()).sendFailure(any()) @@ -198,10 +197,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.registerMapOutput(20, i, new CompressedMapStatus( BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) } - val sender = mock(classOf[RpcEndpointRef]) - when(sender.address).thenReturn(RpcAddress("localhost", 12345)) + val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) - when(rpcCallContext.sender).thenReturn(sender) + when(rpcCallContext.senderAddress).thenReturn(senderAddress) masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(20)) verify(rpcCallContext, never()).reply(any()) verify(rpcCallContext).sendFailure(isA(classOf[SparkException])) diff --git a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala index 33270bec62..2d14249855 100644 --- a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala +++ b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala @@ -41,6 +41,7 @@ object SSLSampleConfigs { def sparkSSLConfig(): SparkConf = { val conf = new SparkConf(loadDefaults = false) + conf.set("spark.rpc", "akka") conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", keyStorePath) conf.set("spark.ssl.keyStorePassword", "password") @@ -54,6 +55,7 @@ object SSLSampleConfigs { def sparkSSLConfigUntrusted(): SparkConf = { val conf = new SparkConf(loadDefaults = false) + conf.set("spark.rpc", "akka") conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", untrustedKeyStorePath) conf.set("spark.ssl.keyStorePassword", "password") diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala index e9034e39a7..40c24bdecc 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala @@ -26,8 +26,7 @@ class WorkerWatcherSuite extends SparkFunSuite { val conf = new SparkConf() val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker") - val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) - workerWatcher.setTesting(testing = true) + val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl, isTesting = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) workerWatcher.onDisconnected(RpcAddress("1.2.3.4", 1234)) assert(workerWatcher.isShutDown) @@ -39,8 +38,7 @@ class WorkerWatcherSuite extends SparkFunSuite { val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker") val otherRpcAddress = RpcAddress("4.3.2.1", 1234) - val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) - workerWatcher.setTesting(testing = true) + val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl, isTesting = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) workerWatcher.onDisconnected(otherRpcAddress) assert(!workerWatcher.isShutDown) diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 6ceafe4337..3bead6395d 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.rpc +import java.io.NotSerializableException import java.util.concurrent.{TimeUnit, CountDownLatch, TimeoutException} import scala.collection.mutable @@ -99,7 +100,6 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } } val rpcEndpointRef = env.setupEndpoint("send-ref", endpoint) - val newRpcEndpointRef = rpcEndpointRef.askWithRetry[RpcEndpointRef]("Hello") val reply = newRpcEndpointRef.askWithRetry[String]("Echo") assert("Echo" === reply) @@ -328,9 +328,6 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { override def onStop(): Unit = { selfOption = Option(self) } - - override def onError(cause: Throwable): Unit = { - } }) env.stop(endpointRef) @@ -516,6 +513,9 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { assert(events === List( ("onConnected", remoteAddress), ("onNetworkError", remoteAddress), + ("onDisconnected", remoteAddress)) || + events === List( + ("onConnected", remoteAddress), ("onDisconnected", remoteAddress))) } } @@ -535,15 +535,84 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { "local", env.address, "sendWithReply-unserializable-error") try { val f = rpcEndpointRef.ask[String]("hello") - intercept[TimeoutException] { + val e = intercept[Exception] { Await.result(f, 1 seconds) } + assert(e.isInstanceOf[TimeoutException] || // For Akka + e.isInstanceOf[NotSerializableException] // For Netty + ) } finally { anotherEnv.shutdown() anotherEnv.awaitTermination() } } + test("port conflict") { + val anotherEnv = createRpcEnv(new SparkConf(), "remote", env.address.port) + assert(anotherEnv.address.port != env.address.port) + } + + test("send with authentication") { + val conf = new SparkConf + conf.set("spark.authenticate", "true") + conf.set("spark.authenticate.secret", "good") + + val localEnv = createRpcEnv(conf, "authentication-local", 13345) + val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345) + + try { + @volatile var message: String = null + localEnv.setupEndpoint("send-authentication", new RpcEndpoint { + override val rpcEnv = localEnv + + override def receive: PartialFunction[Any, Unit] = { + case msg: String => message = msg + } + }) + val rpcEndpointRef = + remoteEnv.setupEndpointRef("authentication-local", localEnv.address, "send-authentication") + rpcEndpointRef.send("hello") + eventually(timeout(5 seconds), interval(10 millis)) { + assert("hello" === message) + } + } finally { + localEnv.shutdown() + localEnv.awaitTermination() + remoteEnv.shutdown() + remoteEnv.awaitTermination() + } + } + + test("ask with authentication") { + val conf = new SparkConf + conf.set("spark.authenticate", "true") + conf.set("spark.authenticate.secret", "good") + + val localEnv = createRpcEnv(conf, "authentication-local", 13345) + val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345) + + try { + localEnv.setupEndpoint("ask-authentication", new RpcEndpoint { + override val rpcEnv = localEnv + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case msg: String => { + context.reply(msg) + } + } + }) + val rpcEndpointRef = + remoteEnv.setupEndpointRef("authentication-local", localEnv.address, "ask-authentication") + val reply = rpcEndpointRef.askWithRetry[String]("hello") + assert("hello" === reply) + } finally { + localEnv.shutdown() + localEnv.awaitTermination() + remoteEnv.shutdown() + remoteEnv.awaitTermination() + } + } + test("construct RpcTimeout with conf property") { val conf = new SparkConf @@ -612,7 +681,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { // once the future is complete to verify addMessageIfTimeout was invoked val reply3 = intercept[RpcTimeoutException] { - Await.result(fut3, 200 millis) + Await.result(fut3, 2000 millis) }.getMessage // When the future timed out, the recover callback should have used diff --git a/core/src/test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala b/core/src/test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala new file mode 100644 index 0000000000..5e8da3e205 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import scala.collection.mutable.ArrayBuffer + +import org.scalactic.TripleEquals + +class TestRpcEndpoint extends ThreadSafeRpcEndpoint with TripleEquals { + + override val rpcEnv: RpcEnv = null + + @volatile private var receiveMessages = ArrayBuffer[Any]() + + @volatile private var receiveAndReplyMessages = ArrayBuffer[Any]() + + @volatile private var onConnectedMessages = ArrayBuffer[RpcAddress]() + + @volatile private var onDisconnectedMessages = ArrayBuffer[RpcAddress]() + + @volatile private var onNetworkErrorMessages = ArrayBuffer[(Throwable, RpcAddress)]() + + @volatile private var started = false + + @volatile private var stopped = false + + override def receive: PartialFunction[Any, Unit] = { + case message: Any => receiveMessages += message + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case message: Any => receiveAndReplyMessages += message + } + + override def onConnected(remoteAddress: RpcAddress): Unit = { + onConnectedMessages += remoteAddress + } + + /** + * Invoked when some network error happens in the connection between the current node and + * `remoteAddress`. + */ + override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + onNetworkErrorMessages += cause -> remoteAddress + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + onDisconnectedMessages += remoteAddress + } + + def numReceiveMessages: Int = receiveMessages.size + + override def onStart(): Unit = { + started = true + } + + override def onStop(): Unit = { + stopped = true + } + + def verifyStarted(): Unit = { + assert(started, "RpcEndpoint is not started") + } + + def verifyStopped(): Unit = { + assert(stopped, "RpcEndpoint is not stopped") + } + + def verifyReceiveMessages(expected: Seq[Any]): Unit = { + assert(receiveMessages === expected) + } + + def verifySingleReceiveMessage(message: Any): Unit = { + verifyReceiveMessages(List(message)) + } + + def verifyReceiveAndReplyMessages(expected: Seq[Any]): Unit = { + assert(receiveAndReplyMessages === expected) + } + + def verifySingleReceiveAndReplyMessage(message: Any): Unit = { + verifyReceiveAndReplyMessages(List(message)) + } + + def verifySingleOnConnectedMessage(remoteAddress: RpcAddress): Unit = { + verifyOnConnectedMessages(List(remoteAddress)) + } + + def verifyOnConnectedMessages(expected: Seq[RpcAddress]): Unit = { + assert(onConnectedMessages === expected) + } + + def verifySingleOnDisconnectedMessage(remoteAddress: RpcAddress): Unit = { + verifyOnDisconnectedMessages(List(remoteAddress)) + } + + def verifyOnDisconnectedMessages(expected: Seq[RpcAddress]): Unit = { + assert(onDisconnectedMessages === expected) + } + + def verifySingleOnNetworkErrorMessage(cause: Throwable, remoteAddress: RpcAddress): Unit = { + verifyOnNetworkErrorMessages(List(cause -> remoteAddress)) + } + + def verifyOnNetworkErrorMessages(expected: Seq[(Throwable, RpcAddress)]): Unit = { + assert(onNetworkErrorMessages === expected) + } +} diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala new file mode 100644 index 0000000000..120cf1b6fa --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import java.util.concurrent.{CountDownLatch, TimeUnit} +import java.util.concurrent.atomic.AtomicInteger + +import org.mockito.Mockito._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.rpc.{RpcEnv, RpcEndpoint, RpcAddress, TestRpcEndpoint} + +class InboxSuite extends SparkFunSuite { + + test("post") { + val endpoint = new TestRpcEndpoint + val endpointRef = mock(classOf[NettyRpcEndpointRef]) + when(endpointRef.name).thenReturn("hello") + + val dispatcher = mock(classOf[Dispatcher]) + + val inbox = new Inbox(endpointRef, endpoint) + val message = ContentMessage(null, "hi", false, null) + inbox.post(message) + inbox.process(dispatcher) + assert(inbox.isEmpty) + + endpoint.verifySingleReceiveMessage("hi") + + inbox.stop() + inbox.process(dispatcher) + assert(inbox.isEmpty) + endpoint.verifyStarted() + endpoint.verifyStopped() + } + + test("post: with reply") { + val endpoint = new TestRpcEndpoint + val endpointRef = mock(classOf[NettyRpcEndpointRef]) + val dispatcher = mock(classOf[Dispatcher]) + + val inbox = new Inbox(endpointRef, endpoint) + val message = ContentMessage(null, "hi", true, null) + inbox.post(message) + inbox.process(dispatcher) + assert(inbox.isEmpty) + + endpoint.verifySingleReceiveAndReplyMessage("hi") + } + + test("post: multiple threads") { + val endpoint = new TestRpcEndpoint + val endpointRef = mock(classOf[NettyRpcEndpointRef]) + when(endpointRef.name).thenReturn("hello") + + val dispatcher = mock(classOf[Dispatcher]) + + val numDroppedMessages = new AtomicInteger(0) + val inbox = new Inbox(endpointRef, endpoint) { + override def onDrop(message: InboxMessage): Unit = { + numDroppedMessages.incrementAndGet() + } + } + + val exitLatch = new CountDownLatch(10) + + for (_ <- 0 until 10) { + new Thread { + override def run(): Unit = { + for (_ <- 0 until 100) { + val message = ContentMessage(null, "hi", false, null) + inbox.post(message) + } + exitLatch.countDown() + } + }.start() + } + // Try to process some messages + inbox.process(dispatcher) + inbox.stop() + // After `stop` is called, further messages will be dropped. However, while `stop` is called, + // some messages may be post to Inbox, so process them here. + inbox.process(dispatcher) + assert(inbox.isEmpty) + + exitLatch.await(30, TimeUnit.SECONDS) + + assert(1000 === endpoint.numReceiveMessages + numDroppedMessages.get) + endpoint.verifyStarted() + endpoint.verifyStopped() + } + + test("post: Associated") { + val endpoint = new TestRpcEndpoint + val endpointRef = mock(classOf[NettyRpcEndpointRef]) + val dispatcher = mock(classOf[Dispatcher]) + + val remoteAddress = RpcAddress("localhost", 11111) + + val inbox = new Inbox(endpointRef, endpoint) + inbox.post(Associated(remoteAddress)) + inbox.process(dispatcher) + + endpoint.verifySingleOnConnectedMessage(remoteAddress) + } + + test("post: Disassociated") { + val endpoint = new TestRpcEndpoint + val endpointRef = mock(classOf[NettyRpcEndpointRef]) + val dispatcher = mock(classOf[Dispatcher]) + + val remoteAddress = RpcAddress("localhost", 11111) + + val inbox = new Inbox(endpointRef, endpoint) + inbox.post(Disassociated(remoteAddress)) + inbox.process(dispatcher) + + endpoint.verifySingleOnDisconnectedMessage(remoteAddress) + } + + test("post: AssociationError") { + val endpoint = new TestRpcEndpoint + val endpointRef = mock(classOf[NettyRpcEndpointRef]) + val dispatcher = mock(classOf[Dispatcher]) + + val remoteAddress = RpcAddress("localhost", 11111) + val cause = new RuntimeException("Oops") + + val inbox = new Inbox(endpointRef, endpoint) + inbox.post(AssociationError(cause, remoteAddress)) + inbox.process(dispatcher) + + endpoint.verifySingleOnNetworkErrorMessage(cause, remoteAddress) + } +} diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala new file mode 100644 index 0000000000..a5d43d3704 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import org.apache.spark.SparkFunSuite + +class NettyRpcAddressSuite extends SparkFunSuite { + + test("toString") { + val addr = NettyRpcAddress("localhost", 12345, "test") + assert(addr.toString === "spark://test@localhost:12345") + } + +} diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala new file mode 100644 index 0000000000..be19668e17 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.rpc._ + +class NettyRpcEnvSuite extends RpcEnvSuite { + + override def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv = { + val config = RpcEnvConfig(conf, "test", "localhost", port, new SecurityManager(conf)) + new NettyRpcEnvFactory().create(config) + } + + test("non-existent endpoint") { + val uri = env.uriOf("test", env.address, "nonexist-endpoint") + val e = intercept[RpcEndpointNotFoundException] { + env.setupEndpointRef("test", env.address, "nonexist-endpoint") + } + assert(e.getMessage.contains(uri)) + } + +} diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala new file mode 100644 index 0000000000..06ca035d19 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import java.net.InetSocketAddress + +import io.netty.channel.Channel +import org.mockito.Mockito._ +import org.mockito.Matchers._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.network.client.{TransportResponseHandler, TransportClient} +import org.apache.spark.rpc._ + +class NettyRpcHandlerSuite extends SparkFunSuite { + + val env = mock(classOf[NettyRpcEnv]) + when(env.deserialize(any(classOf[Array[Byte]]))(any())). + thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false)) + + test("receive") { + val dispatcher = mock(classOf[Dispatcher]) + val nettyRpcHandler = new NettyRpcHandler(dispatcher, env) + + val channel = mock(classOf[Channel]) + val client = new TransportClient(channel, mock(classOf[TransportResponseHandler])) + when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) + nettyRpcHandler.receive(client, null, null) + + when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40001)) + nettyRpcHandler.receive(client, null, null) + + verify(dispatcher, times(1)).broadcastMessage(Associated(RpcAddress("localhost", 12345))) + } + + test("connectionTerminated") { + val dispatcher = mock(classOf[Dispatcher]) + val nettyRpcHandler = new NettyRpcHandler(dispatcher, env) + + val channel = mock(classOf[Channel]) + val client = new TransportClient(channel, mock(classOf[TransportResponseHandler])) + when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) + nettyRpcHandler.receive(client, null, null) + + when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) + nettyRpcHandler.connectionTerminated(client) + + verify(dispatcher, times(1)).broadcastMessage(Associated(RpcAddress("localhost", 12345))) + verify(dispatcher, times(1)).broadcastMessage(Disassociated(RpcAddress("localhost", 12345))) + } + +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index df841288a0..fbb8bb6b2f 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -78,6 +78,10 @@ public class TransportClient implements Closeable { this.handler = Preconditions.checkNotNull(handler); } + public Channel getChannel() { + return channel; + } + public boolean isActive() { return channel.isOpen() || channel.isActive(); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java index 2ba92a40f8..dbb7f95f55 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -52,4 +52,6 @@ public abstract class RpcHandler { * No further requests will come from this client. */ public void connectionTerminated(TransportClient client) { } + + public void exceptionCaught(Throwable cause, TransportClient client) { } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index df60278058..96941d26be 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -71,6 +71,7 @@ public class TransportRequestHandler extends MessageHandler { @Override public void exceptionCaught(Throwable cause) { + rpcHandler.exceptionCaught(cause, reverseClient); } @Override diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 204e6142fd..d053e9e849 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -474,7 +474,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // Remote messages case RegisterReceiver(streamId, typ, hostPort, receiverEndpoint) => val successful = - registerReceiver(streamId, typ, hostPort, receiverEndpoint, context.sender.address) + registerReceiver(streamId, typ, hostPort, receiverEndpoint, context.senderAddress) context.reply(successful) case AddBlock(receivedBlockInfo) => context.reply(addBlock(receivedBlockInfo)) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 07a0a45e60..0df31736c1 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -556,10 +556,7 @@ private[spark] class ApplicationMaster( override val rpcEnv: RpcEnv, driver: RpcEndpointRef, isClusterMode: Boolean) extends RpcEndpoint with Logging { - override def onStart(): Unit = { - driver.send(RegisterClusterManager(self)) - - } + driver.send(RegisterClusterManager(self)) override def receive: PartialFunction[Any, Unit] = { case x: AddWebUIFilter => -- cgit v1.2.3