diff options
Diffstat (limited to 'core')
6 files changed, 178 insertions, 11 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index a1c66ef4fc..6f336a7c29 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2658,7 +2658,7 @@ object SparkContext extends Logging { val coarseGrained = sc.conf.getBoolean("spark.mesos.coarse", false) val url = mesosUrl.stripPrefix("mesos://") // strip scheme from raw Mesos URLs val backend = if (coarseGrained) { - new CoarseMesosSchedulerBackend(scheduler, sc, url) + new CoarseMesosSchedulerBackend(scheduler, sc, url, sc.env.securityManager) } else { new MesosSchedulerBackend(scheduler, sc, url) } diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index 4089c3e771..20a9faa178 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -27,6 +27,7 @@ import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.sasl.SaslServerBootstrap import org.apache.spark.network.server.TransportServer import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler +import org.apache.spark.network.util.TransportConf import org.apache.spark.util.Utils /** @@ -45,11 +46,16 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana private val useSasl: Boolean = securityManager.isAuthenticationEnabled() private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, numUsableCores = 0) - private val blockHandler = new ExternalShuffleBlockHandler(transportConf) + private val blockHandler = newShuffleBlockHandler(transportConf) private val transportContext: TransportContext = new TransportContext(transportConf, blockHandler) private var server: TransportServer = _ + /** Create a new shuffle block handler. Factored out for subclasses to override. */ + protected def newShuffleBlockHandler(conf: TransportConf): ExternalShuffleBlockHandler = { + new ExternalShuffleBlockHandler(conf) + } + /** Starts the external shuffle service if the user has configured us to. */ def startIfEnabled() { if (enabled) { @@ -93,6 +99,13 @@ object ExternalShuffleService extends Logging { private val barrier = new CountDownLatch(1) def main(args: Array[String]): Unit = { + main(args, (conf: SparkConf, sm: SecurityManager) => new ExternalShuffleService(conf, sm)) + } + + /** A helper main method that allows the caller to call this with a custom shuffle service. */ + private[spark] def main( + args: Array[String], + newShuffleService: (SparkConf, SecurityManager) => ExternalShuffleService): Unit = { val sparkConf = new SparkConf Utils.loadDefaultSparkProperties(sparkConf) val securityManager = new SecurityManager(sparkConf) @@ -100,7 +113,7 @@ object ExternalShuffleService extends Logging { // we override this value since this service is started from the command line // and we assume the user really wants it to be running sparkConf.set("spark.shuffle.service.enabled", "true") - server = new ExternalShuffleService(sparkConf, securityManager) + server = newShuffleService(sparkConf, securityManager) server.start() installShutdownHook() diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala new file mode 100644 index 0000000000..061857476a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos + +import java.net.SocketAddress + +import scala.collection.mutable + +import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.deploy.ExternalShuffleService +import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} +import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage +import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver +import org.apache.spark.network.util.TransportConf + +/** + * An RPC endpoint that receives registration requests from Spark drivers running on Mesos. + * It detects driver termination and calls the cleanup callback to [[ExternalShuffleService]]. + */ +private[mesos] class MesosExternalShuffleBlockHandler(transportConf: TransportConf) + extends ExternalShuffleBlockHandler(transportConf) with Logging { + + // Stores a map of driver socket addresses to app ids + private val connectedApps = new mutable.HashMap[SocketAddress, String] + + protected override def handleMessage( + message: BlockTransferMessage, + client: TransportClient, + callback: RpcResponseCallback): Unit = { + message match { + case RegisterDriverParam(appId) => + val address = client.getSocketAddress + logDebug(s"Received registration request from app $appId (remote address $address).") + if (connectedApps.contains(address)) { + val existingAppId = connectedApps(address) + if (!existingAppId.equals(appId)) { + logError(s"A new app '$appId' has connected to existing address $address, " + + s"removing previously registered app '$existingAppId'.") + applicationRemoved(existingAppId, true) + } + } + connectedApps(address) = appId + callback.onSuccess(new Array[Byte](0)) + case _ => super.handleMessage(message, client, callback) + } + } + + /** + * On connection termination, clean up shuffle files written by the associated application. + */ + override def connectionTerminated(client: TransportClient): Unit = { + val address = client.getSocketAddress + if (connectedApps.contains(address)) { + val appId = connectedApps(address) + logInfo(s"Application $appId disconnected (address was $address).") + applicationRemoved(appId, true /* cleanupLocalDirs */) + connectedApps.remove(address) + } else { + logWarning(s"Unknown $address disconnected.") + } + } + + /** An extractor object for matching [[RegisterDriver]] message. */ + private object RegisterDriverParam { + def unapply(r: RegisterDriver): Option[String] = Some(r.getAppId) + } +} + +/** + * A wrapper of [[ExternalShuffleService]] that provides an additional endpoint for drivers + * to associate with. This allows the shuffle service to detect when a driver is terminated + * and can clean up the associated shuffle files. + */ +private[mesos] class MesosExternalShuffleService(conf: SparkConf, securityManager: SecurityManager) + extends ExternalShuffleService(conf, securityManager) { + + protected override def newShuffleBlockHandler( + conf: TransportConf): ExternalShuffleBlockHandler = { + new MesosExternalShuffleBlockHandler(conf) + } +} + +private[spark] object MesosExternalShuffleService extends Logging { + + def main(args: Array[String]): Unit = { + ExternalShuffleService.main(args, + (conf: SparkConf, sm: SecurityManager) => new MesosExternalShuffleService(conf, sm)) + } +} + + diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala index d2b2baef1d..dfcbc51cdf 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala @@ -47,11 +47,11 @@ private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint * * It is guaranteed that `onStart`, `receive` and `onStop` will be called in sequence. * - * The lift-cycle will be: + * The life-cycle of an endpoint is: * - * constructor onStart receive* onStop + * constructor -> onStart -> receive* -> onStop * - * Note: `receive` can be called concurrently. If you want `receive` is thread-safe, please use + * Note: `receive` can be called concurrently. If you want `receive` to be thread-safe, please use * [[ThreadSafeRpcEndpoint]] * * If any error is thrown from one of [[RpcEndpoint]] methods except `onError`, `onError` will be diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index b7fde0d9b3..15a0915708 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -26,12 +26,15 @@ import scala.collection.mutable.{HashMap, HashSet} import com.google.common.collect.HashBiMap import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} -import org.apache.mesos.{Scheduler => MScheduler, _} +import org.apache.mesos.{Scheduler => MScheduler, SchedulerDriver} + +import org.apache.spark.{SecurityManager, SparkContext, SparkEnv, SparkException, TaskState} +import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient import org.apache.spark.rpc.RpcAddress import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils -import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState} /** * A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds @@ -46,7 +49,8 @@ import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState} private[spark] class CoarseMesosSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext, - master: String) + master: String, + securityManager: SecurityManager) extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with MScheduler with MesosSchedulerUtils { @@ -56,12 +60,19 @@ private[spark] class CoarseMesosSchedulerBackend( // Maximum number of cores to acquire (TODO: we'll need more flexible controls here) val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt + // If shuffle service is enabled, the Spark driver will register with the shuffle service. + // This is for cleaning up shuffle files reliably. + private val shuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) + // Cores we have acquired with each Mesos task ID val coresByTaskId = new HashMap[Int, Int] var totalCoresAcquired = 0 val slaveIdsWithExecutors = new HashSet[String] + // Maping from slave Id to hostname + private val slaveIdToHost = new HashMap[String, String] + val taskIdToSlaveId: HashBiMap[Int, String] = HashBiMap.create[Int, String] // How many times tasks on each slave failed val failuresBySlaveId: HashMap[String, Int] = new HashMap[String, Int] @@ -90,6 +101,19 @@ private[spark] class CoarseMesosSchedulerBackend( private val slaveOfferConstraints = parseConstraintString(sc.conf.get("spark.mesos.constraints", "")) + // A client for talking to the external shuffle service, if it is a + private val mesosExternalShuffleClient: Option[MesosExternalShuffleClient] = { + if (shuffleServiceEnabled) { + Some(new MesosExternalShuffleClient( + SparkTransportConf.fromSparkConf(conf), + securityManager, + securityManager.isAuthenticationEnabled(), + securityManager.isSaslEncryptionEnabled())) + } else { + None + } + } + var nextMesosTaskId = 0 @volatile var appId: String = _ @@ -188,6 +212,7 @@ private[spark] class CoarseMesosSchedulerBackend( override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { appId = frameworkId.getValue + mesosExternalShuffleClient.foreach(_.init(appId)) logInfo("Registered as framework ID " + appId) markRegistered() } @@ -244,6 +269,7 @@ private[spark] class CoarseMesosSchedulerBackend( // accept the offer and launch the task logDebug(s"Accepting offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + slaveIdToHost(offer.getSlaveId.getValue) = offer.getHostname d.launchTasks( Collections.singleton(offer.getId), Collections.singleton(taskBuilder.build()), filters) @@ -261,7 +287,27 @@ private[spark] class CoarseMesosSchedulerBackend( val taskId = status.getTaskId.getValue.toInt val state = status.getState logInfo(s"Mesos task $taskId is now $state") + val slaveId: String = status.getSlaveId.getValue stateLock.synchronized { + // If the shuffle service is enabled, have the driver register with each one of the + // shuffle services. This allows the shuffle services to clean up state associated with + // this application when the driver exits. There is currently not a great way to detect + // this through Mesos, since the shuffle services are set up independently. + if (TaskState.fromMesos(state).equals(TaskState.RUNNING) && + slaveIdToHost.contains(slaveId) && + shuffleServiceEnabled) { + assume(mesosExternalShuffleClient.isDefined, + "External shuffle client was not instantiated even though shuffle service is enabled.") + // TODO: Remove this and allow the MesosExternalShuffleService to detect + // framework termination when new Mesos Framework HTTP API is available. + val externalShufflePort = conf.getInt("spark.shuffle.service.port", 7337) + val hostname = slaveIdToHost.remove(slaveId).get + logDebug(s"Connecting to shuffle service on slave $slaveId, " + + s"host $hostname, port $externalShufflePort for app ${conf.getAppId}") + mesosExternalShuffleClient.get + .registerDriverWithShuffleService(hostname, externalShufflePort) + } + if (TaskState.isFinished(TaskState.fromMesos(state))) { val slaveId = taskIdToSlaveId(taskId) slaveIdsWithExecutors -= slaveId diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala index 4b504df7b8..525ee0d3bd 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala @@ -30,7 +30,7 @@ import org.scalatest.mock.MockitoSugar import org.scalatest.BeforeAndAfter import org.apache.spark.scheduler.TaskSchedulerImpl -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SecurityManager, SparkFunSuite} class CoarseMesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext @@ -59,7 +59,8 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite private def createSchedulerBackend( taskScheduler: TaskSchedulerImpl, driver: SchedulerDriver): CoarseMesosSchedulerBackend = { - val backend = new CoarseMesosSchedulerBackend(taskScheduler, sc, "master") { + val securityManager = mock[SecurityManager] + val backend = new CoarseMesosSchedulerBackend(taskScheduler, sc, "master", securityManager) { override protected def createSchedulerDriver( masterUrl: String, scheduler: Scheduler, |