diff options
author | Mridul Muralidharan <mridul@gmail.com> | 2013-04-15 18:12:11 +0530 |
---|---|---|
committer | Mridul Muralidharan <mridul@gmail.com> | 2013-04-15 18:12:11 +0530 |
commit | d90d2af1036e909f81cf77c85bfe589993c4f9f3 (patch) | |
tree | 4119f2702709fe62317e15273b49183678fe24e0 /core/src | |
parent | 6798a09df84fb97e196c84d55cf3e21ad676871f (diff) | |
download | spark-d90d2af1036e909f81cf77c85bfe589993c4f9f3.tar.gz spark-d90d2af1036e909f81cf77c85bfe589993c4f9f3.tar.bz2 spark-d90d2af1036e909f81cf77c85bfe589993c4f9f3.zip |
Checkpoint commit - compiles and passes a lot of tests - not all though, looking into FileSuite issues
Diffstat (limited to 'core/src')
66 files changed, 3488 insertions, 477 deletions
diff --git a/core/src/hadoop1/scala/spark/deploy/SparkHadoopUtil.scala b/core/src/hadoop1/scala/spark/deploy/SparkHadoopUtil.scala new file mode 100644 index 0000000000..d4badbc5c4 --- /dev/null +++ b/core/src/hadoop1/scala/spark/deploy/SparkHadoopUtil.scala @@ -0,0 +1,18 @@ +package spark.deploy + +/** + * Contains util methods to interact with Hadoop from spark. + */ +object SparkHadoopUtil { + + def getUserNameFromEnvironment(): String = { + // defaulting to -D ... + System.getProperty("user.name") + } + + def runAsUser(func: (Product) => Unit, args: Product) { + + // Add support, if exists - for now, simply run func ! + func(args) + } +} diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/SparkHadoopUtil.scala b/core/src/hadoop2-yarn/scala/spark/deploy/SparkHadoopUtil.scala new file mode 100644 index 0000000000..66e5ad8491 --- /dev/null +++ b/core/src/hadoop2-yarn/scala/spark/deploy/SparkHadoopUtil.scala @@ -0,0 +1,59 @@ +package spark.deploy + +import collection.mutable.HashMap +import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment +import java.security.PrivilegedExceptionAction + +/** + * Contains util methods to interact with Hadoop from spark. + */ +object SparkHadoopUtil { + + val yarnConf = new YarnConfiguration(new Configuration()) + + def getUserNameFromEnvironment(): String = { + // defaulting to env if -D is not present ... + val retval = System.getProperty(Environment.USER.name, System.getenv(Environment.USER.name)) + + // If nothing found, default to user we are running as + if (retval == null) System.getProperty("user.name") else retval + } + + def runAsUser(func: (Product) => Unit, args: Product) { + runAsUser(func, args, getUserNameFromEnvironment()) + } + + def runAsUser(func: (Product) => Unit, args: Product, user: String) { + + // println("running as user " + jobUserName) + + UserGroupInformation.setConfiguration(yarnConf) + val appMasterUgi: UserGroupInformation = UserGroupInformation.createRemoteUser(user) + appMasterUgi.doAs(new PrivilegedExceptionAction[AnyRef] { + def run: AnyRef = { + func(args) + // no return value ... + null + } + }) + } + + // Note that all params which start with SPARK are propagated all the way through, so if in yarn mode, this MUST be set to true. + def isYarnMode(): Boolean = { + val yarnMode = System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")) + java.lang.Boolean.valueOf(yarnMode) + } + + // Set an env variable indicating we are running in YARN mode. + // Note that anything with SPARK prefix gets propagated to all (remote) processes + def setYarnMode() { + System.setProperty("SPARK_YARN_MODE", "true") + } + + def setYarnMode(env: HashMap[String, String]) { + env("SPARK_YARN_MODE") = "true" + } +} diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMaster.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMaster.scala new file mode 100644 index 0000000000..65361e0ed9 --- /dev/null +++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMaster.scala @@ -0,0 +1,342 @@ +package spark.deploy.yarn + +import java.net.Socket +import java.util.concurrent.CopyOnWriteArrayList +import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.net.NetUtils +import org.apache.hadoop.yarn.api._ +import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.api.protocolrecords._ +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.ipc.YarnRPC +import org.apache.hadoop.yarn.util.{ConverterUtils, Records} +import scala.collection.JavaConversions._ +import spark.{SparkContext, Logging, Utils} +import org.apache.hadoop.security.UserGroupInformation +import java.security.PrivilegedExceptionAction + +class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) extends Logging { + + def this(args: ApplicationMasterArguments) = this(args, new Configuration()) + + private var rpc: YarnRPC = YarnRPC.create(conf) + private var resourceManager: AMRMProtocol = null + private var appAttemptId: ApplicationAttemptId = null + private var userThread: Thread = null + private val yarnConf: YarnConfiguration = new YarnConfiguration(conf) + + private var yarnAllocator: YarnAllocationHandler = null + + def run() { + + // Initialization + val jobUserName = Utils.getUserNameFromEnvironment() + logInfo("running as user " + jobUserName) + + // run as user ... + UserGroupInformation.setConfiguration(yarnConf) + val appMasterUgi: UserGroupInformation = UserGroupInformation.createRemoteUser(jobUserName) + appMasterUgi.doAs(new PrivilegedExceptionAction[AnyRef] { + def run: AnyRef = { + runImpl() + return null + } + }) + } + + private def runImpl() { + + appAttemptId = getApplicationAttemptId() + resourceManager = registerWithResourceManager() + val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster() + + // Compute number of threads for akka + val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory() + + if (minimumMemory > 0) { + val mem = args.workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD + val numCore = (mem / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0) + + if (numCore > 0) { + // do not override - hits https://issues.apache.org/jira/browse/HADOOP-8406 + // TODO: Uncomment when hadoop is on a version which has this fixed. + // args.workerCores = numCore + } + } + + // Workaround until hadoop moves to something which has + // https://issues.apache.org/jira/browse/HADOOP-8406 + // ignore result + // This does not, unfortunately, always work reliably ... but alleviates the bug a lot of times + // Hence args.workerCores = numCore disabled above. Any better option ? + // org.apache.hadoop.io.compress.CompressionCodecFactory.getCodecClasses(conf) + + ApplicationMaster.register(this) + // Start the user's JAR + userThread = startUserClass() + + // This a bit hacky, but we need to wait until the spark.master.port property has + // been set by the Thread executing the user class. + waitForSparkMaster() + + // Allocate all containers + allocateWorkers() + + // Wait for the user class to Finish + userThread.join() + + // Finish the ApplicationMaster + finishApplicationMaster() + // TODO: Exit based on success/failure + System.exit(0) + } + + private def getApplicationAttemptId(): ApplicationAttemptId = { + val envs = System.getenv() + val containerIdString = envs.get(ApplicationConstants.AM_CONTAINER_ID_ENV) + val containerId = ConverterUtils.toContainerId(containerIdString) + val appAttemptId = containerId.getApplicationAttemptId() + logInfo("ApplicationAttemptId: " + appAttemptId) + return appAttemptId + } + + private def registerWithResourceManager(): AMRMProtocol = { + val rmAddress = NetUtils.createSocketAddr(yarnConf.get( + YarnConfiguration.RM_SCHEDULER_ADDRESS, + YarnConfiguration.DEFAULT_RM_SCHEDULER_ADDRESS)) + logInfo("Connecting to ResourceManager at " + rmAddress) + return rpc.getProxy(classOf[AMRMProtocol], rmAddress, conf).asInstanceOf[AMRMProtocol] + } + + private def registerApplicationMaster(): RegisterApplicationMasterResponse = { + logInfo("Registering the ApplicationMaster") + val appMasterRequest = Records.newRecord(classOf[RegisterApplicationMasterRequest]) + .asInstanceOf[RegisterApplicationMasterRequest] + appMasterRequest.setApplicationAttemptId(appAttemptId) + // Setting this to master host,port - so that the ApplicationReport at client has some sensible info. + // Users can then monitor stderr/stdout on that node if required. + appMasterRequest.setHost(Utils.localHostName()) + appMasterRequest.setRpcPort(0) + // What do we provide here ? Might make sense to expose something sensible later ? + appMasterRequest.setTrackingUrl("") + return resourceManager.registerApplicationMaster(appMasterRequest) + } + + private def waitForSparkMaster() { + logInfo("Waiting for spark master to be reachable.") + var masterUp = false + while(!masterUp) { + val masterHost = System.getProperty("spark.master.host") + val masterPort = System.getProperty("spark.master.port") + try { + val socket = new Socket(masterHost, masterPort.toInt) + socket.close() + logInfo("Master now available: " + masterHost + ":" + masterPort) + masterUp = true + } catch { + case e: Exception => + logError("Failed to connect to master at " + masterHost + ":" + masterPort) + Thread.sleep(100) + } + } + } + + private def startUserClass(): Thread = { + logInfo("Starting the user JAR in a separate Thread") + val mainMethod = Class.forName(args.userClass, false, Thread.currentThread.getContextClassLoader) + .getMethod("main", classOf[Array[String]]) + val t = new Thread { + override def run() { + var mainArgs: Array[String] = null + var startIndex = 0 + + // I am sure there is a better 'scala' way to do this .... but I am just trying to get things to work right now ! + if (args.userArgs.isEmpty || args.userArgs.get(0) != "yarn-standalone") { + // ensure that first param is ALWAYS "yarn-standalone" + mainArgs = new Array[String](args.userArgs.size() + 1) + mainArgs.update(0, "yarn-standalone") + startIndex = 1 + } + else { + mainArgs = new Array[String](args.userArgs.size()) + } + + args.userArgs.copyToArray(mainArgs, startIndex, args.userArgs.size()) + + mainMethod.invoke(null, mainArgs) + } + } + t.start() + return t + } + + private def allocateWorkers() { + logInfo("Waiting for spark context initialization") + + try { + var sparkContext: SparkContext = null + ApplicationMaster.sparkContextRef.synchronized { + var count = 0 + while (ApplicationMaster.sparkContextRef.get() == null) { + logInfo("Waiting for spark context initialization ... " + count) + count = count + 1 + ApplicationMaster.sparkContextRef.wait(10000L) + } + sparkContext = ApplicationMaster.sparkContextRef.get() + assert(sparkContext != null) + this.yarnAllocator = YarnAllocationHandler.newAllocator(yarnConf, resourceManager, appAttemptId, args, sparkContext.preferredNodeLocationData) + } + + + logInfo("Allocating " + args.numWorkers + " workers.") + // Wait until all containers have finished + // TODO: This is a bit ugly. Can we make it nicer? + // TODO: Handle container failure + while(yarnAllocator.getNumWorkersRunning < args.numWorkers && + // If user thread exists, then quit ! + userThread.isAlive) { + + this.yarnAllocator.allocateContainers(math.max(args.numWorkers - yarnAllocator.getNumWorkersRunning, 0)) + ApplicationMaster.incrementAllocatorLoop(1) + Thread.sleep(100) + } + } finally { + // in case of exceptions, etc - ensure that count is atleast ALLOCATOR_LOOP_WAIT_COUNT : + // so that the loop (in ApplicationMaster.sparkContextInitialized) breaks + ApplicationMaster.incrementAllocatorLoop(ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT) + } + logInfo("All workers have launched.") + + // Launch a progress reporter thread, else app will get killed after expiration (def: 10mins) timeout + if (userThread.isAlive){ + // ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse. + + val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000) + // must be <= timeoutInterval/ 2. + // On other hand, also ensure that we are reasonably responsive without causing too many requests to RM. + // so atleast 1 minute or timeoutInterval / 10 - whichever is higher. + val interval = math.min(timeoutInterval / 2, math.max(timeoutInterval/ 10, 60000L)) + launchReporterThread(interval) + } + } + + // TODO: We might want to extend this to allocate more containers in case they die ! + private def launchReporterThread(_sleepTime: Long): Thread = { + val sleepTime = if (_sleepTime <= 0 ) 0 else _sleepTime + + val t = new Thread { + override def run() { + while (userThread.isAlive){ + val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning + if (missingWorkerCount > 0) { + logInfo("Allocating " + missingWorkerCount + " containers to make up for (potentially ?) lost containers") + yarnAllocator.allocateContainers(missingWorkerCount) + } + else sendProgress() + Thread.sleep(sleepTime) + } + } + } + // setting to daemon status, though this is usually not a good idea. + t.setDaemon(true) + t.start() + logInfo("Started progress reporter thread - sleep time : " + sleepTime) + return t + } + + private def sendProgress() { + logDebug("Sending progress") + // simulated with an allocate request with no nodes requested ... + yarnAllocator.allocateContainers(0) + } + + /* + def printContainers(containers: List[Container]) = { + for (container <- containers) { + logInfo("Launching shell command on a new container." + + ", containerId=" + container.getId() + + ", containerNode=" + container.getNodeId().getHost() + + ":" + container.getNodeId().getPort() + + ", containerNodeURI=" + container.getNodeHttpAddress() + + ", containerState" + container.getState() + + ", containerResourceMemory" + + container.getResource().getMemory()) + } + } + */ + + def finishApplicationMaster() { + val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest]) + .asInstanceOf[FinishApplicationMasterRequest] + finishReq.setAppAttemptId(appAttemptId) + // TODO: Check if the application has failed or succeeded + finishReq.setFinishApplicationStatus(FinalApplicationStatus.SUCCEEDED) + resourceManager.finishApplicationMaster(finishReq) + } + +} + +object ApplicationMaster { + // number of times to wait for the allocator loop to complete. + // each loop iteration waits for 100ms, so maximum of 3 seconds. + // This is to ensure that we have reasonable number of containers before we start + // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be optimal as more + // containers are available. Might need to handle this better. + private val ALLOCATOR_LOOP_WAIT_COUNT = 30 + def incrementAllocatorLoop(by: Int) { + val count = yarnAllocatorLoop.getAndAdd(by) + if (count >= ALLOCATOR_LOOP_WAIT_COUNT){ + yarnAllocatorLoop.synchronized { + // to wake threads off wait ... + yarnAllocatorLoop.notifyAll() + } + } + } + + private val applicationMasters = new CopyOnWriteArrayList[ApplicationMaster]() + + def register(master: ApplicationMaster) { + applicationMasters.add(master) + } + + val sparkContextRef: AtomicReference[SparkContext] = new AtomicReference[SparkContext](null) + val yarnAllocatorLoop: AtomicInteger = new AtomicInteger(0) + + def sparkContextInitialized(sc: SparkContext): Boolean = { + var modified = false + sparkContextRef.synchronized { + modified = sparkContextRef.compareAndSet(null, sc) + sparkContextRef.notifyAll() + } + + // Add a shutdown hook - as a best case effort in case users do not call sc.stop or do System.exit + // Should not really have to do this, but it helps yarn to evict resources earlier. + // not to mention, prevent Client declaring failure even though we exit'ed properly. + if (modified) { + Runtime.getRuntime().addShutdownHook(new Thread with Logging { + // This is not just to log, but also to ensure that log system is initialized for this instance when we actually are 'run' + logInfo("Adding shutdown hook for context " + sc) + override def run() { + logInfo("Invoking sc stop from shutdown hook") + sc.stop() + // best case ... + for (master <- applicationMasters) master.finishApplicationMaster + } + } ) + } + + // Wait for initialization to complete and atleast 'some' nodes can get allocated + yarnAllocatorLoop.synchronized { + while (yarnAllocatorLoop.get() <= ALLOCATOR_LOOP_WAIT_COUNT){ + yarnAllocatorLoop.wait(1000L) + } + } + modified + } + + def main(argStrings: Array[String]) { + val args = new ApplicationMasterArguments(argStrings) + new ApplicationMaster(args).run() + } +} diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMasterArguments.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMasterArguments.scala new file mode 100644 index 0000000000..dc89125d81 --- /dev/null +++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -0,0 +1,78 @@ +package spark.deploy.yarn + +import spark.util.IntParam +import collection.mutable.ArrayBuffer + +class ApplicationMasterArguments(val args: Array[String]) { + var userJar: String = null + var userClass: String = null + var userArgs: Seq[String] = Seq[String]() + var workerMemory = 1024 + var workerCores = 1 + var numWorkers = 2 + + parseArgs(args.toList) + + private def parseArgs(inputArgs: List[String]): Unit = { + val userArgsBuffer = new ArrayBuffer[String]() + + var args = inputArgs + + while (! args.isEmpty) { + + args match { + case ("--jar") :: value :: tail => + userJar = value + args = tail + + case ("--class") :: value :: tail => + userClass = value + args = tail + + case ("--args") :: value :: tail => + userArgsBuffer += value + args = tail + + case ("--num-workers") :: IntParam(value) :: tail => + numWorkers = value + args = tail + + case ("--worker-memory") :: IntParam(value) :: tail => + workerMemory = value + args = tail + + case ("--worker-cores") :: IntParam(value) :: tail => + workerCores = value + args = tail + + case Nil => + if (userJar == null || userClass == null) { + printUsageAndExit(1) + } + + case _ => + printUsageAndExit(1, args) + } + } + + userArgs = userArgsBuffer.readOnly + } + + def printUsageAndExit(exitCode: Int, unknownParam: Any = null) { + if (unknownParam != null) { + System.err.println("Unknown/unsupported param " + unknownParam) + } + System.err.println( + "Usage: spark.deploy.yarn.ApplicationMaster [options] \n" + + "Options:\n" + + " --jar JAR_PATH Path to your application's JAR file (required)\n" + + " --class CLASS_NAME Name of your application's main class (required)\n" + + " --args ARGS Arguments to be passed to your application's main class.\n" + + " Mutliple invocations are possible, each will be passed in order.\n" + + " Note that first argument will ALWAYS be yarn-standalone : will be added if missing.\n" + + " --num-workers NUM Number of workers to start (Default: 2)\n" + + " --worker-cores NUM Number of cores for the workers (Default: 1)\n" + + " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n") + System.exit(exitCode) + } +} diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala new file mode 100644 index 0000000000..7fa6740579 --- /dev/null +++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala @@ -0,0 +1,326 @@ +package spark.deploy.yarn + +import java.net.{InetSocketAddress, URI} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.net.NetUtils +import org.apache.hadoop.yarn.api._ +import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.api.protocolrecords._ +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.ipc.YarnRPC +import scala.collection.mutable.HashMap +import scala.collection.JavaConversions._ +import spark.{Logging, Utils} +import org.apache.hadoop.yarn.util.{Apps, Records, ConverterUtils} +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment +import spark.deploy.SparkHadoopUtil + +class Client(conf: Configuration, args: ClientArguments) extends Logging { + + def this(args: ClientArguments) = this(new Configuration(), args) + + var applicationsManager: ClientRMProtocol = null + var rpc: YarnRPC = YarnRPC.create(conf) + val yarnConf: YarnConfiguration = new YarnConfiguration(conf) + + def run() { + connectToASM() + logClusterResourceDetails() + + val newApp = getNewApplication() + val appId = newApp.getApplicationId() + + verifyClusterResources(newApp) + val appContext = createApplicationSubmissionContext(appId) + val localResources = prepareLocalResources(appId, "spark") + val env = setupLaunchEnv(localResources) + val amContainer = createContainerLaunchContext(newApp, localResources, env) + + appContext.setQueue(args.amQueue) + appContext.setAMContainerSpec(amContainer) + appContext.setUser(args.amUser) + + submitApp(appContext) + + monitorApplication(appId) + System.exit(0) + } + + + def connectToASM() { + val rmAddress: InetSocketAddress = NetUtils.createSocketAddr( + yarnConf.get(YarnConfiguration.RM_ADDRESS, YarnConfiguration.DEFAULT_RM_ADDRESS) + ) + logInfo("Connecting to ResourceManager at" + rmAddress) + applicationsManager = rpc.getProxy(classOf[ClientRMProtocol], rmAddress, conf) + .asInstanceOf[ClientRMProtocol] + } + + def logClusterResourceDetails() { + val clusterMetrics: YarnClusterMetrics = getYarnClusterMetrics + logInfo("Got Cluster metric info from ASM, numNodeManagers=" + clusterMetrics.getNumNodeManagers) + +/* + val clusterNodeReports: List[NodeReport] = getNodeReports + logDebug("Got Cluster node info from ASM") + for (node <- clusterNodeReports) { + logDebug("Got node report from ASM for, nodeId=" + node.getNodeId + ", nodeAddress=" + node.getHttpAddress + + ", nodeRackName=" + node.getRackName + ", nodeNumContainers=" + node.getNumContainers + ", nodeHealthStatus=" + node.getNodeHealthStatus) + } +*/ + + val queueInfo: QueueInfo = getQueueInfo(args.amQueue) + logInfo("Queue info .. queueName=" + queueInfo.getQueueName + ", queueCurrentCapacity=" + queueInfo.getCurrentCapacity + + ", queueMaxCapacity=" + queueInfo.getMaximumCapacity + ", queueApplicationCount=" + queueInfo.getApplications.size + + ", queueChildQueueCount=" + queueInfo.getChildQueues.size) + } + + def getYarnClusterMetrics: YarnClusterMetrics = { + val request: GetClusterMetricsRequest = Records.newRecord(classOf[GetClusterMetricsRequest]) + val response: GetClusterMetricsResponse = applicationsManager.getClusterMetrics(request) + return response.getClusterMetrics + } + + def getNodeReports: List[NodeReport] = { + val request: GetClusterNodesRequest = Records.newRecord(classOf[GetClusterNodesRequest]) + val response: GetClusterNodesResponse = applicationsManager.getClusterNodes(request) + return response.getNodeReports.toList + } + + def getQueueInfo(queueName: String): QueueInfo = { + val request: GetQueueInfoRequest = Records.newRecord(classOf[GetQueueInfoRequest]) + request.setQueueName(queueName) + request.setIncludeApplications(true) + request.setIncludeChildQueues(false) + request.setRecursive(false) + Records.newRecord(classOf[GetQueueInfoRequest]) + return applicationsManager.getQueueInfo(request).getQueueInfo + } + + def getNewApplication(): GetNewApplicationResponse = { + logInfo("Requesting new Application") + val request = Records.newRecord(classOf[GetNewApplicationRequest]) + val response = applicationsManager.getNewApplication(request) + logInfo("Got new ApplicationId: " + response.getApplicationId()) + return response + } + + def verifyClusterResources(app: GetNewApplicationResponse) = { + val maxMem = app.getMaximumResourceCapability().getMemory() + logInfo("Max mem capabililty of resources in this cluster " + maxMem) + + // If the cluster does not have enough memory resources, exit. + val requestedMem = (args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + args.numWorkers * args.workerMemory + if (requestedMem > maxMem) { + logError("Cluster cannot satisfy memory resource request of " + requestedMem) + System.exit(1) + } + } + + def createApplicationSubmissionContext(appId: ApplicationId): ApplicationSubmissionContext = { + logInfo("Setting up application submission context for ASM") + val appContext = Records.newRecord(classOf[ApplicationSubmissionContext]) + appContext.setApplicationId(appId) + appContext.setApplicationName("Spark") + return appContext + } + + def prepareLocalResources(appId: ApplicationId, appName: String): HashMap[String, LocalResource] = { + logInfo("Preparing Local resources") + val locaResources = HashMap[String, LocalResource]() + // Upload Spark and the application JAR to the remote file system + // Add them as local resources to the AM + val fs = FileSystem.get(conf) + Map("spark.jar" -> System.getenv("SPARK_JAR"), "app.jar" -> args.userJar, "log4j.properties" -> System.getenv("SPARK_LOG4J_CONF")) + .foreach { case(destName, _localPath) => + val localPath: String = if (_localPath != null) _localPath.trim() else "" + if (! localPath.isEmpty()) { + val src = new Path(localPath) + val pathSuffix = appName + "/" + appId.getId() + destName + val dst = new Path(fs.getHomeDirectory(), pathSuffix) + logInfo("Uploading " + src + " to " + dst) + fs.copyFromLocalFile(false, true, src, dst) + val destStatus = fs.getFileStatus(dst) + + val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] + amJarRsrc.setType(LocalResourceType.FILE) + amJarRsrc.setVisibility(LocalResourceVisibility.APPLICATION) + amJarRsrc.setResource(ConverterUtils.getYarnUrlFromPath(dst)) + amJarRsrc.setTimestamp(destStatus.getModificationTime()) + amJarRsrc.setSize(destStatus.getLen()) + locaResources(destName) = amJarRsrc + } + } + return locaResources + } + + def setupLaunchEnv(localResources: HashMap[String, LocalResource]): HashMap[String, String] = { + logInfo("Setting up the launch environment") + val log4jConfLocalRes = localResources.getOrElse("log4j.properties", null) + + val env = new HashMap[String, String]() + Apps.addToEnvironment(env, Environment.USER.name, args.amUser) + + // If log4j present, ensure ours overrides all others + if (log4jConfLocalRes != null) Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./") + + Apps.addToEnvironment(env, Environment.CLASSPATH.name, "$CLASSPATH") + Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./*") + Client.populateHadoopClasspath(yarnConf, env) + SparkHadoopUtil.setYarnMode(env) + env("SPARK_YARN_JAR_PATH") = + localResources("spark.jar").getResource().getScheme.toString() + "://" + + localResources("spark.jar").getResource().getFile().toString() + env("SPARK_YARN_JAR_TIMESTAMP") = localResources("spark.jar").getTimestamp().toString() + env("SPARK_YARN_JAR_SIZE") = localResources("spark.jar").getSize().toString() + + env("SPARK_YARN_USERJAR_PATH") = + localResources("app.jar").getResource().getScheme.toString() + "://" + + localResources("app.jar").getResource().getFile().toString() + env("SPARK_YARN_USERJAR_TIMESTAMP") = localResources("app.jar").getTimestamp().toString() + env("SPARK_YARN_USERJAR_SIZE") = localResources("app.jar").getSize().toString() + + if (log4jConfLocalRes != null) { + env("SPARK_YARN_LOG4J_PATH") = + log4jConfLocalRes.getResource().getScheme.toString() + "://" + log4jConfLocalRes.getResource().getFile().toString() + env("SPARK_YARN_LOG4J_TIMESTAMP") = log4jConfLocalRes.getTimestamp().toString() + env("SPARK_YARN_LOG4J_SIZE") = log4jConfLocalRes.getSize().toString() + } + + // Add each SPARK-* key to the environment + System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v } + return env + } + + def userArgsToString(clientArgs: ClientArguments): String = { + val prefix = " --args " + val args = clientArgs.userArgs + val retval = new StringBuilder() + for (arg <- args){ + retval.append(prefix).append(" '").append(arg).append("' ") + } + + retval.toString + } + + def createContainerLaunchContext(newApp: GetNewApplicationResponse, + localResources: HashMap[String, LocalResource], + env: HashMap[String, String]): ContainerLaunchContext = { + logInfo("Setting up container launch context") + val amContainer = Records.newRecord(classOf[ContainerLaunchContext]) + amContainer.setLocalResources(localResources) + amContainer.setEnvironment(env) + + val minResMemory: Int = newApp.getMinimumResourceCapability().getMemory() + + var amMemory = ((args.amMemory / minResMemory) * minResMemory) + + (if (0 != (args.amMemory % minResMemory)) minResMemory else 0) - YarnAllocationHandler.MEMORY_OVERHEAD + + // Extra options for the JVM + var JAVA_OPTS = "" + + // Add Xmx for am memory + JAVA_OPTS += "-Xmx" + amMemory + "m " + + // Commenting it out for now - so that people can refer to the properties if required. Remove it once cpuset version is pushed out. + // The context is, default gc for server class machines end up using all cores to do gc - hence if there are multiple containers in same + // node, spark gc effects all other containers performance (which can also be other spark containers) + // Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in multi-tenant environments. Not sure how default java gc behaves if it is + // limited to subset of cores on a node. + if (env.isDefinedAt("SPARK_USE_CONC_INCR_GC") && java.lang.Boolean.parseBoolean(env("SPARK_USE_CONC_INCR_GC"))) { + // In our expts, using (default) throughput collector has severe perf ramnifications in multi-tenant machines + JAVA_OPTS += " -XX:+UseConcMarkSweepGC " + JAVA_OPTS += " -XX:+CMSIncrementalMode " + JAVA_OPTS += " -XX:+CMSIncrementalPacing " + JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 " + JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 " + } + if (env.isDefinedAt("SPARK_JAVA_OPTS")) { + JAVA_OPTS += env("SPARK_JAVA_OPTS") + " " + } + + // Command for the ApplicationMaster + val commands = List[String]("java " + + " -server " + + JAVA_OPTS + + " spark.deploy.yarn.ApplicationMaster" + + " --class " + args.userClass + + " --jar " + args.userJar + + userArgsToString(args) + + " --worker-memory " + args.workerMemory + + " --worker-cores " + args.workerCores + + " --num-workers " + args.numWorkers + + " 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" + + " 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr") + logInfo("Command for the ApplicationMaster: " + commands(0)) + amContainer.setCommands(commands) + + val capability = Records.newRecord(classOf[Resource]).asInstanceOf[Resource] + // Memory for the ApplicationMaster + capability.setMemory(args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + amContainer.setResource(capability) + + return amContainer + } + + def submitApp(appContext: ApplicationSubmissionContext) = { + // Create the request to send to the applications manager + val appRequest = Records.newRecord(classOf[SubmitApplicationRequest]) + .asInstanceOf[SubmitApplicationRequest] + appRequest.setApplicationSubmissionContext(appContext) + // Submit the application to the applications manager + logInfo("Submitting application to ASM") + applicationsManager.submitApplication(appRequest) + } + + def monitorApplication(appId: ApplicationId): Boolean = { + while(true) { + Thread.sleep(1000) + val reportRequest = Records.newRecord(classOf[GetApplicationReportRequest]) + .asInstanceOf[GetApplicationReportRequest] + reportRequest.setApplicationId(appId) + val reportResponse = applicationsManager.getApplicationReport(reportRequest) + val report = reportResponse.getApplicationReport() + + logInfo("Application report from ASM: \n" + + "\t application identifier: " + appId.toString() + "\n" + + "\t appId: " + appId.getId() + "\n" + + "\t clientToken: " + report.getClientToken() + "\n" + + "\t appDiagnostics: " + report.getDiagnostics() + "\n" + + "\t appMasterHost: " + report.getHost() + "\n" + + "\t appQueue: " + report.getQueue() + "\n" + + "\t appMasterRpcPort: " + report.getRpcPort() + "\n" + + "\t appStartTime: " + report.getStartTime() + "\n" + + "\t yarnAppState: " + report.getYarnApplicationState() + "\n" + + "\t distributedFinalState: " + report.getFinalApplicationStatus() + "\n" + + "\t appTrackingUrl: " + report.getTrackingUrl() + "\n" + + "\t appUser: " + report.getUser() + ) + + val state = report.getYarnApplicationState() + val dsStatus = report.getFinalApplicationStatus() + if (state == YarnApplicationState.FINISHED || + state == YarnApplicationState.FAILED || + state == YarnApplicationState.KILLED) { + return true + } + } + return true + } +} + +object Client { + def main(argStrings: Array[String]) { + val args = new ClientArguments(argStrings) + SparkHadoopUtil.setYarnMode() + new Client(args).run + } + + // Based on code from org.apache.hadoop.mapreduce.v2.util.MRApps + def populateHadoopClasspath(conf: Configuration, env: HashMap[String, String]) { + for (c <- conf.getStrings(YarnConfiguration.YARN_APPLICATION_CLASSPATH)) { + Apps.addToEnvironment(env, Environment.CLASSPATH.name, c.trim) + } + } +} diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ClientArguments.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ClientArguments.scala new file mode 100644 index 0000000000..53b305f7df --- /dev/null +++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ClientArguments.scala @@ -0,0 +1,104 @@ +package spark.deploy.yarn + +import spark.util.MemoryParam +import spark.util.IntParam +import collection.mutable.{ArrayBuffer, HashMap} +import spark.scheduler.{InputFormatInfo, SplitInfo} + +// TODO: Add code and support for ensuring that yarn resource 'asks' are location aware ! +class ClientArguments(val args: Array[String]) { + var userJar: String = null + var userClass: String = null + var userArgs: Seq[String] = Seq[String]() + var workerMemory = 1024 + var workerCores = 1 + var numWorkers = 2 + var amUser = System.getProperty("user.name") + var amQueue = System.getProperty("QUEUE", "default") + var amMemory: Int = 512 + // TODO + var inputFormatInfo: List[InputFormatInfo] = null + + parseArgs(args.toList) + + private def parseArgs(inputArgs: List[String]): Unit = { + val userArgsBuffer: ArrayBuffer[String] = new ArrayBuffer[String]() + val inputFormatMap: HashMap[String, InputFormatInfo] = new HashMap[String, InputFormatInfo]() + + var args = inputArgs + + while (! args.isEmpty) { + + args match { + case ("--jar") :: value :: tail => + userJar = value + args = tail + + case ("--class") :: value :: tail => + userClass = value + args = tail + + case ("--args") :: value :: tail => + userArgsBuffer += value + args = tail + + case ("--master-memory") :: MemoryParam(value) :: tail => + amMemory = value + args = tail + + case ("--num-workers") :: IntParam(value) :: tail => + numWorkers = value + args = tail + + case ("--worker-memory") :: MemoryParam(value) :: tail => + workerMemory = value + args = tail + + case ("--worker-cores") :: IntParam(value) :: tail => + workerCores = value + args = tail + + case ("--user") :: value :: tail => + amUser = value + args = tail + + case ("--queue") :: value :: tail => + amQueue = value + args = tail + + case Nil => + if (userJar == null || userClass == null) { + printUsageAndExit(1) + } + + case _ => + printUsageAndExit(1, args) + } + } + + userArgs = userArgsBuffer.readOnly + inputFormatInfo = inputFormatMap.values.toList + } + + + def printUsageAndExit(exitCode: Int, unknownParam: Any = null) { + if (unknownParam != null) { + System.err.println("Unknown/unsupported param " + unknownParam) + } + System.err.println( + "Usage: spark.deploy.yarn.Client [options] \n" + + "Options:\n" + + " --jar JAR_PATH Path to your application's JAR file (required)\n" + + " --class CLASS_NAME Name of your application's main class (required)\n" + + " --args ARGS Arguments to be passed to your application's main class.\n" + + " Mutliple invocations are possible, each will be passed in order.\n" + + " Note that first argument will ALWAYS be yarn-standalone : will be added if missing.\n" + + " --num-workers NUM Number of workers to start (Default: 2)\n" + + " --worker-cores NUM Number of cores for the workers (Default: 1)\n" + + " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n" + + " --user USERNAME Run the ApplicationMaster as a different user\n" + ) + System.exit(exitCode) + } + +} diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/WorkerRunnable.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/WorkerRunnable.scala new file mode 100644 index 0000000000..5688f1ab66 --- /dev/null +++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/WorkerRunnable.scala @@ -0,0 +1,171 @@ +package spark.deploy.yarn + +import java.net.URI + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.net.NetUtils +import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.yarn.api._ +import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.api.protocolrecords._ +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.ipc.YarnRPC +import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records} +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment + +import scala.collection.JavaConversions._ +import scala.collection.mutable.HashMap + +import spark.{Logging, Utils} + +class WorkerRunnable(container: Container, conf: Configuration, masterAddress: String, + slaveId: String, hostname: String, workerMemory: Int, workerCores: Int) + extends Runnable with Logging { + + var rpc: YarnRPC = YarnRPC.create(conf) + var cm: ContainerManager = null + val yarnConf: YarnConfiguration = new YarnConfiguration(conf) + + def run = { + logInfo("Starting Worker Container") + cm = connectToCM + startContainer + } + + def startContainer = { + logInfo("Setting up ContainerLaunchContext") + + val ctx = Records.newRecord(classOf[ContainerLaunchContext]) + .asInstanceOf[ContainerLaunchContext] + + ctx.setContainerId(container.getId()) + ctx.setResource(container.getResource()) + val localResources = prepareLocalResources + ctx.setLocalResources(localResources) + + val env = prepareEnvironment + ctx.setEnvironment(env) + + // Extra options for the JVM + var JAVA_OPTS = "" + // Set the JVM memory + val workerMemoryString = workerMemory + "m" + JAVA_OPTS += "-Xms" + workerMemoryString + " -Xmx" + workerMemoryString + " " + if (env.isDefinedAt("SPARK_JAVA_OPTS")) { + JAVA_OPTS += env("SPARK_JAVA_OPTS") + " " + } + // Commenting it out for now - so that people can refer to the properties if required. Remove it once cpuset version is pushed out. + // The context is, default gc for server class machines end up using all cores to do gc - hence if there are multiple containers in same + // node, spark gc effects all other containers performance (which can also be other spark containers) + // Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in multi-tenant environments. Not sure how default java gc behaves if it is + // limited to subset of cores on a node. +/* + else { + // If no java_opts specified, default to using -XX:+CMSIncrementalMode + // It might be possible that other modes/config is being done in SPARK_JAVA_OPTS, so we dont want to mess with it. + // In our expts, using (default) throughput collector has severe perf ramnifications in multi-tennent machines + // The options are based on + // http://www.oracle.com/technetwork/java/gc-tuning-5-138395.html#0.0.0.%20When%20to%20Use%20the%20Concurrent%20Low%20Pause%20Collector|outline + JAVA_OPTS += " -XX:+UseConcMarkSweepGC " + JAVA_OPTS += " -XX:+CMSIncrementalMode " + JAVA_OPTS += " -XX:+CMSIncrementalPacing " + JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 " + JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 " + } +*/ + + ctx.setUser(UserGroupInformation.getCurrentUser().getShortUserName()) + val commands = List[String]("java " + + " -server " + + // Kill if OOM is raised - leverage yarn's failure handling to cause rescheduling. + // Not killing the task leaves various aspects of the worker and (to some extent) the jvm in an inconsistent state. + // TODO: If the OOM is not recoverable by rescheduling it on different node, then do 'something' to fail job ... akin to blacklisting trackers in mapred ? + " -XX:OnOutOfMemoryError='kill %p' " + + JAVA_OPTS + + " spark.executor.StandaloneExecutorBackend " + + masterAddress + " " + + slaveId + " " + + hostname + " " + + workerCores + + " 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" + + " 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr") + logInfo("Setting up worker with commands: " + commands) + ctx.setCommands(commands) + + // Send the start request to the ContainerManager + val startReq = Records.newRecord(classOf[StartContainerRequest]) + .asInstanceOf[StartContainerRequest] + startReq.setContainerLaunchContext(ctx) + cm.startContainer(startReq) + } + + + def prepareLocalResources: HashMap[String, LocalResource] = { + logInfo("Preparing Local resources") + val locaResources = HashMap[String, LocalResource]() + + // Spark JAR + val sparkJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] + sparkJarResource.setType(LocalResourceType.FILE) + sparkJarResource.setVisibility(LocalResourceVisibility.APPLICATION) + sparkJarResource.setResource(ConverterUtils.getYarnUrlFromURI( + new URI(System.getenv("SPARK_YARN_JAR_PATH")))) + sparkJarResource.setTimestamp(System.getenv("SPARK_YARN_JAR_TIMESTAMP").toLong) + sparkJarResource.setSize(System.getenv("SPARK_YARN_JAR_SIZE").toLong) + locaResources("spark.jar") = sparkJarResource + // User JAR + val userJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] + userJarResource.setType(LocalResourceType.FILE) + userJarResource.setVisibility(LocalResourceVisibility.APPLICATION) + userJarResource.setResource(ConverterUtils.getYarnUrlFromURI( + new URI(System.getenv("SPARK_YARN_USERJAR_PATH")))) + userJarResource.setTimestamp(System.getenv("SPARK_YARN_USERJAR_TIMESTAMP").toLong) + userJarResource.setSize(System.getenv("SPARK_YARN_USERJAR_SIZE").toLong) + locaResources("app.jar") = userJarResource + + // Log4j conf - if available + if (System.getenv("SPARK_YARN_LOG4J_PATH") != null) { + val log4jConfResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] + log4jConfResource.setType(LocalResourceType.FILE) + log4jConfResource.setVisibility(LocalResourceVisibility.APPLICATION) + log4jConfResource.setResource(ConverterUtils.getYarnUrlFromURI( + new URI(System.getenv("SPARK_YARN_LOG4J_PATH")))) + log4jConfResource.setTimestamp(System.getenv("SPARK_YARN_LOG4J_TIMESTAMP").toLong) + log4jConfResource.setSize(System.getenv("SPARK_YARN_LOG4J_SIZE").toLong) + locaResources("log4j.properties") = log4jConfResource + } + + + logInfo("Prepared Local resources " + locaResources) + return locaResources + } + + def prepareEnvironment: HashMap[String, String] = { + val env = new HashMap[String, String]() + // should we add this ? + Apps.addToEnvironment(env, Environment.USER.name, Utils.getUserNameFromEnvironment()) + + // If log4j present, ensure ours overrides all others + if (System.getenv("SPARK_YARN_LOG4J_PATH") != null) { + // Which is correct ? + Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./log4j.properties") + Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./") + } + + Apps.addToEnvironment(env, Environment.CLASSPATH.name, "$CLASSPATH") + Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./*") + Client.populateHadoopClasspath(yarnConf, env) + + System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v } + return env + } + + def connectToCM: ContainerManager = { + val cmHostPortStr = container.getNodeId().getHost() + ":" + container.getNodeId().getPort() + val cmAddress = NetUtils.createSocketAddr(cmHostPortStr) + logInfo("Connecting to ContainerManager at " + cmHostPortStr) + return rpc.getProxy(classOf[ContainerManager], cmAddress, conf).asInstanceOf[ContainerManager] + } + +} diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/YarnAllocationHandler.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/YarnAllocationHandler.scala new file mode 100644 index 0000000000..cac9dab401 --- /dev/null +++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/YarnAllocationHandler.scala @@ -0,0 +1,547 @@ +package spark.deploy.yarn + +import spark.{Logging, Utils} +import spark.scheduler.SplitInfo +import scala.collection +import org.apache.hadoop.yarn.api.records.{AMResponse, ApplicationAttemptId, ContainerId, Priority, Resource, ResourceRequest, ContainerStatus, Container} +import spark.scheduler.cluster.{ClusterScheduler, StandaloneSchedulerBackend} +import org.apache.hadoop.yarn.api.protocolrecords.{AllocateRequest, AllocateResponse} +import org.apache.hadoop.yarn.util.{RackResolver, Records} +import java.util.concurrent.{CopyOnWriteArrayList, ConcurrentHashMap} +import java.util.concurrent.atomic.AtomicInteger +import org.apache.hadoop.yarn.api.AMRMProtocol +import collection.JavaConversions._ +import collection.mutable.{ArrayBuffer, HashMap, HashSet} +import org.apache.hadoop.conf.Configuration +import java.util.{Collections, Set => JSet} +import java.lang.{Boolean => JBoolean} + +object AllocationType extends Enumeration ("HOST", "RACK", "ANY") { + type AllocationType = Value + val HOST, RACK, ANY = Value +} + +// too many params ? refactor it 'somehow' ? +// needs to be mt-safe +// Need to refactor this to make it 'cleaner' ... right now, all computation is reactive : should make it +// more proactive and decoupled. +// Note that right now, we assume all node asks as uniform in terms of capabilities and priority +// Refer to http://developer.yahoo.com/blogs/hadoop/posts/2011/03/mapreduce-nextgen-scheduler/ for more info +// on how we are requesting for containers. +private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceManager: AMRMProtocol, + val appAttemptId: ApplicationAttemptId, + val maxWorkers: Int, val workerMemory: Int, val workerCores: Int, + val preferredHostToCount: Map[String, Int], + val preferredRackToCount: Map[String, Int]) + extends Logging { + + + // These three are locked on allocatedHostToContainersMap. Complementary data structures + // allocatedHostToContainersMap : containers which are running : host, Set<containerid> + // allocatedContainerToHostMap: container to host mapping + private val allocatedHostToContainersMap = new HashMap[String, collection.mutable.Set[ContainerId]]() + private val allocatedContainerToHostMap = new HashMap[ContainerId, String]() + // allocatedRackCount is populated ONLY if allocation happens (or decremented if this is an allocated node) + // As with the two data structures above, tightly coupled with them, and to be locked on allocatedHostToContainersMap + private val allocatedRackCount = new HashMap[String, Int]() + + // containers which have been released. + private val releasedContainerList = new CopyOnWriteArrayList[ContainerId]() + // containers to be released in next request to RM + private val pendingReleaseContainers = new ConcurrentHashMap[ContainerId, Boolean] + + private val numWorkersRunning = new AtomicInteger() + // Used to generate a unique id per worker + private val workerIdCounter = new AtomicInteger() + private val lastResponseId = new AtomicInteger() + + def getNumWorkersRunning: Int = numWorkersRunning.intValue + + + def isResourceConstraintSatisfied(container: Container): Boolean = { + container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + } + + def allocateContainers(workersToRequest: Int) { + // We need to send the request only once from what I understand ... but for now, not modifying this much. + + // Keep polling the Resource Manager for containers + val amResp = allocateWorkerResources(workersToRequest).getAMResponse + + val _allocatedContainers = amResp.getAllocatedContainers() + if (_allocatedContainers.size > 0) { + + + logDebug("Allocated " + _allocatedContainers.size + " containers, current count " + + numWorkersRunning.get() + ", to-be-released " + releasedContainerList + + ", pendingReleaseContainers : " + pendingReleaseContainers) + logDebug("Cluster Resources: " + amResp.getAvailableResources) + + val hostToContainers = new HashMap[String, ArrayBuffer[Container]]() + + // ignore if not satisfying constraints { + for (container <- _allocatedContainers) { + if (isResourceConstraintSatisfied(container)) { + // allocatedContainers += container + + val host = container.getNodeId.getHost + val containers = hostToContainers.getOrElseUpdate(host, new ArrayBuffer[Container]()) + + containers += container + } + // Add all ignored containers to released list + else releasedContainerList.add(container.getId()) + } + + // Find the appropriate containers to use + // Slightly non trivial groupBy I guess ... + val dataLocalContainers = new HashMap[String, ArrayBuffer[Container]]() + val rackLocalContainers = new HashMap[String, ArrayBuffer[Container]]() + val offRackContainers = new HashMap[String, ArrayBuffer[Container]]() + + for (candidateHost <- hostToContainers.keySet) + { + val maxExpectedHostCount = preferredHostToCount.getOrElse(candidateHost, 0) + val requiredHostCount = maxExpectedHostCount - allocatedContainersOnHost(candidateHost) + + var remainingContainers = hostToContainers.get(candidateHost).getOrElse(null) + assert(remainingContainers != null) + + if (requiredHostCount >= remainingContainers.size){ + // Since we got <= required containers, add all to dataLocalContainers + dataLocalContainers.put(candidateHost, remainingContainers) + // all consumed + remainingContainers = null + } + else if (requiredHostCount > 0) { + // container list has more containers than we need for data locality. + // Split into two : data local container count of (remainingContainers.size - requiredHostCount) + // and rest as remainingContainer + val (dataLocal, remaining) = remainingContainers.splitAt(remainingContainers.size - requiredHostCount) + dataLocalContainers.put(candidateHost, dataLocal) + // remainingContainers = remaining + + // yarn has nasty habit of allocating a tonne of containers on a host - discourage this : + // add remaining to release list. If we have insufficient containers, next allocation cycle + // will reallocate (but wont treat it as data local) + for (container <- remaining) releasedContainerList.add(container.getId()) + remainingContainers = null + } + + // now rack local + if (remainingContainers != null){ + val rack = YarnAllocationHandler.lookupRack(conf, candidateHost) + + if (rack != null){ + val maxExpectedRackCount = preferredRackToCount.getOrElse(rack, 0) + val requiredRackCount = maxExpectedRackCount - allocatedContainersOnRack(rack) - + rackLocalContainers.get(rack).getOrElse(List()).size + + + if (requiredRackCount >= remainingContainers.size){ + // Add all to dataLocalContainers + dataLocalContainers.put(rack, remainingContainers) + // all consumed + remainingContainers = null + } + else if (requiredRackCount > 0) { + // container list has more containers than we need for data locality. + // Split into two : data local container count of (remainingContainers.size - requiredRackCount) + // and rest as remainingContainer + val (rackLocal, remaining) = remainingContainers.splitAt(remainingContainers.size - requiredRackCount) + val existingRackLocal = rackLocalContainers.getOrElseUpdate(rack, new ArrayBuffer[Container]()) + + existingRackLocal ++= rackLocal + remainingContainers = remaining + } + } + } + + // If still not consumed, then it is off rack host - add to that list. + if (remainingContainers != null){ + offRackContainers.put(candidateHost, remainingContainers) + } + } + + // Now that we have split the containers into various groups, go through them in order : + // first host local, then rack local and then off rack (everything else). + // Note that the list we create below tries to ensure that not all containers end up within a host + // if there are sufficiently large number of hosts/containers. + + val allocatedContainers = new ArrayBuffer[Container](_allocatedContainers.size) + allocatedContainers ++= ClusterScheduler.prioritizeContainers(dataLocalContainers) + allocatedContainers ++= ClusterScheduler.prioritizeContainers(rackLocalContainers) + allocatedContainers ++= ClusterScheduler.prioritizeContainers(offRackContainers) + + // Run each of the allocated containers + for (container <- allocatedContainers) { + val numWorkersRunningNow = numWorkersRunning.incrementAndGet() + val workerHostname = container.getNodeId.getHost + val containerId = container.getId + + assert (container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)) + + if (numWorkersRunningNow > maxWorkers) { + logInfo("Ignoring container " + containerId + " at host " + workerHostname + + " .. we already have required number of containers") + releasedContainerList.add(containerId) + // reset counter back to old value. + numWorkersRunning.decrementAndGet() + } + else { + // deallocate + allocate can result in reusing id's wrongly - so use a different counter (workerIdCounter) + val workerId = workerIdCounter.incrementAndGet().toString + val masterUrl = "akka://spark@%s:%s/user/%s".format( + System.getProperty("spark.master.host"), System.getProperty("spark.master.port"), + StandaloneSchedulerBackend.ACTOR_NAME) + + logInfo("launching container on " + containerId + " host " + workerHostname) + // just to be safe, simply remove it from pendingReleaseContainers. Should not be there, but .. + pendingReleaseContainers.remove(containerId) + + val rack = YarnAllocationHandler.lookupRack(conf, workerHostname) + allocatedHostToContainersMap.synchronized { + val containerSet = allocatedHostToContainersMap.getOrElseUpdate(workerHostname, new HashSet[ContainerId]()) + + containerSet += containerId + allocatedContainerToHostMap.put(containerId, workerHostname) + if (rack != null) allocatedRackCount.put(rack, allocatedRackCount.getOrElse(rack, 0) + 1) + } + + new Thread( + new WorkerRunnable(container, conf, masterUrl, workerId, + workerHostname, workerMemory, workerCores) + ).start() + } + } + logDebug("After allocated " + allocatedContainers.size + " containers (orig : " + + _allocatedContainers.size + "), current count " + numWorkersRunning.get() + + ", to-be-released " + releasedContainerList + ", pendingReleaseContainers : " + pendingReleaseContainers) + } + + + val completedContainers = amResp.getCompletedContainersStatuses() + if (completedContainers.size > 0){ + logDebug("Completed " + completedContainers.size + " containers, current count " + numWorkersRunning.get() + + ", to-be-released " + releasedContainerList + ", pendingReleaseContainers : " + pendingReleaseContainers) + + for (completedContainer <- completedContainers){ + val containerId = completedContainer.getContainerId + + // Was this released by us ? If yes, then simply remove from containerSet and move on. + if (pendingReleaseContainers.containsKey(containerId)) { + pendingReleaseContainers.remove(containerId) + } + else { + // simply decrement count - next iteration of ReporterThread will take care of allocating ! + numWorkersRunning.decrementAndGet() + logInfo("Container completed ? nodeId: " + containerId + ", state " + completedContainer.getState + + " httpaddress: " + completedContainer.getDiagnostics) + } + + allocatedHostToContainersMap.synchronized { + if (allocatedContainerToHostMap.containsKey(containerId)) { + val host = allocatedContainerToHostMap.get(containerId).getOrElse(null) + assert (host != null) + + val containerSet = allocatedHostToContainersMap.get(host).getOrElse(null) + assert (containerSet != null) + + containerSet -= containerId + if (containerSet.isEmpty) allocatedHostToContainersMap.remove(host) + else allocatedHostToContainersMap.update(host, containerSet) + + allocatedContainerToHostMap -= containerId + + // doing this within locked context, sigh ... move to outside ? + val rack = YarnAllocationHandler.lookupRack(conf, host) + if (rack != null) { + val rackCount = allocatedRackCount.getOrElse(rack, 0) - 1 + if (rackCount > 0) allocatedRackCount.put(rack, rackCount) + else allocatedRackCount.remove(rack) + } + } + } + } + logDebug("After completed " + completedContainers.size + " containers, current count " + + numWorkersRunning.get() + ", to-be-released " + releasedContainerList + + ", pendingReleaseContainers : " + pendingReleaseContainers) + } + } + + def createRackResourceRequests(hostContainers: List[ResourceRequest]): List[ResourceRequest] = { + // First generate modified racks and new set of hosts under it : then issue requests + val rackToCounts = new HashMap[String, Int]() + + // Within this lock - used to read/write to the rack related maps too. + for (container <- hostContainers) { + val candidateHost = container.getHostName + val candidateNumContainers = container.getNumContainers + assert(YarnAllocationHandler.ANY_HOST != candidateHost) + + val rack = YarnAllocationHandler.lookupRack(conf, candidateHost) + if (rack != null) { + var count = rackToCounts.getOrElse(rack, 0) + count += candidateNumContainers + rackToCounts.put(rack, count) + } + } + + val requestedContainers: ArrayBuffer[ResourceRequest] = + new ArrayBuffer[ResourceRequest](rackToCounts.size) + for ((rack, count) <- rackToCounts){ + requestedContainers += + createResourceRequest(AllocationType.RACK, rack, count, YarnAllocationHandler.PRIORITY) + } + + requestedContainers.toList + } + + def allocatedContainersOnHost(host: String): Int = { + var retval = 0 + allocatedHostToContainersMap.synchronized { + retval = allocatedHostToContainersMap.getOrElse(host, Set()).size + } + retval + } + + def allocatedContainersOnRack(rack: String): Int = { + var retval = 0 + allocatedHostToContainersMap.synchronized { + retval = allocatedRackCount.getOrElse(rack, 0) + } + retval + } + + private def allocateWorkerResources(numWorkers: Int): AllocateResponse = { + + var resourceRequests: List[ResourceRequest] = null + + // default. + if (numWorkers <= 0 || preferredHostToCount.isEmpty) { + logDebug("numWorkers: " + numWorkers + ", host preferences ? " + preferredHostToCount.isEmpty) + resourceRequests = List( + createResourceRequest(AllocationType.ANY, null, numWorkers, YarnAllocationHandler.PRIORITY)) + } + else { + // request for all hosts in preferred nodes and for numWorkers - + // candidates.size, request by default allocation policy. + val hostContainerRequests: ArrayBuffer[ResourceRequest] = + new ArrayBuffer[ResourceRequest](preferredHostToCount.size) + for ((candidateHost, candidateCount) <- preferredHostToCount) { + val requiredCount = candidateCount - allocatedContainersOnHost(candidateHost) + + if (requiredCount > 0) { + hostContainerRequests += + createResourceRequest(AllocationType.HOST, candidateHost, requiredCount, YarnAllocationHandler.PRIORITY) + } + } + val rackContainerRequests: List[ResourceRequest] = createRackResourceRequests(hostContainerRequests.toList) + + val anyContainerRequests: ResourceRequest = + createResourceRequest(AllocationType.ANY, null, numWorkers, YarnAllocationHandler.PRIORITY) + + val containerRequests: ArrayBuffer[ResourceRequest] = + new ArrayBuffer[ResourceRequest](hostContainerRequests.size() + rackContainerRequests.size() + 1) + + containerRequests ++= hostContainerRequests + containerRequests ++= rackContainerRequests + containerRequests += anyContainerRequests + + resourceRequests = containerRequests.toList + } + + val req = Records.newRecord(classOf[AllocateRequest]) + req.setResponseId(lastResponseId.incrementAndGet) + req.setApplicationAttemptId(appAttemptId) + + req.addAllAsks(resourceRequests) + + val releasedContainerList = createReleasedContainerList() + req.addAllReleases(releasedContainerList) + + + + if (numWorkers > 0) { + logInfo("Allocating " + numWorkers + " worker containers with " + (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + " of memory each.") + } + else { + logDebug("Empty allocation req .. release : " + releasedContainerList) + } + + for (req <- resourceRequests) { + logInfo("rsrcRequest ... host : " + req.getHostName + ", numContainers : " + req.getNumContainers + + ", p = " + req.getPriority().getPriority + ", capability: " + req.getCapability) + } + resourceManager.allocate(req) + } + + + private def createResourceRequest(requestType: AllocationType.AllocationType, + resource:String, numWorkers: Int, priority: Int): ResourceRequest = { + + // If hostname specified, we need atleast two requests - node local and rack local. + // There must be a third request - which is ANY : that will be specially handled. + requestType match { + case AllocationType.HOST => { + assert (YarnAllocationHandler.ANY_HOST != resource) + + val hostname = resource + val nodeLocal = createResourceRequestImpl(hostname, numWorkers, priority) + + // add to host->rack mapping + YarnAllocationHandler.populateRackInfo(conf, hostname) + + nodeLocal + } + + case AllocationType.RACK => { + val rack = resource + createResourceRequestImpl(rack, numWorkers, priority) + } + + case AllocationType.ANY => { + createResourceRequestImpl(YarnAllocationHandler.ANY_HOST, numWorkers, priority) + } + + case _ => throw new IllegalArgumentException("Unexpected/unsupported request type .. " + requestType) + } + } + + private def createResourceRequestImpl(hostname:String, numWorkers: Int, priority: Int): ResourceRequest = { + + val rsrcRequest = Records.newRecord(classOf[ResourceRequest]) + val memCapability = Records.newRecord(classOf[Resource]) + // There probably is some overhead here, let's reserve a bit more memory. + memCapability.setMemory(workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + rsrcRequest.setCapability(memCapability) + + val pri = Records.newRecord(classOf[Priority]) + pri.setPriority(priority) + rsrcRequest.setPriority(pri) + + rsrcRequest.setHostName(hostname) + + rsrcRequest.setNumContainers(java.lang.Math.max(numWorkers, 0)) + rsrcRequest + } + + def createReleasedContainerList(): ArrayBuffer[ContainerId] = { + + val retval = new ArrayBuffer[ContainerId](1) + // iterator on COW list ... + for (container <- releasedContainerList.iterator()){ + retval += container + } + // remove from the original list. + if (! retval.isEmpty) { + releasedContainerList.removeAll(retval) + for (v <- retval) pendingReleaseContainers.put(v, true) + logInfo("Releasing " + retval.size + " containers. pendingReleaseContainers : " + + pendingReleaseContainers) + } + + retval + } +} + +object YarnAllocationHandler { + + val ANY_HOST = "*" + // all requests are issued with same priority : we do not (yet) have any distinction between + // request types (like map/reduce in hadoop for example) + val PRIORITY = 1 + + // Additional memory overhead - in mb + val MEMORY_OVERHEAD = 384 + + // host to rack map - saved from allocation requests + // We are expecting this not to change. + // Note that it is possible for this to change : and RM will indicate that to us via update + // response to allocate. But we are punting on handling that for now. + private val hostToRack = new ConcurrentHashMap[String, String]() + private val rackToHostSet = new ConcurrentHashMap[String, JSet[String]]() + + def newAllocator(conf: Configuration, + resourceManager: AMRMProtocol, appAttemptId: ApplicationAttemptId, + args: ApplicationMasterArguments, + map: collection.Map[String, collection.Set[SplitInfo]]): YarnAllocationHandler = { + + val (hostToCount, rackToCount) = generateNodeToWeight(conf, map) + + + new YarnAllocationHandler(conf, resourceManager, appAttemptId, args.numWorkers, + args.workerMemory, args.workerCores, hostToCount, rackToCount) + } + + def newAllocator(conf: Configuration, + resourceManager: AMRMProtocol, appAttemptId: ApplicationAttemptId, + maxWorkers: Int, workerMemory: Int, workerCores: Int, + map: collection.Map[String, collection.Set[SplitInfo]]): YarnAllocationHandler = { + + val (hostToCount, rackToCount) = generateNodeToWeight(conf, map) + + new YarnAllocationHandler(conf, resourceManager, appAttemptId, maxWorkers, + workerMemory, workerCores, hostToCount, rackToCount) + } + + // A simple method to copy the split info map. + private def generateNodeToWeight(conf: Configuration, input: collection.Map[String, collection.Set[SplitInfo]]) : + // host to count, rack to count + (Map[String, Int], Map[String, Int]) = { + + if (input == null) return (Map[String, Int](), Map[String, Int]()) + + val hostToCount = new HashMap[String, Int] + val rackToCount = new HashMap[String, Int] + + for ((host, splits) <- input) { + val hostCount = hostToCount.getOrElse(host, 0) + hostToCount.put(host, hostCount + splits.size) + + val rack = lookupRack(conf, host) + if (rack != null){ + val rackCount = rackToCount.getOrElse(host, 0) + rackToCount.put(host, rackCount + splits.size) + } + } + + (hostToCount.toMap, rackToCount.toMap) + } + + def lookupRack(conf: Configuration, host: String): String = { + if (! hostToRack.contains(host)) populateRackInfo(conf, host) + hostToRack.get(host) + } + + def fetchCachedHostsForRack(rack: String): Option[Set[String]] = { + val set = rackToHostSet.get(rack) + if (set == null) return None + + // No better way to get a Set[String] from JSet ? + val convertedSet: collection.mutable.Set[String] = set + Some(convertedSet.toSet) + } + + def populateRackInfo(conf: Configuration, hostname: String) { + Utils.checkHost(hostname) + + if (!hostToRack.containsKey(hostname)) { + // If there are repeated failures to resolve, all to an ignore list ? + val rackInfo = RackResolver.resolve(conf, hostname) + if (rackInfo != null && rackInfo.getNetworkLocation != null) { + val rack = rackInfo.getNetworkLocation + hostToRack.put(hostname, rack) + if (! rackToHostSet.containsKey(rack)) { + rackToHostSet.putIfAbsent(rack, Collections.newSetFromMap(new ConcurrentHashMap[String, JBoolean]())) + } + rackToHostSet.get(rack).add(hostname) + + // Since RackResolver caches, we are disabling this for now ... + } /* else { + // right ? Else we will keep calling rack resolver in case we cant resolve rack info ... + hostToRack.put(hostname, null) + } */ + } + } +} diff --git a/core/src/hadoop2-yarn/scala/spark/scheduler/cluster/YarnClusterScheduler.scala b/core/src/hadoop2-yarn/scala/spark/scheduler/cluster/YarnClusterScheduler.scala new file mode 100644 index 0000000000..ed732d36bf --- /dev/null +++ b/core/src/hadoop2-yarn/scala/spark/scheduler/cluster/YarnClusterScheduler.scala @@ -0,0 +1,42 @@ +package spark.scheduler.cluster + +import spark._ +import spark.deploy.yarn.{ApplicationMaster, YarnAllocationHandler} +import org.apache.hadoop.conf.Configuration + +/** + * + * This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of ApplicationMaster, etc is done + */ +private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) { + + def this(sc: SparkContext) = this(sc, new Configuration()) + + // Nothing else for now ... initialize application master : which needs sparkContext to determine how to allocate + // Note that only the first creation of SparkContext influences (and ideally, there must be only one SparkContext, right ?) + // Subsequent creations are ignored - since nodes are already allocated by then. + + + // By default, rack is unknown + override def getRackForHost(hostPort: String): Option[String] = { + val host = Utils.parseHostPort(hostPort)._1 + val retval = YarnAllocationHandler.lookupRack(conf, host) + if (retval != null) Some(retval) else None + } + + // By default, if rack is unknown, return nothing + override def getCachedHostsForRack(rack: String): Option[Set[String]] = { + if (rack == None || rack == null) return None + + YarnAllocationHandler.fetchCachedHostsForRack(rack) + } + + override def postStartHook() { + val sparkContextInitialized = ApplicationMaster.sparkContextInitialized(sc) + if (sparkContextInitialized){ + // Wait for a few seconds for the slaves to bootstrap and register with master - best case attempt + Thread.sleep(3000L) + } + logInfo("YarnClusterScheduler.postStartHook done") + } +} diff --git a/core/src/hadoop2/scala/spark/deploy/SparkHadoopUtil.scala b/core/src/hadoop2/scala/spark/deploy/SparkHadoopUtil.scala new file mode 100644 index 0000000000..d4badbc5c4 --- /dev/null +++ b/core/src/hadoop2/scala/spark/deploy/SparkHadoopUtil.scala @@ -0,0 +1,18 @@ +package spark.deploy + +/** + * Contains util methods to interact with Hadoop from spark. + */ +object SparkHadoopUtil { + + def getUserNameFromEnvironment(): String = { + // defaulting to -D ... + System.getProperty("user.name") + } + + def runAsUser(func: (Product) => Unit, args: Product) { + + // Add support, if exists - for now, simply run func ! + func(args) + } +} diff --git a/core/src/main/scala/spark/ClosureCleaner.scala b/core/src/main/scala/spark/ClosureCleaner.scala index 98525b99c8..50d6a1c5c9 100644 --- a/core/src/main/scala/spark/ClosureCleaner.scala +++ b/core/src/main/scala/spark/ClosureCleaner.scala @@ -8,12 +8,20 @@ import scala.collection.mutable.Set import org.objectweb.asm.{ClassReader, MethodVisitor, Type} import org.objectweb.asm.commons.EmptyVisitor import org.objectweb.asm.Opcodes._ +import java.io.{InputStream, IOException, ByteArrayOutputStream, ByteArrayInputStream, BufferedInputStream} private[spark] object ClosureCleaner extends Logging { // Get an ASM class reader for a given class from the JAR that loaded it private def getClassReader(cls: Class[_]): ClassReader = { - new ClassReader(cls.getResourceAsStream( - cls.getName.replaceFirst("^.*\\.", "") + ".class")) + // Copy data over, before delegating to ClassReader - else we can run out of open file handles. + val className = cls.getName.replaceFirst("^.*\\.", "") + ".class" + val resourceStream = cls.getResourceAsStream(className) + // todo: Fixme - continuing with earlier behavior ... + if (resourceStream == null) return new ClassReader(resourceStream) + + val baos = new ByteArrayOutputStream(128) + Utils.copyStream(resourceStream, baos, true) + new ClassReader(new ByteArrayInputStream(baos.toByteArray)) } // Check whether a class represents a Scala closure diff --git a/core/src/main/scala/spark/FetchFailedException.scala b/core/src/main/scala/spark/FetchFailedException.scala index a953081d24..40b0193f19 100644 --- a/core/src/main/scala/spark/FetchFailedException.scala +++ b/core/src/main/scala/spark/FetchFailedException.scala @@ -3,18 +3,25 @@ package spark import spark.storage.BlockManagerId private[spark] class FetchFailedException( - val bmAddress: BlockManagerId, - val shuffleId: Int, - val mapId: Int, - val reduceId: Int, + taskEndReason: TaskEndReason, + message: String, cause: Throwable) extends Exception { - - override def getMessage(): String = - "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId) + + def this (bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int, cause: Throwable) = + this(FetchFailed(bmAddress, shuffleId, mapId, reduceId), + "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId), + cause) + + def this (shuffleId: Int, reduceId: Int, cause: Throwable) = + this(FetchFailed(null, shuffleId, -1, reduceId), + "Unable to fetch locations from master: %d %d".format(shuffleId, reduceId), cause) + + override def getMessage(): String = message + override def getCause(): Throwable = cause - def toTaskEndReason: TaskEndReason = - FetchFailed(bmAddress, shuffleId, mapId, reduceId) + def toTaskEndReason: TaskEndReason = taskEndReason + } diff --git a/core/src/main/scala/spark/Logging.scala b/core/src/main/scala/spark/Logging.scala index 7c1c1bb144..0fc8c31463 100644 --- a/core/src/main/scala/spark/Logging.scala +++ b/core/src/main/scala/spark/Logging.scala @@ -68,6 +68,10 @@ trait Logging { if (log.isErrorEnabled) log.error(msg, throwable) } + protected def isTraceEnabled(): Boolean = { + log.isTraceEnabled + } + // Method for ensuring that logging is initialized, to avoid having multiple // threads do it concurrently (as SLF4J initialization is not thread safe). protected def initLogging() { log } diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index 866d630a6d..6e9da02893 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -1,7 +1,6 @@ package spark import java.io._ -import java.util.concurrent.ConcurrentHashMap import java.util.zip.{GZIPInputStream, GZIPOutputStream} import scala.collection.mutable.HashMap @@ -12,8 +11,7 @@ import akka.dispatch._ import akka.pattern.ask import akka.remote._ import akka.util.Duration -import akka.util.Timeout -import akka.util.duration._ + import spark.scheduler.MapStatus import spark.storage.BlockManagerId @@ -40,10 +38,12 @@ private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Ac private[spark] class MapOutputTracker extends Logging { + private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") + // Set to the MapOutputTrackerActor living on the driver var trackerActor: ActorRef = _ - var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] + private var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] // Incremented every time a fetch fails so that client nodes know to clear // their cache of map output locations if this happens. @@ -52,7 +52,7 @@ private[spark] class MapOutputTracker extends Logging { // Cache a serialized version of the output statuses for each shuffle to send them out faster var cacheGeneration = generation - val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]] + private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]] val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup) @@ -60,7 +60,6 @@ private[spark] class MapOutputTracker extends Logging { // throw a SparkException if this fails. def askTracker(message: Any): Any = { try { - val timeout = 10.seconds val future = trackerActor.ask(message)(timeout) return Await.result(future, timeout) } catch { @@ -77,10 +76,9 @@ private[spark] class MapOutputTracker extends Logging { } def registerShuffle(shuffleId: Int, numMaps: Int) { - if (mapStatuses.get(shuffleId) != None) { + if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) { throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") } - mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)) } def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) { @@ -101,8 +99,9 @@ private[spark] class MapOutputTracker extends Logging { } def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { - var array = mapStatuses(shuffleId) - if (array != null) { + var arrayOpt = mapStatuses.get(shuffleId) + if (arrayOpt.isDefined && arrayOpt.get != null) { + var array = arrayOpt.get array.synchronized { if (array(mapId) != null && array(mapId).location == bmAddress) { array(mapId) = null @@ -115,13 +114,14 @@ private[spark] class MapOutputTracker extends Logging { } // Remembers which map output locations are currently being fetched on a worker - val fetching = new HashSet[Int] + private val fetching = new HashSet[Int] // Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = { val statuses = mapStatuses.get(shuffleId).orNull if (statuses == null) { logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") + var fetchedStatuses: Array[MapStatus] = null fetching.synchronized { if (fetching.contains(shuffleId)) { // Someone else is fetching it; wait for them to be done @@ -132,31 +132,49 @@ private[spark] class MapOutputTracker extends Logging { case e: InterruptedException => } } - return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, mapStatuses(shuffleId)) - } else { + } + + // Either while we waited the fetch happened successfully, or + // someone fetched it in between the get and the fetching.synchronized. + fetchedStatuses = mapStatuses.get(shuffleId).orNull + if (fetchedStatuses == null) { + // We have to do the fetch, get others to wait for us. fetching += shuffleId } } - // We won the race to fetch the output locs; do so - logInfo("Doing the fetch; tracker actor = " + trackerActor) - val host = System.getProperty("spark.hostname", Utils.localHostName) - // This try-finally prevents hangs due to timeouts: - var fetchedStatuses: Array[MapStatus] = null - try { - val fetchedBytes = - askTracker(GetMapOutputStatuses(shuffleId, host)).asInstanceOf[Array[Byte]] - fetchedStatuses = deserializeStatuses(fetchedBytes) - logInfo("Got the output locations") - mapStatuses.put(shuffleId, fetchedStatuses) - } finally { - fetching.synchronized { - fetching -= shuffleId - fetching.notifyAll() + + if (fetchedStatuses == null) { + // We won the race to fetch the output locs; do so + logInfo("Doing the fetch; tracker actor = " + trackerActor) + val hostPort = Utils.localHostPort() + // This try-finally prevents hangs due to timeouts: + var fetchedStatuses: Array[MapStatus] = null + try { + val fetchedBytes = + askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]] + fetchedStatuses = deserializeStatuses(fetchedBytes) + logInfo("Got the output locations") + mapStatuses.put(shuffleId, fetchedStatuses) + } finally { + fetching.synchronized { + fetching -= shuffleId + fetching.notifyAll() + } + } + } + if (fetchedStatuses != null) { + fetchedStatuses.synchronized { + return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) } } - return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) + else{ + throw new FetchFailedException(null, shuffleId, -1, reduceId, + new Exception("Missing all output locations for shuffle " + shuffleId)) + } } else { - return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses) + statuses.synchronized { + return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses) + } } } @@ -194,7 +212,8 @@ private[spark] class MapOutputTracker extends Logging { generationLock.synchronized { if (newGen > generation) { logInfo("Updating generation to " + newGen + " and clearing cache") - mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] + // mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] + mapStatuses.clear() generation = newGen } } @@ -232,10 +251,13 @@ private[spark] class MapOutputTracker extends Logging { // Serialize an array of map output locations into an efficient byte format so that we can send // it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will // generally be pretty compressible because many map outputs will be on the same hostname. - def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = { + private def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = { val out = new ByteArrayOutputStream val objOut = new ObjectOutputStream(new GZIPOutputStream(out)) - objOut.writeObject(statuses) + // Since statuses can be modified in parallel, sync on it + statuses.synchronized { + objOut.writeObject(statuses) + } objOut.close() out.toByteArray } @@ -243,7 +265,10 @@ private[spark] class MapOutputTracker extends Logging { // Opposite of serializeStatuses. def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = { val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes))) - objIn.readObject().asInstanceOf[Array[MapStatus]] + objIn.readObject(). + // // drop all null's from status - not sure why they are occuring though. Causes NPE downstream in slave if present + // comment this out - nulls could be due to missing location ? + asInstanceOf[Array[MapStatus]] // .filter( _ != null ) } } @@ -253,14 +278,11 @@ private[spark] object MapOutputTracker { // Convert an array of MapStatuses to locations and sizes for a given reduce ID. If // any of the statuses is null (indicating a missing location due to a failed mapper), // throw a FetchFailedException. - def convertMapStatuses( + private def convertMapStatuses( shuffleId: Int, reduceId: Int, statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = { - if (statuses == null) { - throw new FetchFailedException(null, shuffleId, -1, reduceId, - new Exception("Missing all output locations for shuffle " + shuffleId)) - } + assert (statuses != null) statuses.map { status => if (status == null) { diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 4957a54c1b..e853bce2c4 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -37,7 +37,7 @@ import spark.partial.PartialResult import spark.rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD, ParallelCollectionRDD} import spark.scheduler._ import spark.scheduler.local.LocalScheduler -import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler} +import spark.scheduler.cluster.{StandaloneSchedulerBackend, SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler} import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import spark.storage.BlockManagerUI import spark.util.{MetadataCleaner, TimeStampedHashMap} @@ -59,7 +59,10 @@ class SparkContext( val appName: String, val sparkHome: String = null, val jars: Seq[String] = Nil, - val environment: Map[String, String] = Map()) + val environment: Map[String, String] = Map(), + // This is used only by yarn for now, but should be relevant to other cluster types (mesos, etc) too. + // This is typically generated from InputFormatInfo.computePreferredLocations .. host, set of data-local splits on host + val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = scala.collection.immutable.Map()) extends Logging { // Ensure logging is initialized before we spawn any threads @@ -67,7 +70,7 @@ class SparkContext( // Set Spark driver host and port system properties if (System.getProperty("spark.driver.host") == null) { - System.setProperty("spark.driver.host", Utils.localIpAddress) + System.setProperty("spark.driver.host", Utils.localHostName()) } if (System.getProperty("spark.driver.port") == null) { System.setProperty("spark.driver.port", "0") @@ -99,7 +102,7 @@ class SparkContext( // Add each JAR given through the constructor - jars.foreach { addJar(_) } + if (jars != null) jars.foreach { addJar(_) } // Environment variables to pass to our executors private[spark] val executorEnvs = HashMap[String, String]() @@ -111,7 +114,7 @@ class SparkContext( executorEnvs(key) = value } } - executorEnvs ++= environment + if (environment != null) executorEnvs ++= environment // Create and start the scheduler private var taskScheduler: TaskScheduler = { @@ -164,6 +167,22 @@ class SparkContext( } scheduler + case "yarn-standalone" => + val scheduler = try { + val clazz = Class.forName("spark.scheduler.cluster.YarnClusterScheduler") + val cons = clazz.getConstructor(classOf[SparkContext]) + cons.newInstance(this).asInstanceOf[ClusterScheduler] + } catch { + // TODO: Enumerate the exact reasons why it can fail + // But irrespective of it, it means we cannot proceed ! + case th: Throwable => { + throw new SparkException("YARN mode not available ?", th) + } + } + val backend = new StandaloneSchedulerBackend(scheduler, this.env.actorSystem) + scheduler.initialize(backend) + scheduler + case _ => if (MESOS_REGEX.findFirstIn(master).isEmpty) { logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master)) @@ -183,7 +202,7 @@ class SparkContext( } taskScheduler.start() - private var dagScheduler = new DAGScheduler(taskScheduler) + @volatile private var dagScheduler = new DAGScheduler(taskScheduler) dagScheduler.start() /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ @@ -207,6 +226,9 @@ class SparkContext( private[spark] var checkpointDir: Option[String] = None + // Post init + taskScheduler.postStartHook() + // Methods for creating RDDs /** Distribute a local Scala collection to form an RDD. */ @@ -471,7 +493,7 @@ class SparkContext( */ def getExecutorMemoryStatus: Map[String, (Long, Long)] = { env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) => - (blockManagerId.ip + ":" + blockManagerId.port, mem) + (blockManagerId.host + ":" + blockManagerId.port, mem) } } @@ -527,10 +549,13 @@ class SparkContext( /** Shut down the SparkContext. */ def stop() { - if (dagScheduler != null) { + // Do this only if not stopped already - best case effort. + // prevent NPE if stopped more than once. + val dagSchedulerCopy = dagScheduler + dagScheduler = null + if (dagSchedulerCopy != null) { metadataCleaner.cancel() - dagScheduler.stop() - dagScheduler = null + dagSchedulerCopy.stop() taskScheduler = null // TODO: Cache.stop()? env.stop() @@ -546,6 +571,7 @@ class SparkContext( } } + /** * Get Spark's home location from either a value set through the constructor, * or the spark.home Java property, or the SPARK_HOME environment variable diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 7157fd2688..ffb40bab3a 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -72,6 +72,16 @@ object SparkEnv extends Logging { System.setProperty("spark.driver.port", boundPort.toString) } + // set only if unset until now. + if (System.getProperty("spark.hostPort", null) == null) { + if (!isDriver){ + // unexpected + Utils.logErrorWithStack("Unexpected NOT to have spark.hostPort set") + } + Utils.checkHost(hostname) + System.setProperty("spark.hostPort", hostname + ":" + boundPort) + } + val classLoader = Thread.currentThread.getContextClassLoader // Create an instance of the class named by the given Java system property, or by @@ -88,9 +98,10 @@ object SparkEnv extends Logging { logInfo("Registering " + name) actorSystem.actorOf(Props(newActor), name = name) } else { - val driverIp: String = System.getProperty("spark.driver.host", "localhost") + val driverHost: String = System.getProperty("spark.driver.host", "localhost") val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt - val url = "akka://spark@%s:%s/user/%s".format(driverIp, driverPort, name) + Utils.checkHost(driverHost, "Expected hostname") + val url = "akka://spark@%s:%s/user/%s".format(driverHost, driverPort, name) logInfo("Connecting to " + name + ": " + url) actorSystem.actorFor(url) } diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 81daacf958..14bb153d54 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -1,18 +1,18 @@ package spark import java.io._ -import java.net._ +import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address, ServerSocket} import java.util.{Locale, Random, UUID} -import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} +import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadFactory, ThreadPoolExecutor} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, FileSystem, FileUtil} -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.collection.JavaConversions._ import scala.io.Source import com.google.common.io.Files import com.google.common.util.concurrent.ThreadFactoryBuilder -import scala.Some import spark.serializer.SerializerInstance +import spark.deploy.SparkHadoopUtil /** * Various utility methods used by Spark. @@ -68,6 +68,41 @@ private object Utils extends Logging { return buf } + + private val shutdownDeletePaths = new collection.mutable.HashSet[String]() + + // Register the path to be deleted via shutdown hook + def registerShutdownDeleteDir(file: File) { + val absolutePath = file.getAbsolutePath() + shutdownDeletePaths.synchronized { + shutdownDeletePaths += absolutePath + } + } + + // Is the path already registered to be deleted via a shutdown hook ? + def hasShutdownDeleteDir(file: File): Boolean = { + val absolutePath = file.getAbsolutePath() + shutdownDeletePaths.synchronized { + shutdownDeletePaths.contains(absolutePath) + } + } + + // Note: if file is child of some registered path, while not equal to it, then return true; else false + // This is to ensure that two shutdown hooks do not try to delete each others paths - resulting in IOException + // and incomplete cleanup + def hasRootAsShutdownDeleteDir(file: File): Boolean = { + + val absolutePath = file.getAbsolutePath() + + val retval = shutdownDeletePaths.synchronized { + shutdownDeletePaths.find(path => ! absolutePath.equals(path) && absolutePath.startsWith(path) ).isDefined + } + + if (retval) logInfo("path = " + file + ", already present as root for deletion.") + + retval + } + /** Create a temporary directory inside the given parent directory */ def createTempDir(root: String = System.getProperty("java.io.tmpdir")): File = { var attempts = 0 @@ -86,10 +121,14 @@ private object Utils extends Logging { } } catch { case e: IOException => ; } } + + registerShutdownDeleteDir(dir) + // Add a shutdown hook to delete the temp dir when the JVM exits Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dir " + dir) { override def run() { - Utils.deleteRecursively(dir) + // Attempt to delete if some patch which is parent of this is not already registered. + if (! hasRootAsShutdownDeleteDir(dir)) Utils.deleteRecursively(dir) } }) return dir @@ -227,8 +266,10 @@ private object Utils extends Logging { /** * Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4). + * Note, this is typically not used from within core spark. */ lazy val localIpAddress: String = findLocalIpAddress() + lazy val localIpAddressHostname: String = getAddressHostName(localIpAddress) private def findLocalIpAddress(): String = { val defaultIpOverride = System.getenv("SPARK_LOCAL_IP") @@ -266,6 +307,8 @@ private object Utils extends Logging { * hostname it reports to the master. */ def setCustomHostname(hostname: String) { + // DEBUG code + Utils.checkHost(hostname) customHostname = Some(hostname) } @@ -273,7 +316,90 @@ private object Utils extends Logging { * Get the local machine's hostname. */ def localHostName(): String = { - customHostname.getOrElse(InetAddress.getLocalHost.getHostName) + // customHostname.getOrElse(InetAddress.getLocalHost.getHostName) + customHostname.getOrElse(localIpAddressHostname) + } + + def getAddressHostName(address: String): String = { + InetAddress.getByName(address).getHostName + } + + + + def localHostPort(): String = { + val retval = System.getProperty("spark.hostPort", null) + if (retval == null) { + logErrorWithStack("spark.hostPort not set but invoking localHostPort") + return localHostName() + } + + retval + } + + // Used by DEBUG code : remove when all testing done + def checkHost(host: String, message: String = "") { + // Currently catches only ipv4 pattern, this is just a debugging tool - not rigourous ! + if (host.matches("^[0-9]+(\\.[0-9]+)*$")) { + Utils.logErrorWithStack("Unexpected to have host " + host + " which matches IP pattern. Message " + message) + } + if (Utils.parseHostPort(host)._2 != 0){ + Utils.logErrorWithStack("Unexpected to have host " + host + " which has port in it. Message " + message) + } + } + + // Used by DEBUG code : remove when all testing done + def checkHostPort(hostPort: String, message: String = "") { + val (host, port) = Utils.parseHostPort(hostPort) + checkHost(host) + if (port <= 0){ + Utils.logErrorWithStack("Unexpected to have port " + port + " which is not valid in " + hostPort + ". Message " + message) + } + } + + def getUserNameFromEnvironment(): String = { + SparkHadoopUtil.getUserNameFromEnvironment + } + + // Used by DEBUG code : remove when all testing done + def logErrorWithStack(msg: String) { + try { throw new Exception } catch { case ex: Exception => { logError(msg, ex) } } + // temp code for debug + System.exit(-1) + } + + // Typically, this will be of order of number of nodes in cluster + // If not, we should change it to LRUCache or something. + private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]() + def parseHostPort(hostPort: String): (String, Int) = { + { + // Check cache first. + var cached = hostPortParseResults.get(hostPort) + if (cached != null) return cached + } + + val indx: Int = hostPort.lastIndexOf(':') + // This is potentially broken - when dealing with ipv6 addresses for example, sigh ... but then hadoop does not support ipv6 right now. + // For now, we assume that if port exists, then it is valid - not check if it is an int > 0 + if (-1 == indx) { + val retval = (hostPort, 0) + hostPortParseResults.put(hostPort, retval) + return retval + } + + val retval = (hostPort.substring(0, indx).trim(), hostPort.substring(indx + 1).trim().toInt) + hostPortParseResults.putIfAbsent(hostPort, retval) + hostPortParseResults.get(hostPort) + } + + def addIfNoPort(hostPort: String, port: Int): String = { + if (port <= 0) throw new IllegalArgumentException("Invalid port specified " + port) + + // This is potentially broken - when dealing with ipv6 addresses for example, sigh ... but then hadoop does not support ipv6 right now. + // For now, we assume that if port exists, then it is valid - not check if it is an int > 0 + val indx: Int = hostPort.lastIndexOf(':') + if (-1 != indx) return hostPort + + hostPort + ":" + port } private[spark] val daemonThreadFactory: ThreadFactory = diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 9b4d54ab4e..807119ca8c 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -277,6 +277,8 @@ private class BytesToString extends spark.api.java.function.Function[Array[Byte] */ class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) extends AccumulatorParam[JList[Array[Byte]]] { + + Utils.checkHost(serverHost, "Expected hostname") override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList diff --git a/core/src/main/scala/spark/deploy/DeployMessage.scala b/core/src/main/scala/spark/deploy/DeployMessage.scala index 8a3e64e4c2..51274acb1e 100644 --- a/core/src/main/scala/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/spark/deploy/DeployMessage.scala @@ -4,6 +4,7 @@ import spark.deploy.ExecutorState.ExecutorState import spark.deploy.master.{WorkerInfo, ApplicationInfo} import spark.deploy.worker.ExecutorRunner import scala.collection.immutable.List +import spark.Utils private[spark] sealed trait DeployMessage extends Serializable @@ -19,7 +20,10 @@ case class RegisterWorker( memory: Int, webUiPort: Int, publicAddress: String) - extends DeployMessage + extends DeployMessage { + Utils.checkHost(host, "Required hostname") + assert (port > 0) +} private[spark] case class ExecutorStateChanged( @@ -58,7 +62,9 @@ private[spark] case class RegisteredApplication(appId: String) extends DeployMessage private[spark] -case class ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int) +case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) { + Utils.checkHostPort(hostPort, "Required hostport") +} private[spark] case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String], @@ -81,6 +87,9 @@ private[spark] case class MasterState(host: String, port: Int, workers: Array[WorkerInfo], activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo]) { + Utils.checkHost(host, "Required hostname") + assert (port > 0) + def uri = "spark://" + host + ":" + port } @@ -92,4 +101,8 @@ private[spark] case object RequestWorkerState private[spark] case class WorkerState(host: String, port: Int, workerId: String, executors: List[ExecutorRunner], finishedExecutors: List[ExecutorRunner], masterUrl: String, cores: Int, memory: Int, - coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String) + coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String) { + + Utils.checkHost(host, "Required hostname") + assert (port > 0) +} diff --git a/core/src/main/scala/spark/deploy/JsonProtocol.scala b/core/src/main/scala/spark/deploy/JsonProtocol.scala index 38a6ebfc24..71a641a9ef 100644 --- a/core/src/main/scala/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/spark/deploy/JsonProtocol.scala @@ -12,6 +12,7 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol { def write(obj: WorkerInfo) = JsObject( "id" -> JsString(obj.id), "host" -> JsString(obj.host), + "port" -> JsNumber(obj.port), "webuiaddress" -> JsString(obj.webUiAddress), "cores" -> JsNumber(obj.cores), "coresused" -> JsNumber(obj.coresUsed), diff --git a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala index 22319a96ca..55bb61b0cc 100644 --- a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala @@ -18,7 +18,7 @@ import scala.collection.mutable.ArrayBuffer private[spark] class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) extends Logging { - private val localIpAddress = Utils.localIpAddress + private val localHostname = Utils.localHostName() private val masterActorSystems = ArrayBuffer[ActorSystem]() private val workerActorSystems = ArrayBuffer[ActorSystem]() @@ -26,13 +26,13 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I logInfo("Starting a local Spark cluster with " + numWorkers + " workers.") /* Start the Master */ - val (masterSystem, masterPort) = Master.startSystemAndActor(localIpAddress, 0, 0) + val (masterSystem, masterPort) = Master.startSystemAndActor(localHostname, 0, 0) masterActorSystems += masterSystem - val masterUrl = "spark://" + localIpAddress + ":" + masterPort + val masterUrl = "spark://" + localHostname + ":" + masterPort /* Start the Workers */ for (workerNum <- 1 to numWorkers) { - val (workerSystem, _) = Worker.startSystemAndActor(localIpAddress, 0, 0, coresPerWorker, + val (workerSystem, _) = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker, memoryPerWorker, masterUrl, null, Some(workerNum)) workerActorSystems += workerSystem } diff --git a/core/src/main/scala/spark/deploy/client/Client.scala b/core/src/main/scala/spark/deploy/client/Client.scala index 2fc5e657f9..072232e33a 100644 --- a/core/src/main/scala/spark/deploy/client/Client.scala +++ b/core/src/main/scala/spark/deploy/client/Client.scala @@ -59,10 +59,10 @@ private[spark] class Client( markDisconnected() context.stop(self) - case ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int) => + case ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) => val fullId = appId + "/" + id - logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, host, cores)) - listener.executorAdded(fullId, workerId, host, cores, memory) + logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort, cores)) + listener.executorAdded(fullId, workerId, hostPort, cores, memory) case ExecutorUpdated(id, state, message, exitStatus) => val fullId = appId + "/" + id diff --git a/core/src/main/scala/spark/deploy/client/ClientListener.scala b/core/src/main/scala/spark/deploy/client/ClientListener.scala index b7008321df..e8c4083f9d 100644 --- a/core/src/main/scala/spark/deploy/client/ClientListener.scala +++ b/core/src/main/scala/spark/deploy/client/ClientListener.scala @@ -12,7 +12,7 @@ private[spark] trait ClientListener { def disconnected(): Unit - def executorAdded(fullId: String, workerId: String, host: String, cores: Int, memory: Int): Unit + def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int): Unit def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]): Unit } diff --git a/core/src/main/scala/spark/deploy/client/TestClient.scala b/core/src/main/scala/spark/deploy/client/TestClient.scala index dc004b59ca..ad92532b58 100644 --- a/core/src/main/scala/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/spark/deploy/client/TestClient.scala @@ -16,7 +16,7 @@ private[spark] object TestClient { System.exit(0) } - def executorAdded(id: String, workerId: String, host: String, cores: Int, memory: Int) {} + def executorAdded(id: String, workerId: String, hostPort: String, cores: Int, memory: Int) {} def executorRemoved(id: String, message: String, exitStatus: Option[Int]) {} } diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index 71b9d0801d..160afe5239 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -15,7 +15,7 @@ import spark.{Logging, SparkException, Utils} import spark.util.AkkaUtils -private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging { +private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging { val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs val WORKER_TIMEOUT = System.getProperty("spark.worker.timeout", "60").toLong * 1000 @@ -35,9 +35,11 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor var firstApp: Option[ApplicationInfo] = None + Utils.checkHost(host, "Expected hostname") + val masterPublicAddress = { val envVar = System.getenv("SPARK_PUBLIC_DNS") - if (envVar != null) envVar else ip + if (envVar != null) envVar else host } // As a temporary workaround before better ways of configuring memory, we allow users to set @@ -46,7 +48,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor val spreadOutApps = System.getProperty("spark.deploy.spreadOut", "true").toBoolean override def preStart() { - logInfo("Starting Spark master at spark://" + ip + ":" + port) + logInfo("Starting Spark master at spark://" + host + ":" + port) // Listen for remote client disconnection events, since they don't go through Akka's watch() context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) startWebUi() @@ -145,7 +147,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor } case RequestMasterState => { - sender ! MasterState(ip, port, workers.toArray, apps.toArray, completedApps.toArray) + sender ! MasterState(host, port, workers.toArray, apps.toArray, completedApps.toArray) } } @@ -211,13 +213,13 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) worker.actor ! LaunchExecutor(exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory, sparkHome) - exec.application.driver ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory) + exec.application.driver ! ExecutorAdded(exec.id, worker.id, worker.hostPort, exec.cores, exec.memory) } def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int, publicAddress: String): WorkerInfo = { // There may be one or more refs to dead workers on this same node (w/ different ID's), remove them. - workers.filter(w => (w.host == host) && (w.state == WorkerState.DEAD)).foreach(workers -= _) + workers.filter(w => (w.host == host && w.port == port) && (w.state == WorkerState.DEAD)).foreach(workers -= _) val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress) workers += worker idToWorker(worker.id) = worker @@ -307,7 +309,7 @@ private[spark] object Master { def main(argStrings: Array[String]) { val args = new MasterArguments(argStrings) - val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort) + val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort) actorSystem.awaitTermination() } diff --git a/core/src/main/scala/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/spark/deploy/master/MasterArguments.scala index 4ceab3fc03..3d28ecabb4 100644 --- a/core/src/main/scala/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/spark/deploy/master/MasterArguments.scala @@ -7,13 +7,13 @@ import spark.Utils * Command-line parser for the master. */ private[spark] class MasterArguments(args: Array[String]) { - var ip = Utils.localHostName() + var host = Utils.localHostName() var port = 7077 var webUiPort = 8080 // Check for settings in environment variables - if (System.getenv("SPARK_MASTER_IP") != null) { - ip = System.getenv("SPARK_MASTER_IP") + if (System.getenv("SPARK_MASTER_HOST") != null) { + host = System.getenv("SPARK_MASTER_HOST") } if (System.getenv("SPARK_MASTER_PORT") != null) { port = System.getenv("SPARK_MASTER_PORT").toInt @@ -26,7 +26,13 @@ private[spark] class MasterArguments(args: Array[String]) { def parse(args: List[String]): Unit = args match { case ("--ip" | "-i") :: value :: tail => - ip = value + Utils.checkHost(value, "ip no longer supported, please use hostname " + value) + host = value + parse(tail) + + case ("--host" | "-h") :: value :: tail => + Utils.checkHost(value, "Please use hostname " + value) + host = value parse(tail) case ("--port" | "-p") :: IntParam(value) :: tail => @@ -54,7 +60,8 @@ private[spark] class MasterArguments(args: Array[String]) { "Usage: Master [options]\n" + "\n" + "Options:\n" + - " -i IP, --ip IP IP address or DNS name to listen on\n" + + " -i HOST, --ip HOST Hostname to listen on (deprecated, please use --host or -h) \n" + + " -h HOST, --host HOST Hostname to listen on\n" + " -p PORT, --port PORT Port to listen on (default: 7077)\n" + " --webui-port PORT Port for web UI (default: 8080)") System.exit(exitCode) diff --git a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala index 23df1bb463..0c08c5f417 100644 --- a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala @@ -2,6 +2,7 @@ package spark.deploy.master import akka.actor.ActorRef import scala.collection.mutable +import spark.Utils private[spark] class WorkerInfo( val id: String, @@ -13,6 +14,9 @@ private[spark] class WorkerInfo( val webUiPort: Int, val publicAddress: String) { + Utils.checkHost(host, "Expected hostname") + assert (port > 0) + var executors = new mutable.HashMap[String, ExecutorInfo] // fullId => info var state: WorkerState.Value = WorkerState.ALIVE var coresUsed = 0 @@ -23,6 +27,11 @@ private[spark] class WorkerInfo( def coresFree: Int = cores - coresUsed def memoryFree: Int = memory - memoryUsed + def hostPort: String = { + assert (port > 0) + host + ":" + port + } + def addExecutor(exec: ExecutorInfo) { executors(exec.fullId) = exec coresUsed += exec.cores diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala index de11771c8e..dfcb9f0d05 100644 --- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala @@ -21,11 +21,13 @@ private[spark] class ExecutorRunner( val memory: Int, val worker: ActorRef, val workerId: String, - val hostname: String, + val hostPort: String, val sparkHome: File, val workDir: File) extends Logging { + Utils.checkHostPort(hostPort, "Expected hostport") + val fullId = appId + "/" + execId var workerThread: Thread = null var process: Process = null @@ -68,7 +70,7 @@ private[spark] class ExecutorRunner( /** Replace variables such as {{EXECUTOR_ID}} and {{CORES}} in a command argument passed to us */ def substituteVariables(argument: String): String = argument match { case "{{EXECUTOR_ID}}" => execId.toString - case "{{HOSTNAME}}" => hostname + case "{{HOSTPORT}}" => hostPort case "{{CORES}}" => cores.toString case other => other } diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala index 8919d1261c..cf4babc892 100644 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/spark/deploy/worker/Worker.scala @@ -16,7 +16,7 @@ import spark.deploy.master.Master import java.io.File private[spark] class Worker( - ip: String, + host: String, port: Int, webUiPort: Int, cores: Int, @@ -25,6 +25,9 @@ private[spark] class Worker( workDirPath: String = null) extends Actor with Logging { + Utils.checkHost(host, "Expected hostname") + assert (port > 0) + val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs // Send a heartbeat every (heartbeat timeout) / 4 milliseconds @@ -39,7 +42,7 @@ private[spark] class Worker( val finishedExecutors = new HashMap[String, ExecutorRunner] val publicAddress = { val envVar = System.getenv("SPARK_PUBLIC_DNS") - if (envVar != null) envVar else ip + if (envVar != null) envVar else host } var coresUsed = 0 @@ -64,7 +67,7 @@ private[spark] class Worker( override def preStart() { logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format( - ip, port, cores, Utils.memoryMegabytesToString(memory))) + host, port, cores, Utils.memoryMegabytesToString(memory))) sparkHome = new File(Option(System.getenv("SPARK_HOME")).getOrElse(".")) logInfo("Spark home: " + sparkHome) createWorkDir() @@ -75,7 +78,7 @@ private[spark] class Worker( def connectToMaster() { logInfo("Connecting to master " + masterUrl) master = context.actorFor(Master.toAkkaUrl(masterUrl)) - master ! RegisterWorker(workerId, ip, port, cores, memory, webUiPort, publicAddress) + master ! RegisterWorker(workerId, host, port, cores, memory, webUiPort, publicAddress) context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) context.watch(master) // Doesn't work with remote actors, but useful for testing } @@ -106,7 +109,7 @@ private[spark] class Worker( case LaunchExecutor(appId, execId, appDesc, cores_, memory_, execSparkHome_) => logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name)) val manager = new ExecutorRunner( - appId, execId, appDesc, cores_, memory_, self, workerId, ip, new File(execSparkHome_), workDir) + appId, execId, appDesc, cores_, memory_, self, workerId, host + ":" + port, new File(execSparkHome_), workDir) executors(appId + "/" + execId) = manager manager.start() coresUsed += cores_ @@ -141,7 +144,7 @@ private[spark] class Worker( masterDisconnected() case RequestWorkerState => { - sender ! WorkerState(ip, port, workerId, executors.values.toList, + sender ! WorkerState(host, port, workerId, executors.values.toList, finishedExecutors.values.toList, masterUrl, cores, memory, coresUsed, memoryUsed, masterWebUiUrl) } @@ -156,7 +159,7 @@ private[spark] class Worker( } def generateWorkerId(): String = { - "worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), ip, port) + "worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), host, port) } override def postStop() { @@ -167,7 +170,7 @@ private[spark] class Worker( private[spark] object Worker { def main(argStrings: Array[String]) { val args = new WorkerArguments(argStrings) - val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort, args.cores, + val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores, args.memory, args.master, args.workDir) actorSystem.awaitTermination() } diff --git a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala index 08f02bad80..2b96611ee3 100644 --- a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala @@ -9,7 +9,7 @@ import java.lang.management.ManagementFactory * Command-line parser for the master. */ private[spark] class WorkerArguments(args: Array[String]) { - var ip = Utils.localHostName() + var host = Utils.localHostName() var port = 0 var webUiPort = 8081 var cores = inferDefaultCores() @@ -38,7 +38,13 @@ private[spark] class WorkerArguments(args: Array[String]) { def parse(args: List[String]): Unit = args match { case ("--ip" | "-i") :: value :: tail => - ip = value + Utils.checkHost(value, "ip no longer supported, please use hostname " + value) + host = value + parse(tail) + + case ("--host" | "-h") :: value :: tail => + Utils.checkHost(value, "Please use hostname " + value) + host = value parse(tail) case ("--port" | "-p") :: IntParam(value) :: tail => @@ -93,7 +99,8 @@ private[spark] class WorkerArguments(args: Array[String]) { " -c CORES, --cores CORES Number of cores to use\n" + " -m MEM, --memory MEM Amount of memory to use (e.g. 1000M, 2G)\n" + " -d DIR, --work-dir DIR Directory to run apps in (default: SPARK_HOME/work)\n" + - " -i IP, --ip IP IP address or DNS name to listen on\n" + + " -i HOST, --ip IP Hostname to listen on (deprecated, please use --host or -h)\n" + + " -h HOST, --host HOST Hostname to listen on\n" + " -p PORT, --port PORT Port to listen on (default: random)\n" + " --webui-port PORT Port for web UI (default: 8081)") System.exit(exitCode) diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index 3e7407b58d..344face5e6 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -17,7 +17,7 @@ import java.nio.ByteBuffer * The Mesos executor for Spark. */ private[spark] class Executor(executorId: String, slaveHostname: String, properties: Seq[(String, String)]) extends Logging { - + // Application dependencies (added through SparkContext) that we've fetched so far on this node. // Each map holds the master's timestamp for the version of that file or JAR we got. private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]() @@ -27,6 +27,11 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert initLogging() + // No ip or host:port - just hostname + Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname") + // must not have port specified. + assert (0 == Utils.parseHostPort(slaveHostname)._2) + // Make sure the local hostname we report matches the cluster scheduler's name for this host Utils.setCustomHostname(slaveHostname) diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala index 1047f71c6a..49e1f3f07a 100644 --- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala @@ -12,23 +12,27 @@ import spark.scheduler.cluster.RegisteredExecutor import spark.scheduler.cluster.LaunchTask import spark.scheduler.cluster.RegisterExecutorFailed import spark.scheduler.cluster.RegisterExecutor +import spark.Utils +import spark.deploy.SparkHadoopUtil private[spark] class StandaloneExecutorBackend( driverUrl: String, executorId: String, - hostname: String, + hostPort: String, cores: Int) extends Actor with ExecutorBackend with Logging { + Utils.checkHostPort(hostPort, "Expected hostport") + var executor: Executor = null var driver: ActorRef = null override def preStart() { logInfo("Connecting to driver: " + driverUrl) driver = context.actorFor(driverUrl) - driver ! RegisterExecutor(executorId, hostname, cores) + driver ! RegisterExecutor(executorId, hostPort, cores) context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) context.watch(driver) // Doesn't work with remote actors, but useful for testing } @@ -36,7 +40,8 @@ private[spark] class StandaloneExecutorBackend( override def receive = { case RegisteredExecutor(sparkProperties) => logInfo("Successfully registered with driver") - executor = new Executor(executorId, hostname, sparkProperties) + // Make this host instead of hostPort ? + executor = new Executor(executorId, Utils.parseHostPort(hostPort)._1, sparkProperties) case RegisterExecutorFailed(message) => logError("Slave registration failed: " + message) @@ -63,11 +68,29 @@ private[spark] class StandaloneExecutorBackend( private[spark] object StandaloneExecutorBackend { def run(driverUrl: String, executorId: String, hostname: String, cores: Int) { + SparkHadoopUtil.runAsUser(run0, Tuple4[Any, Any, Any, Any] (driverUrl, executorId, hostname, cores)) + } + + // This will be run 'as' the user + def run0(args: Product) { + assert(4 == args.productArity) + runImpl(args.productElement(0).asInstanceOf[String], + args.productElement(0).asInstanceOf[String], + args.productElement(0).asInstanceOf[String], + args.productElement(0).asInstanceOf[Int]) + } + + private def runImpl(driverUrl: String, executorId: String, hostname: String, cores: Int) { // Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor // before getting started with all our system properties, etc val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0) + // Debug code + Utils.checkHost(hostname) + // set it + val sparkHostPort = hostname + ":" + boundPort + System.setProperty("spark.hostPort", sparkHostPort) val actor = actorSystem.actorOf( - Props(new StandaloneExecutorBackend(driverUrl, executorId, hostname, cores)), + Props(new StandaloneExecutorBackend(driverUrl, executorId, sparkHostPort, cores)), name = "Executor") actorSystem.awaitTermination() } diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala index d1451bc212..00a0433a44 100644 --- a/core/src/main/scala/spark/network/Connection.scala +++ b/core/src/main/scala/spark/network/Connection.scala @@ -13,7 +13,7 @@ import java.net._ private[spark] abstract class Connection(val channel: SocketChannel, val selector: Selector, - val remoteConnectionManagerId: ConnectionManagerId) extends Logging { + val socketRemoteConnectionManagerId: ConnectionManagerId) extends Logging { def this(channel_ : SocketChannel, selector_ : Selector) = { this(channel_, selector_, ConnectionManagerId.fromSocketAddress( @@ -32,16 +32,43 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, var onKeyInterestChangeCallback: (Connection, Int) => Unit = null val remoteAddress = getRemoteAddress() + + // Read channels typically do not register for write and write does not for read + // Now, we do have write registering for read too (temporarily), but this is to detect + // channel close NOT to actually read/consume data on it ! + // How does this work if/when we move to SSL ? + + // What is the interest to register with selector for when we want this connection to be selected + def registerInterest() + // What is the interest to register with selector for when we want this connection to be de-selected + // Traditionally, 0 - but in our case, for example, for close-detection on SendingConnection hack, it will be + // SelectionKey.OP_READ (until we fix it properly) + def unregisterInterest() + + // On receiving a read event, should we change the interest for this channel or not ? + // Will be true for ReceivingConnection, false for SendingConnection. + def changeInterestForRead(): Boolean + + // On receiving a write event, should we change the interest for this channel or not ? + // Will be false for ReceivingConnection, true for SendingConnection. + // Actually, for now, should not get triggered for ReceivingConnection + def changeInterestForWrite(): Boolean + + def getRemoteConnectionManagerId(): ConnectionManagerId = { + socketRemoteConnectionManagerId + } def key() = channel.keyFor(selector) def getRemoteAddress() = channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress] - def read() { + // Returns whether we have to register for further reads or not. + def read(): Boolean = { throw new UnsupportedOperationException("Cannot read on connection of type " + this.getClass.toString) } - - def write() { + + // Returns whether we have to register for further writes or not. + def write(): Boolean = { throw new UnsupportedOperationException("Cannot write on connection of type " + this.getClass.toString) } @@ -64,7 +91,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, if (onExceptionCallback != null) { onExceptionCallback(this, e) } else { - logError("Error in connection to " + remoteConnectionManagerId + + logError("Error in connection to " + getRemoteConnectionManagerId() + " and OnExceptionCallback not registered", e) } } @@ -73,7 +100,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, if (onCloseCallback != null) { onCloseCallback(this) } else { - logWarning("Connection to " + remoteConnectionManagerId + + logWarning("Connection to " + getRemoteConnectionManagerId() + " closed and OnExceptionCallback not registered") } @@ -81,7 +108,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, def changeConnectionKeyInterest(ops: Int) { if (onKeyInterestChangeCallback != null) { - onKeyInterestChangeCallback(this, ops) + onKeyInterestChangeCallback(this, ops) } else { throw new Exception("OnKeyInterestChangeCallback not registered") } @@ -122,7 +149,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { messages.synchronized{ /*messages += message*/ messages.enqueue(message) - logDebug("Added [" + message + "] to outbox for sending to [" + remoteConnectionManagerId + "]") + logDebug("Added [" + message + "] to outbox for sending to [" + getRemoteConnectionManagerId() + "]") } } @@ -149,9 +176,9 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { } return chunk } else { - /*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/ + /*logInfo("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + "]")*/ message.finishTime = System.currentTimeMillis - logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId + + logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + "] in " + message.timeTaken ) } } @@ -170,15 +197,15 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { messages.enqueue(message) nextMessageToBeUsed = nextMessageToBeUsed + 1 if (!message.started) { - logDebug("Starting to send [" + message + "] to [" + remoteConnectionManagerId + "]") + logDebug("Starting to send [" + message + "] to [" + getRemoteConnectionManagerId() + "]") message.started = true message.startTime = System.currentTimeMillis } - logTrace("Sending chunk from [" + message+ "] to [" + remoteConnectionManagerId + "]") + logTrace("Sending chunk from [" + message+ "] to [" + getRemoteConnectionManagerId() + "]") return chunk } else { message.finishTime = System.currentTimeMillis - logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId + + logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + "] in " + message.timeTaken ) } } @@ -187,26 +214,39 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { } } - val outbox = new Outbox(1) + private val outbox = new Outbox(1) val currentBuffers = new ArrayBuffer[ByteBuffer]() /*channel.socket.setSendBufferSize(256 * 1024)*/ - override def getRemoteAddress() = address + override def getRemoteAddress() = address + val DEFAULT_INTEREST = SelectionKey.OP_READ + + override def registerInterest() { + // Registering read too - does not really help in most cases, but for some + // it does - so let us keep it for now. + changeConnectionKeyInterest(SelectionKey.OP_WRITE | DEFAULT_INTEREST) + } + + override def unregisterInterest() { + changeConnectionKeyInterest(DEFAULT_INTEREST) + } + def send(message: Message) { outbox.synchronized { outbox.addMessage(message) if (channel.isConnected) { - changeConnectionKeyInterest(SelectionKey.OP_WRITE | SelectionKey.OP_READ) + registerInterest() } } } + // MUST be called within the selector loop def connect() { try{ - channel.connect(address) channel.register(selector, SelectionKey.OP_CONNECT) + channel.connect(address) logInfo("Initiating connection to [" + address + "]") } catch { case e: Exception => { @@ -216,20 +256,33 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { } } - def finishConnect() { + def finishConnect(force: Boolean): Boolean = { try { - channel.finishConnect - changeConnectionKeyInterest(SelectionKey.OP_WRITE | SelectionKey.OP_READ) + // Typically, this should finish immediately since it was triggered by a connect + // selection - though need not necessarily always complete successfully. + val connected = channel.finishConnect + if (!force && !connected) { + logInfo("finish connect failed [" + address + "], " + outbox.messages.size + " messages pending") + return false + } + + // Fallback to previous behavior - assume finishConnect completed + // This will happen only when finishConnect failed for some repeated number of times (10 or so) + // Is highly unlikely unless there was an unclean close of socket, etc + registerInterest() logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending") + return true } catch { case e: Exception => { logWarning("Error finishing connection to " + address, e) callOnExceptionCallback(e) + // ignore + return true } } } - override def write() { + override def write(): Boolean = { try{ while(true) { if (currentBuffers.size == 0) { @@ -239,8 +292,9 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { currentBuffers ++= chunk.buffers } case None => { - changeConnectionKeyInterest(SelectionKey.OP_READ) - return + // changeConnectionKeyInterest(0) + /*key.interestOps(0)*/ + return false } } } @@ -254,38 +308,53 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { currentBuffers -= buffer } if (writtenBytes < remainingBytes) { - return + // re-register for write. + return true } } } } catch { case e: Exception => { - logWarning("Error writing in connection to " + remoteConnectionManagerId, e) + logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e) callOnExceptionCallback(e) close() + return false } } + // should not happen - to keep scala compiler happy + return true } - override def read() { + // This is a hack to determine if remote socket was closed or not. + // SendingConnection DOES NOT expect to receive any data - if it does, it is an error + // For a bunch of cases, read will return -1 in case remote socket is closed : hence we + // register for reads to determine that. + override def read(): Boolean = { // We don't expect the other side to send anything; so, we just read to detect an error or EOF. try { val length = channel.read(ByteBuffer.allocate(1)) if (length == -1) { // EOF close() } else if (length > 0) { - logWarning("Unexpected data read from SendingConnection to " + remoteConnectionManagerId) + logWarning("Unexpected data read from SendingConnection to " + getRemoteConnectionManagerId()) } } catch { case e: Exception => - logError("Exception while reading SendingConnection to " + remoteConnectionManagerId, e) + logError("Exception while reading SendingConnection to " + getRemoteConnectionManagerId(), e) callOnExceptionCallback(e) close() } + + false } + + override def changeInterestForRead(): Boolean = false + + override def changeInterestForWrite(): Boolean = true } +// Must be created within selector loop - else deadlock private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector) extends Connection(channel_, selector_) { @@ -298,13 +367,13 @@ extends Connection(channel_, selector_) { val newMessage = Message.create(header).asInstanceOf[BufferMessage] newMessage.started = true newMessage.startTime = System.currentTimeMillis - logDebug("Starting to receive [" + newMessage + "] from [" + remoteConnectionManagerId + "]") + logDebug("Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]") messages += ((newMessage.id, newMessage)) newMessage } val message = messages.getOrElseUpdate(header.id, createNewMessage) - logTrace("Receiving chunk of [" + message + "] from [" + remoteConnectionManagerId + "]") + logTrace("Receiving chunk of [" + message + "] from [" + getRemoteConnectionManagerId() + "]") message.getChunkForReceiving(header.chunkSize) } @@ -316,7 +385,27 @@ extends Connection(channel_, selector_) { messages -= message.id } } - + + @volatile private var inferredRemoteManagerId: ConnectionManagerId = null + override def getRemoteConnectionManagerId(): ConnectionManagerId = { + val currId = inferredRemoteManagerId + if (currId != null) currId else super.getRemoteConnectionManagerId() + } + + // The reciever's remote address is the local socket on remote side : which is NOT the connection manager id of the receiver. + // We infer that from the messages we receive on the receiver socket. + private def processConnectionManagerId(header: MessageChunkHeader) { + val currId = inferredRemoteManagerId + if (header.address == null || currId != null) return + + val managerId = ConnectionManagerId.fromSocketAddress(header.address) + + if (managerId != null) { + inferredRemoteManagerId = managerId + } + } + + val inbox = new Inbox() val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE) var onReceiveCallback: (Connection , Message) => Unit = null @@ -324,17 +413,18 @@ extends Connection(channel_, selector_) { channel.register(selector, SelectionKey.OP_READ) - override def read() { + override def read(): Boolean = { try { while (true) { if (currentChunk == null) { val headerBytesRead = channel.read(headerBuffer) if (headerBytesRead == -1) { close() - return + return false } if (headerBuffer.remaining > 0) { - return + // re-register for read event ... + return true } headerBuffer.flip if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) { @@ -342,6 +432,9 @@ extends Connection(channel_, selector_) { } val header = MessageChunkHeader.create(headerBuffer) headerBuffer.clear() + + processConnectionManagerId(header) + header.typ match { case Message.BUFFER_MESSAGE => { if (header.totalSize == 0) { @@ -349,7 +442,8 @@ extends Connection(channel_, selector_) { onReceiveCallback(this, Message.create(header)) } currentChunk = null - return + // re-register for read event ... + return true } else { currentChunk = inbox.getChunk(header).orNull } @@ -362,10 +456,11 @@ extends Connection(channel_, selector_) { val bytesRead = channel.read(currentChunk.buffer) if (bytesRead == 0) { - return + // re-register for read event ... + return true } else if (bytesRead == -1) { close() - return + return false } /*logDebug("Read " + bytesRead + " bytes for the buffer")*/ @@ -376,7 +471,7 @@ extends Connection(channel_, selector_) { if (bufferMessage.isCompletelyReceived) { bufferMessage.flip bufferMessage.finishTime = System.currentTimeMillis - logDebug("Finished receiving [" + bufferMessage + "] from [" + remoteConnectionManagerId + "] in " + bufferMessage.timeTaken) + logDebug("Finished receiving [" + bufferMessage + "] from [" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken) if (onReceiveCallback != null) { onReceiveCallback(this, bufferMessage) } @@ -387,12 +482,31 @@ extends Connection(channel_, selector_) { } } catch { case e: Exception => { - logWarning("Error reading from connection to " + remoteConnectionManagerId, e) + logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e) callOnExceptionCallback(e) close() + return false } } + // should not happen - to keep scala compiler happy + return true } def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback} + + override def changeInterestForRead(): Boolean = true + + override def changeInterestForWrite(): Boolean = { + throw new IllegalStateException("Unexpected invocation right now") + } + + override def registerInterest() { + // Registering read too - does not really help in most cases, but for some + // it does - so let us keep it for now. + changeConnectionKeyInterest(SelectionKey.OP_READ) + } + + override def unregisterInterest() { + changeConnectionKeyInterest(0) + } } diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index b6ec664d7e..0c6bdb1559 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -6,12 +6,12 @@ import java.nio._ import java.nio.channels._ import java.nio.channels.spi._ import java.net._ -import java.util.concurrent.Executors +import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor} +import scala.collection.mutable.HashSet import scala.collection.mutable.HashMap import scala.collection.mutable.SynchronizedMap import scala.collection.mutable.SynchronizedQueue -import scala.collection.mutable.Queue import scala.collection.mutable.ArrayBuffer import akka.dispatch.{Await, Promise, ExecutionContext, Future} @@ -19,6 +19,10 @@ import akka.util.Duration import akka.util.duration._ private[spark] case class ConnectionManagerId(host: String, port: Int) { + // DEBUG code + Utils.checkHost(host) + assert (port > 0) + def toSocketAddress() = new InetSocketAddress(host, port) } @@ -42,19 +46,37 @@ private[spark] class ConnectionManager(port: Int) extends Logging { def markDone() { completionHandler(this) } } - val selector = SelectorProvider.provider.openSelector() - val handleMessageExecutor = Executors.newFixedThreadPool(System.getProperty("spark.core.connection.handler.threads","20").toInt) - val serverChannel = ServerSocketChannel.open() - val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] - val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] - val messageStatuses = new HashMap[Int, MessageStatus] - val connectionRequests = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] - val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] - val sendMessageRequests = new Queue[(Message, SendingConnection)] + private val selector = SelectorProvider.provider.openSelector() + + private val handleMessageExecutor = new ThreadPoolExecutor( + System.getProperty("spark.core.connection.handler.threads.min","20").toInt, + System.getProperty("spark.core.connection.handler.threads.max","60").toInt, + System.getProperty("spark.core.connection.handler.threads.keepalive","60").toInt, TimeUnit.SECONDS, + new LinkedBlockingDeque[Runnable]()) + + private val handleReadWriteExecutor = new ThreadPoolExecutor( + System.getProperty("spark.core.connection.io.threads.min","4").toInt, + System.getProperty("spark.core.connection.io.threads.max","32").toInt, + System.getProperty("spark.core.connection.io.threads.keepalive","60").toInt, TimeUnit.SECONDS, + new LinkedBlockingDeque[Runnable]()) + + // Use a different, yet smaller, thread pool - infrequently used with very short lived tasks : which should be executed asap + private val handleConnectExecutor = new ThreadPoolExecutor( + System.getProperty("spark.core.connection.connect.threads.min","1").toInt, + System.getProperty("spark.core.connection.connect.threads.max","8").toInt, + System.getProperty("spark.core.connection.connect.threads.keepalive","60").toInt, TimeUnit.SECONDS, + new LinkedBlockingDeque[Runnable]()) + + private val serverChannel = ServerSocketChannel.open() + private val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] + private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] + private val messageStatuses = new HashMap[Int, MessageStatus] + private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] + private val registerRequests = new SynchronizedQueue[SendingConnection] implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool()) - var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null + private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null serverChannel.configureBlocking(false) serverChannel.socket.setReuseAddress(true) @@ -65,46 +87,139 @@ private[spark] class ConnectionManager(port: Int) extends Logging { val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort) logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id) - - val selectorThread = new Thread("connection-manager-thread") { + + private val selectorThread = new Thread("connection-manager-thread") { override def run() = ConnectionManager.this.run() } selectorThread.setDaemon(true) selectorThread.start() - private def run() { - try { - while(!selectorThread.isInterrupted) { - for ((connectionManagerId, sendingConnection) <- connectionRequests) { - sendingConnection.connect() - addConnection(sendingConnection) - connectionRequests -= connectionManagerId + private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() + + private def triggerWrite(key: SelectionKey) { + val conn = connectionsByKey.getOrElse(key, null) + if (conn == null) return + + writeRunnableStarted.synchronized { + // So that we do not trigger more write events while processing this one. + // The write method will re-register when done. + if (conn.changeInterestForWrite()) conn.unregisterInterest() + if (writeRunnableStarted.contains(key)) { + // key.interestOps(key.interestOps() & ~ SelectionKey.OP_WRITE) + return + } + + writeRunnableStarted += key + } + handleReadWriteExecutor.execute(new Runnable { + override def run() { + var register: Boolean = false + try { + register = conn.write() + } finally { + writeRunnableStarted.synchronized { + writeRunnableStarted -= key + if (register && conn.changeInterestForWrite()) { + conn.registerInterest() + } + } } - sendMessageRequests.synchronized { - while (!sendMessageRequests.isEmpty) { - val (message, connection) = sendMessageRequests.dequeue - connection.send(message) + } + } ) + } + + private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() + + private def triggerRead(key: SelectionKey) { + val conn = connectionsByKey.getOrElse(key, null) + if (conn == null) return + + readRunnableStarted.synchronized { + // So that we do not trigger more read events while processing this one. + // The read method will re-register when done. + if (conn.changeInterestForRead())conn.unregisterInterest() + if (readRunnableStarted.contains(key)) { + return + } + + readRunnableStarted += key + } + handleReadWriteExecutor.execute(new Runnable { + override def run() { + var register: Boolean = false + try { + register = conn.read() + } finally { + readRunnableStarted.synchronized { + readRunnableStarted -= key + if (register && conn.changeInterestForRead()) { + conn.registerInterest() + } } } + } + } ) + } + + private def triggerConnect(key: SelectionKey) { + val conn = connectionsByKey.getOrElse(key, null).asInstanceOf[SendingConnection] + if (conn == null) return + + // prevent other events from being triggered + // Since we are still trying to connect, we do not need to do the additional steps in triggerWrite + conn.changeConnectionKeyInterest(0) + + handleConnectExecutor.execute(new Runnable { + override def run() { + + var tries: Int = 10 + while (tries >= 0) { + if (conn.finishConnect(false)) return + // Sleep ? + Thread.sleep(1) + tries -= 1 + } - while (!keyInterestChangeRequests.isEmpty) { + // fallback to previous behavior : we should not really come here since this method was + // triggered since channel became connectable : but at times, the first finishConnect need not + // succeed : hence the loop to retry a few 'times'. + conn.finishConnect(true) + } + } ) + } + + def run() { + try { + while(!selectorThread.isInterrupted) { + while (! registerRequests.isEmpty) { + val conn: SendingConnection = registerRequests.dequeue + addListeners(conn) + conn.connect() + addConnection(conn) + } + + while(!keyInterestChangeRequests.isEmpty) { val (key, ops) = keyInterestChangeRequests.dequeue - val connection = connectionsByKey(key) - val lastOps = key.interestOps() - key.interestOps(ops) - - def intToOpStr(op: Int): String = { - val opStrs = ArrayBuffer[String]() - if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ" - if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE" - if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT" - if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT" - if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " " + val connection = connectionsByKey.getOrElse(key, null) + if (connection != null) { + val lastOps = key.interestOps() + key.interestOps(ops) + + // hot loop - prevent materialization of string if trace not enabled. + if (isTraceEnabled()) { + def intToOpStr(op: Int): String = { + val opStrs = ArrayBuffer[String]() + if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ" + if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE" + if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT" + if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT" + if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " " + } + + logTrace("Changed key for connection to [" + connection.getRemoteConnectionManagerId() + + "] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]") + } } - - logTrace("Changed key for connection to [" + connection.remoteConnectionManagerId + - "] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]") - } val selectedKeysCount = selector.select() @@ -123,12 +238,15 @@ private[spark] class ConnectionManager(port: Int) extends Logging { if (key.isValid) { if (key.isAcceptable) { acceptConnection(key) - } else if (key.isConnectable) { - connectionsByKey(key).asInstanceOf[SendingConnection].finishConnect() - } else if (key.isReadable) { - connectionsByKey(key).read() - } else if (key.isWritable) { - connectionsByKey(key).write() + } else + if (key.isConnectable) { + triggerConnect(key) + } else + if (key.isReadable) { + triggerRead(key) + } else + if (key.isWritable) { + triggerWrite(key) } } } @@ -138,94 +256,116 @@ private[spark] class ConnectionManager(port: Int) extends Logging { } } - private def acceptConnection(key: SelectionKey) { + def acceptConnection(key: SelectionKey) { val serverChannel = key.channel.asInstanceOf[ServerSocketChannel] - val newChannel = serverChannel.accept() - val newConnection = new ReceivingConnection(newChannel, selector) - newConnection.onReceive(receiveMessage) - newConnection.onClose(removeConnection) - addConnection(newConnection) - logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]") - } - private def addConnection(connection: Connection) { - connectionsByKey += ((connection.key, connection)) - if (connection.isInstanceOf[SendingConnection]) { - val sendingConnection = connection.asInstanceOf[SendingConnection] - connectionsById += ((sendingConnection.remoteConnectionManagerId, sendingConnection)) + var newChannel = serverChannel.accept() + + // accept them all in a tight loop. non blocking accept with no processing, should be fine + while (newChannel != null) { + try { + val newConnection = new ReceivingConnection(newChannel, selector) + newConnection.onReceive(receiveMessage) + addListeners(newConnection) + addConnection(newConnection) + logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]") + } catch { + // might happen in case of issues with registering with selector + case e: Exception => logError("Error in accept loop", e) + } + + newChannel = serverChannel.accept() } + } + + private def addListeners(connection: Connection) { connection.onKeyInterestChange(changeConnectionKeyInterest) connection.onException(handleConnectionError) connection.onClose(removeConnection) } - private def removeConnection(connection: Connection) { + def addConnection(connection: Connection) { + connectionsByKey += ((connection.key, connection)) + } + + def removeConnection(connection: Connection) { connectionsByKey -= connection.key - if (connection.isInstanceOf[SendingConnection]) { - val sendingConnection = connection.asInstanceOf[SendingConnection] - val sendingConnectionManagerId = sendingConnection.remoteConnectionManagerId - logInfo("Removing SendingConnection to " + sendingConnectionManagerId) - - connectionsById -= sendingConnectionManagerId - - messageStatuses.synchronized { - messageStatuses - .values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => { - logInfo("Notifying " + status) - status.synchronized { - status.attempted = true - status.acked = false - status.markDone() - } + + try { + if (connection.isInstanceOf[SendingConnection]) { + val sendingConnection = connection.asInstanceOf[SendingConnection] + val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId() + logInfo("Removing SendingConnection to " + sendingConnectionManagerId) + + connectionsById -= sendingConnectionManagerId + + messageStatuses.synchronized { + messageStatuses + .values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => { + logInfo("Notifying " + status) + status.synchronized { + status.attempted = true + status.acked = false + status.markDone() + } + }) + + messageStatuses.retain((i, status) => { + status.connectionManagerId != sendingConnectionManagerId }) + } + } else if (connection.isInstanceOf[ReceivingConnection]) { + val receivingConnection = connection.asInstanceOf[ReceivingConnection] + val remoteConnectionManagerId = receivingConnection.getRemoteConnectionManagerId() + logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId) + + val sendingConnectionOpt = connectionsById.get(remoteConnectionManagerId) + if (! sendingConnectionOpt.isDefined) { + logError("Corresponding SendingConnectionManagerId not found") + return + } - messageStatuses.retain((i, status) => { - status.connectionManagerId != sendingConnectionManagerId - }) - } - } else if (connection.isInstanceOf[ReceivingConnection]) { - val receivingConnection = connection.asInstanceOf[ReceivingConnection] - val remoteConnectionManagerId = receivingConnection.remoteConnectionManagerId - logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId) - - val sendingConnectionManagerId = connectionsById.keys.find(_.host == remoteConnectionManagerId.host).orNull - if (sendingConnectionManagerId == null) { - logError("Corresponding SendingConnectionManagerId not found") - return - } - logInfo("Corresponding SendingConnectionManagerId is " + sendingConnectionManagerId) - - val sendingConnection = connectionsById(sendingConnectionManagerId) - sendingConnection.close() - connectionsById -= sendingConnectionManagerId - - messageStatuses.synchronized { - for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) { - logInfo("Notifying " + s) - s.synchronized { - s.attempted = true - s.acked = false - s.markDone() + val sendingConnection = sendingConnectionOpt.get + connectionsById -= remoteConnectionManagerId + sendingConnection.close() + + val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId() + + assert (sendingConnectionManagerId == remoteConnectionManagerId) + + messageStatuses.synchronized { + for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) { + logInfo("Notifying " + s) + s.synchronized { + s.attempted = true + s.acked = false + s.markDone() + } } - } - messageStatuses.retain((i, status) => { - status.connectionManagerId != sendingConnectionManagerId - }) + messageStatuses.retain((i, status) => { + status.connectionManagerId != sendingConnectionManagerId + }) + } } + } finally { + // So that the selection keys can be removed. + wakeupSelector() } } - private def handleConnectionError(connection: Connection, e: Exception) { - logInfo("Handling connection error on connection to " + connection.remoteConnectionManagerId) + def handleConnectionError(connection: Connection, e: Exception) { + logInfo("Handling connection error on connection to " + connection.getRemoteConnectionManagerId()) removeConnection(connection) } - private def changeConnectionKeyInterest(connection: Connection, ops: Int) { - keyInterestChangeRequests += ((connection.key, ops)) + def changeConnectionKeyInterest(connection: Connection, ops: Int) { + keyInterestChangeRequests += ((connection.key, ops)) + // so that registerations happen ! + wakeupSelector() } - private def receiveMessage(connection: Connection, message: Message) { + def receiveMessage(connection: Connection, message: Message) { val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress) logDebug("Received [" + message + "] from [" + connectionManagerId + "]") val runnable = new Runnable() { @@ -293,18 +433,22 @@ private[spark] class ConnectionManager(port: Int) extends Logging { private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) { def startNewConnection(): SendingConnection = { val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port) - val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId, - new SendingConnection(inetSocketAddress, selector, connectionManagerId)) - newConnection + val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId) + registerRequests.enqueue(newConnection) + + newConnection } - val lookupKey = ConnectionManagerId.fromSocketAddress(connectionManagerId.toSocketAddress) - val connection = connectionsById.getOrElse(lookupKey, startNewConnection()) + // I removed the lookupKey stuff as part of merge ... should I re-add it ? We did not find it useful in our test-env ... + // If we do re-add it, we should consistently use it everywhere I guess ? + val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection()) message.senderAddress = id.toSocketAddress() logDebug("Sending [" + message + "] to [" + connectionManagerId + "]") - /*connection.send(message)*/ - sendMessageRequests.synchronized { - sendMessageRequests += ((message, connection)) - } + connection.send(message) + + wakeupSelector() + } + + private def wakeupSelector() { selector.wakeup() } @@ -337,6 +481,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging { logWarning("All connections not cleaned up") } handleMessageExecutor.shutdown() + handleReadWriteExecutor.shutdown() + handleConnectExecutor.shutdown() logInfo("ConnectionManager stopped") } } diff --git a/core/src/main/scala/spark/network/Message.scala b/core/src/main/scala/spark/network/Message.scala index 525751b5bf..34fac9e776 100644 --- a/core/src/main/scala/spark/network/Message.scala +++ b/core/src/main/scala/spark/network/Message.scala @@ -17,7 +17,8 @@ private[spark] class MessageChunkHeader( val other: Int, val address: InetSocketAddress) { lazy val buffer = { - val ip = address.getAddress.getAddress() + // No need to change this, at 'use' time, we do a reverse lookup of the hostname. Refer to network.Connection + val ip = address.getAddress.getAddress() val port = address.getPort() ByteBuffer. allocate(MessageChunkHeader.HEADER_SIZE). diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index c54dce51d7..1440b93e65 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -50,6 +50,11 @@ class DAGScheduler( eventQueue.put(ExecutorLost(execId)) } + // Called by TaskScheduler when a host is added + override def executorGained(execId: String, hostPort: String) { + eventQueue.put(ExecutorGained(execId, hostPort)) + } + // Called by TaskScheduler to cancel an entire TaskSet due to repeated failures. override def taskSetFailed(taskSet: TaskSet, reason: String) { eventQueue.put(TaskSetFailed(taskSet, reason)) @@ -113,7 +118,7 @@ class DAGScheduler( if (!cacheLocs.contains(rdd.id)) { val blockIds = rdd.partitions.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray cacheLocs(rdd.id) = blockManagerMaster.getLocations(blockIds).map { - locations => locations.map(_.ip).toList + locations => locations.map(_.hostPort).toList }.toArray } cacheLocs(rdd.id) @@ -293,6 +298,9 @@ class DAGScheduler( submitStage(finalStage) } + case ExecutorGained(execId, hostPort) => + handleExecutorGained(execId, hostPort) + case ExecutorLost(execId) => handleExecutorLost(execId) @@ -630,6 +638,14 @@ class DAGScheduler( "(generation " + currentGeneration + ")") } } + + private def handleExecutorGained(execId: String, hostPort: String) { + // remove from failedGeneration(execId) ? + if (failedGeneration.contains(execId)) { + logInfo("Host gained which was in lost list earlier: " + hostPort) + failedGeneration -= execId + } + } /** * Aborts all jobs depending on a particular Stage. This is called in response to a task set diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala index ed0b9bf178..b46bb863f0 100644 --- a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala @@ -32,6 +32,10 @@ private[spark] case class CompletionEvent( taskMetrics: TaskMetrics) extends DAGSchedulerEvent +private[spark] case class ExecutorGained(execId: String, hostPort: String) extends DAGSchedulerEvent { + Utils.checkHostPort(hostPort, "Required hostport") +} + private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent diff --git a/core/src/main/scala/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/spark/scheduler/InputFormatInfo.scala new file mode 100644 index 0000000000..287f731787 --- /dev/null +++ b/core/src/main/scala/spark/scheduler/InputFormatInfo.scala @@ -0,0 +1,156 @@ +package spark.scheduler + +import spark.Logging +import scala.collection.immutable.Set +import org.apache.hadoop.mapred.{FileInputFormat, JobConf} +import org.apache.hadoop.util.ReflectionUtils +import org.apache.hadoop.mapreduce.Job +import org.apache.hadoop.conf.Configuration +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.collection.JavaConversions._ + + +/** + * Parses and holds information about inputFormat (and files) specified as a parameter. + */ +class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Class[_], + val path: String) extends Logging { + + var mapreduceInputFormat: Boolean = false + var mapredInputFormat: Boolean = false + + validate() + + override def toString(): String = { + "InputFormatInfo " + super.toString + " .. inputFormatClazz " + inputFormatClazz + ", path : " + path + } + + override def hashCode(): Int = { + var hashCode = inputFormatClazz.hashCode + hashCode = hashCode * 31 + path.hashCode + hashCode + } + + // Since we are not doing canonicalization of path, this can be wrong : like relative vs absolute path + // .. which is fine, this is best case effort to remove duplicates - right ? + override def equals(other: Any): Boolean = other match { + case that: InputFormatInfo => { + // not checking config - that should be fine, right ? + this.inputFormatClazz == that.inputFormatClazz && + this.path == that.path + } + case _ => false + } + + private def validate() { + logDebug("validate InputFormatInfo : " + inputFormatClazz + ", path " + path) + + try { + if (classOf[org.apache.hadoop.mapreduce.InputFormat[_, _]].isAssignableFrom(inputFormatClazz)) { + logDebug("inputformat is from mapreduce package") + mapreduceInputFormat = true + } + else if (classOf[org.apache.hadoop.mapred.InputFormat[_, _]].isAssignableFrom(inputFormatClazz)) { + logDebug("inputformat is from mapred package") + mapredInputFormat = true + } + else { + throw new IllegalArgumentException("Specified inputformat " + inputFormatClazz + + " is NOT a supported input format ? does not implement either of the supported hadoop api's") + } + } + catch { + case e: ClassNotFoundException => { + throw new IllegalArgumentException("Specified inputformat " + inputFormatClazz + " cannot be found ?", e) + } + } + } + + + // This method does not expect failures, since validate has already passed ... + private def prefLocsFromMapreduceInputFormat(): Set[SplitInfo] = { + val conf = new JobConf(configuration) + FileInputFormat.setInputPaths(conf, path) + + val instance: org.apache.hadoop.mapreduce.InputFormat[_, _] = + ReflectionUtils.newInstance(inputFormatClazz.asInstanceOf[Class[_]], conf).asInstanceOf[ + org.apache.hadoop.mapreduce.InputFormat[_, _]] + val job = new Job(conf) + + val retval = new ArrayBuffer[SplitInfo]() + val list = instance.getSplits(job) + for (split <- list) { + retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, split) + } + + return retval.toSet + } + + // This method does not expect failures, since validate has already passed ... + private def prefLocsFromMapredInputFormat(): Set[SplitInfo] = { + val jobConf = new JobConf(configuration) + FileInputFormat.setInputPaths(jobConf, path) + + val instance: org.apache.hadoop.mapred.InputFormat[_, _] = + ReflectionUtils.newInstance(inputFormatClazz.asInstanceOf[Class[_]], jobConf).asInstanceOf[ + org.apache.hadoop.mapred.InputFormat[_, _]] + + val retval = new ArrayBuffer[SplitInfo]() + instance.getSplits(jobConf, jobConf.getNumMapTasks()).foreach( + elem => retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, elem) + ) + + return retval.toSet + } + + private def findPreferredLocations(): Set[SplitInfo] = { + logDebug("mapreduceInputFormat : " + mapreduceInputFormat + ", mapredInputFormat : " + mapredInputFormat + + ", inputFormatClazz : " + inputFormatClazz) + if (mapreduceInputFormat) { + return prefLocsFromMapreduceInputFormat() + } + else { + assert(mapredInputFormat) + return prefLocsFromMapredInputFormat() + } + } +} + + + + +object InputFormatInfo { + /** + Computes the preferred locations based on input(s) and returned a location to block map. + Typical use of this method for allocation would follow some algo like this + (which is what we currently do in YARN branch) : + a) For each host, count number of splits hosted on that host. + b) Decrement the currently allocated containers on that host. + c) Compute rack info for each host and update rack -> count map based on (b). + d) Allocate nodes based on (c) + e) On the allocation result, ensure that we dont allocate "too many" jobs on a single node + (even if data locality on that is very high) : this is to prevent fragility of job if a single + (or small set of) hosts go down. + + go to (a) until required nodes are allocated. + + If a node 'dies', follow same procedure. + + PS: I know the wording here is weird, hopefully it makes some sense ! + */ + def computePreferredLocations(formats: Seq[InputFormatInfo]): HashMap[String, HashSet[SplitInfo]] = { + + val nodeToSplit = new HashMap[String, HashSet[SplitInfo]] + for (inputSplit <- formats) { + val splits = inputSplit.findPreferredLocations() + + for (split <- splits){ + val location = split.hostLocation + val set = nodeToSplit.getOrElseUpdate(location, new HashSet[SplitInfo]) + set += split + } + } + + nodeToSplit + } +} diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala index beb21a76fe..89dc6640b2 100644 --- a/core/src/main/scala/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/spark/scheduler/ResultTask.scala @@ -70,6 +70,14 @@ private[spark] class ResultTask[T, U]( rdd.partitions(partition) } + // data locality is on a per host basis, not hyper specific to container (host:port). Unique on set of hosts. + val preferredLocs: Seq[String] = if (locs == null) Nil else locs.map(loc => Utils.parseHostPort(loc)._1).toSet.toSeq + + { + // DEBUG code + preferredLocs.foreach (host => Utils.checkHost(host, "preferredLocs : " + preferredLocs)) + } + override def run(attemptId: Long): U = { val context = new TaskContext(stageId, partition, attemptId) metrics = Some(context.taskMetrics) @@ -80,7 +88,7 @@ private[spark] class ResultTask[T, U]( } } - override def preferredLocations: Seq[String] = locs + override def preferredLocations: Seq[String] = preferredLocs override def toString = "ResultTask(" + stageId + ", " + partition + ")" diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 36d087a4d0..7dc6da4573 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -77,13 +77,21 @@ private[spark] class ShuffleMapTask( var rdd: RDD[_], var dep: ShuffleDependency[_,_], var partition: Int, - @transient var locs: Seq[String]) + @transient private var locs: Seq[String]) extends Task[MapStatus](stageId) with Externalizable with Logging { protected def this() = this(0, null, null, 0, null) + // data locality is on a per host basis, not hyper specific to container (host:port). Unique on set of hosts. + private val preferredLocs: Seq[String] = if (locs == null) Nil else locs.map(loc => Utils.parseHostPort(loc)._1).toSet.toSeq + + { + // DEBUG code + preferredLocs.foreach (host => Utils.checkHost(host, "preferredLocs : " + preferredLocs)) + } + var split = if (rdd == null) { null } else { @@ -154,7 +162,7 @@ private[spark] class ShuffleMapTask( } } - override def preferredLocations: Seq[String] = locs + override def preferredLocations: Seq[String] = preferredLocs override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition) } diff --git a/core/src/main/scala/spark/scheduler/SplitInfo.scala b/core/src/main/scala/spark/scheduler/SplitInfo.scala new file mode 100644 index 0000000000..6abfb7a1f7 --- /dev/null +++ b/core/src/main/scala/spark/scheduler/SplitInfo.scala @@ -0,0 +1,61 @@ +package spark.scheduler + +import collection.mutable.ArrayBuffer + +// information about a specific split instance : handles both split instances. +// So that we do not need to worry about the differences. +class SplitInfo(val inputFormatClazz: Class[_], val hostLocation: String, val path: String, + val length: Long, val underlyingSplit: Any) { + override def toString(): String = { + "SplitInfo " + super.toString + " .. inputFormatClazz " + inputFormatClazz + + ", hostLocation : " + hostLocation + ", path : " + path + + ", length : " + length + ", underlyingSplit " + underlyingSplit + } + + override def hashCode(): Int = { + var hashCode = inputFormatClazz.hashCode + hashCode = hashCode * 31 + hostLocation.hashCode + hashCode = hashCode * 31 + path.hashCode + // ignore overflow ? It is hashcode anyway ! + hashCode = hashCode * 31 + (length & 0x7fffffff).toInt + hashCode + } + + // This is practically useless since most of the Split impl's dont seem to implement equals :-( + // So unless there is identity equality between underlyingSplits, it will always fail even if it + // is pointing to same block. + override def equals(other: Any): Boolean = other match { + case that: SplitInfo => { + this.hostLocation == that.hostLocation && + this.inputFormatClazz == that.inputFormatClazz && + this.path == that.path && + this.length == that.length && + // other split specific checks (like start for FileSplit) + this.underlyingSplit == that.underlyingSplit + } + case _ => false + } +} + +object SplitInfo { + + def toSplitInfo(inputFormatClazz: Class[_], path: String, + mapredSplit: org.apache.hadoop.mapred.InputSplit): Seq[SplitInfo] = { + val retval = new ArrayBuffer[SplitInfo]() + val length = mapredSplit.getLength + for (host <- mapredSplit.getLocations) { + retval += new SplitInfo(inputFormatClazz, host, path, length, mapredSplit) + } + retval + } + + def toSplitInfo(inputFormatClazz: Class[_], path: String, + mapreduceSplit: org.apache.hadoop.mapreduce.InputSplit): Seq[SplitInfo] = { + val retval = new ArrayBuffer[SplitInfo]() + val length = mapreduceSplit.getLength + for (host <- mapreduceSplit.getLocations) { + retval += new SplitInfo(inputFormatClazz, host, path, length, mapreduceSplit) + } + retval + } +} diff --git a/core/src/main/scala/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/spark/scheduler/TaskScheduler.scala index d549b184b0..7787b54762 100644 --- a/core/src/main/scala/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/spark/scheduler/TaskScheduler.scala @@ -10,6 +10,10 @@ package spark.scheduler private[spark] trait TaskScheduler { def start(): Unit + // Invoked after system has successfully initialized (typically in spark context). + // Yarn uses this to bootstrap allocation of resources based on preferred locations, wait for slave registerations, etc. + def postStartHook() { } + // Disconnect from the cluster. def stop(): Unit diff --git a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala index 771518dddf..b75d3736cf 100644 --- a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala +++ b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala @@ -14,6 +14,9 @@ private[spark] trait TaskSchedulerListener { def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any], taskInfo: TaskInfo, taskMetrics: TaskMetrics): Unit + // A node was added to the cluster. + def executorGained(execId: String, hostPort: String): Unit + // A node was lost from the cluster. def executorLost(execId: String): Unit diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 26fdef101b..2e18d46edc 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -1,6 +1,6 @@ package spark.scheduler.cluster -import java.io.{File, FileInputStream, FileOutputStream} +import java.lang.{Boolean => JBoolean} import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap @@ -25,6 +25,30 @@ private[spark] class ClusterScheduler(val sc: SparkContext) val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong // Threshold above which we warn user initial TaskSet may be starved val STARVATION_TIMEOUT = System.getProperty("spark.starvation.timeout", "15000").toLong + // How often to revive offers in case there are pending tasks - that is how often to try to get + // tasks scheduled in case there are nodes available : default 0 is to disable it - to preserve existing behavior + // Note that this is required due to delayed scheduling due to data locality waits, etc. + // TODO: rename property ? + val TASK_REVIVAL_INTERVAL = System.getProperty("spark.tasks.revive.interval", "0").toLong + + /* + This property controls how aggressive we should be to modulate waiting for host local task scheduling. + To elaborate, currently there is a time limit (3 sec def) to ensure that spark attempts to wait for host locality of tasks before + scheduling on other nodes. We have modified this in yarn branch such that offers to task set happen in prioritized order : + host-local, rack-local and then others + But once all available host local (and no pref) tasks are scheduled, instead of waiting for 3 sec before + scheduling to other nodes (which degrades performance for time sensitive tasks and on larger clusters), we can + modulate that : to also allow rack local nodes or any node. The default is still set to HOST - so that previous behavior is + maintained. This is to allow tuning the tension between pulling rdd data off node and scheduling computation asap. + + TODO: rename property ? The value is one of + - HOST_LOCAL (default, no change w.r.t current behavior), + - RACK_LOCAL and + - ANY + + Note that this property makes more sense when used in conjugation with spark.tasks.revive.interval > 0 : else it is not very effective. + */ + val TASK_SCHEDULING_AGGRESSION = TaskLocality.parse(System.getProperty("spark.tasks.schedule.aggression", "HOST_LOCAL")) val activeTaskSets = new HashMap[String, TaskSetManager] var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager] @@ -33,9 +57,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext) val taskIdToExecutorId = new HashMap[Long, String] val taskSetTaskIds = new HashMap[String, HashSet[Long]] - var hasReceivedTask = false - var hasLaunchedTask = false - val starvationTimer = new Timer(true) + @volatile private var hasReceivedTask = false + @volatile private var hasLaunchedTask = false + private val starvationTimer = new Timer(true) // Incrementing Mesos task IDs val nextTaskId = new AtomicLong(0) @@ -43,11 +67,16 @@ private[spark] class ClusterScheduler(val sc: SparkContext) // Which executor IDs we have executors on val activeExecutorIds = new HashSet[String] + // TODO: We might want to remove this and merge it with execId datastructures - but later. + // Which hosts in the cluster are alive (contains hostPort's) + private val hostPortsAlive = new HashSet[String] + private val hostToAliveHostPorts = new HashMap[String, HashSet[String]] + // The set of executors we have on each host; this is used to compute hostsAlive, which // in turn is used to decide when we can attain data locality on a given host - val executorsByHost = new HashMap[String, HashSet[String]] + val executorsByHostPort = new HashMap[String, HashSet[String]] - val executorIdToHost = new HashMap[String, String] + val executorIdToHostPort = new HashMap[String, String] // JAR server, if any JARs were added by the user to the SparkContext var jarServer: HttpServer = null @@ -75,11 +104,12 @@ private[spark] class ClusterScheduler(val sc: SparkContext) override def start() { backend.start() - if (System.getProperty("spark.speculation", "false") == "true") { + if (JBoolean.getBoolean("spark.speculation")) { new Thread("ClusterScheduler speculation check") { setDaemon(true) override def run() { + logInfo("Starting speculative execution thread") while (true) { try { Thread.sleep(SPECULATION_INTERVAL) @@ -91,6 +121,27 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } }.start() } + + + // Change to always run with some default if TASK_REVIVAL_INTERVAL <= 0 ? + if (TASK_REVIVAL_INTERVAL > 0) { + new Thread("ClusterScheduler task offer revival check") { + setDaemon(true) + + override def run() { + logInfo("Starting speculative task offer revival thread") + while (true) { + try { + Thread.sleep(TASK_REVIVAL_INTERVAL) + } catch { + case e: InterruptedException => {} + } + + if (hasPendingTasks()) backend.reviveOffers() + } + } + }.start() + } } override def submitTasks(taskSet: TaskSet) { @@ -139,22 +190,92 @@ private[spark] class ClusterScheduler(val sc: SparkContext) SparkEnv.set(sc.env) // Mark each slave as alive and remember its hostname for (o <- offers) { - executorIdToHost(o.executorId) = o.hostname - if (!executorsByHost.contains(o.hostname)) { - executorsByHost(o.hostname) = new HashSet() + // DEBUG Code + Utils.checkHostPort(o.hostPort) + + executorIdToHostPort(o.executorId) = o.hostPort + if (! executorsByHostPort.contains(o.hostPort)) { + executorsByHostPort(o.hostPort) = new HashSet[String]() } + + hostPortsAlive += o.hostPort + hostToAliveHostPorts.getOrElseUpdate(Utils.parseHostPort(o.hostPort)._1, new HashSet[String]).add(o.hostPort) + executorGained(o.executorId, o.hostPort) } // Build a list of tasks to assign to each slave val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores)) val availableCpus = offers.map(o => o.cores).toArray var launchedTask = false + + for (manager <- activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId))) { + + // Split offers based on host local, rack local and off-rack tasks. + val hostLocalOffers = new HashMap[String, ArrayBuffer[Int]]() + val rackLocalOffers = new HashMap[String, ArrayBuffer[Int]]() + val otherOffers = new HashMap[String, ArrayBuffer[Int]]() + + for (i <- 0 until offers.size) { + val hostPort = offers(i).hostPort + // DEBUG code + Utils.checkHostPort(hostPort) + val host = Utils.parseHostPort(hostPort)._1 + val numHostLocalTasks = math.max(0, math.min(manager.numPendingTasksForHost(hostPort), availableCpus(i))) + if (numHostLocalTasks > 0){ + val list = hostLocalOffers.getOrElseUpdate(host, new ArrayBuffer[Int]) + for (j <- 0 until numHostLocalTasks) list += i + } + + val numRackLocalTasks = math.max(0, + // Remove host local tasks (which are also rack local btw !) from this + math.min(manager.numRackLocalPendingTasksForHost(hostPort) - numHostLocalTasks, availableCpus(i))) + if (numRackLocalTasks > 0){ + val list = rackLocalOffers.getOrElseUpdate(host, new ArrayBuffer[Int]) + for (j <- 0 until numRackLocalTasks) list += i + } + if (numHostLocalTasks <= 0 && numRackLocalTasks <= 0){ + // add to others list - spread even this across cluster. + val list = otherOffers.getOrElseUpdate(host, new ArrayBuffer[Int]) + list += i + } + } + + val offersPriorityList = new ArrayBuffer[Int]( + hostLocalOffers.size + rackLocalOffers.size + otherOffers.size) + // First host local, then rack, then others + val numHostLocalOffers = { + val hostLocalPriorityList = ClusterScheduler.prioritizeContainers(hostLocalOffers) + offersPriorityList ++= hostLocalPriorityList + hostLocalPriorityList.size + } + val numRackLocalOffers = { + val rackLocalPriorityList = ClusterScheduler.prioritizeContainers(rackLocalOffers) + offersPriorityList ++= rackLocalPriorityList + rackLocalPriorityList.size + } + offersPriorityList ++= ClusterScheduler.prioritizeContainers(otherOffers) + + var lastLoop = false + val lastLoopIndex = TASK_SCHEDULING_AGGRESSION match { + case TaskLocality.HOST_LOCAL => numHostLocalOffers + case TaskLocality.RACK_LOCAL => numRackLocalOffers + numHostLocalOffers + case TaskLocality.ANY => offersPriorityList.size + } + do { launchedTask = false - for (i <- 0 until offers.size) { + var loopCount = 0 + for (i <- offersPriorityList) { val execId = offers(i).executorId - val host = offers(i).hostname - manager.slaveOffer(execId, host, availableCpus(i)) match { + val hostPort = offers(i).hostPort + + // If last loop and within the lastLoopIndex, expand scope - else use null (which will use default/existing) + val overrideLocality = if (lastLoop && loopCount < lastLoopIndex) TASK_SCHEDULING_AGGRESSION else null + + // If last loop, override waiting for host locality - we scheduled all local tasks already and there might be more available ... + loopCount += 1 + + manager.slaveOffer(execId, hostPort, availableCpus(i), overrideLocality) match { case Some(task) => tasks(i) += task val tid = task.taskId @@ -162,15 +283,31 @@ private[spark] class ClusterScheduler(val sc: SparkContext) taskSetTaskIds(manager.taskSet.id) += tid taskIdToExecutorId(tid) = execId activeExecutorIds += execId - executorsByHost(host) += execId + executorsByHostPort(hostPort) += execId availableCpus(i) -= 1 launchedTask = true - + case None => {} } } + // Loop once more - when lastLoop = true, then we try to schedule task on all nodes irrespective of + // data locality (we still go in order of priority : but that would not change anything since + // if data local tasks had been available, we would have scheduled them already) + if (lastLoop) { + // prevent more looping + launchedTask = false + } else if (!lastLoop && !launchedTask) { + // Do this only if TASK_SCHEDULING_AGGRESSION != HOST_LOCAL + if (TASK_SCHEDULING_AGGRESSION != TaskLocality.HOST_LOCAL) { + // fudge launchedTask to ensure we loop once more + launchedTask = true + // dont loop anymore + lastLoop = true + } + } } while (launchedTask) } + if (tasks.size > 0) { hasLaunchedTask = true } @@ -256,10 +393,15 @@ private[spark] class ClusterScheduler(val sc: SparkContext) if (jarServer != null) { jarServer.stop() } + + // sleeping for an arbitrary 5 seconds : to ensure that messages are sent out. + // TODO: Do something better ! + Thread.sleep(5000L) } override def defaultParallelism() = backend.defaultParallelism() + // Check for speculatable tasks in all our active jobs. def checkSpeculatableTasks() { var shouldRevive = false @@ -273,12 +415,20 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } } + // Check for pending tasks in all our active jobs. + def hasPendingTasks(): Boolean = { + synchronized { + activeTaskSetsQueue.exists( _.hasPendingTasks() ) + } + } + def executorLost(executorId: String, reason: ExecutorLossReason) { var failedExecutor: Option[String] = None + synchronized { if (activeExecutorIds.contains(executorId)) { - val host = executorIdToHost(executorId) - logError("Lost executor %s on %s: %s".format(executorId, host, reason)) + val hostPort = executorIdToHostPort(executorId) + logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason)) removeExecutor(executorId) failedExecutor = Some(executorId) } else { @@ -296,19 +446,95 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } } - /** Get a list of hosts that currently have executors */ - def hostsAlive: scala.collection.Set[String] = executorsByHost.keySet - /** Remove an executor from all our data structures and mark it as lost */ private def removeExecutor(executorId: String) { activeExecutorIds -= executorId - val host = executorIdToHost(executorId) - val execs = executorsByHost.getOrElse(host, new HashSet) + val hostPort = executorIdToHostPort(executorId) + if (hostPortsAlive.contains(hostPort)) { + // DEBUG Code + Utils.checkHostPort(hostPort) + + hostPortsAlive -= hostPort + hostToAliveHostPorts.getOrElseUpdate(Utils.parseHostPort(hostPort)._1, new HashSet[String]).remove(hostPort) + } + + val execs = executorsByHostPort.getOrElse(hostPort, new HashSet) execs -= executorId if (execs.isEmpty) { - executorsByHost -= host + executorsByHostPort -= hostPort } - executorIdToHost -= executorId - activeTaskSetsQueue.foreach(_.executorLost(executorId, host)) + executorIdToHostPort -= executorId + activeTaskSetsQueue.foreach(_.executorLost(executorId, hostPort)) + } + + def executorGained(execId: String, hostPort: String) { + listener.executorGained(execId, hostPort) + } + + def getExecutorsAliveOnHost(host: String): Option[Set[String]] = { + val retval = hostToAliveHostPorts.get(host) + if (retval.isDefined) { + return Some(retval.get.toSet) + } + + None + } + + // By default, rack is unknown + def getRackForHost(value: String): Option[String] = None + + // By default, (cached) hosts for rack is unknown + def getCachedHostsForRack(rack: String): Option[Set[String]] = None +} + + +object ClusterScheduler { + + // Used to 'spray' available containers across the available set to ensure too many containers on same host + // are not used up. Used in yarn mode and in task scheduling (when there are multiple containers available + // to execute a task) + // For example: yarn can returns more containers than we would have requested under ANY, this method + // prioritizes how to use the allocated containers. + // flatten the map such that the array buffer entries are spread out across the returned value. + // given <host, list[container]> == <h1, [c1 .. c5]>, <h2, [c1 .. c3]>, <h3, [c1, c2]>, <h4, c1>, <h5, c1>, i + // the return value would be something like : h1c1, h2c1, h3c1, h4c1, h5c1, h1c2, h2c2, h3c2, h1c3, h2c3, h1c4, h1c5 + // We then 'use' the containers in this order (consuming only the top K from this list where + // K = number to be user). This is to ensure that if we have multiple eligible allocations, + // they dont end up allocating all containers on a small number of hosts - increasing probability of + // multiple container failure when a host goes down. + // Note, there is bias for keys with higher number of entries in value to be picked first (by design) + // Also note that invocation of this method is expected to have containers of same 'type' + // (host-local, rack-local, off-rack) and not across types : so that reordering is simply better from + // the available list - everything else being same. + // That is, we we first consume data local, then rack local and finally off rack nodes. So the + // prioritization from this method applies to within each category + def prioritizeContainers[K, T] (map: HashMap[K, ArrayBuffer[T]]): List[T] = { + val _keyList = new ArrayBuffer[K](map.size) + _keyList ++= map.keys + + // order keyList based on population of value in map + val keyList = _keyList.sortWith( + (left, right) => map.get(left).getOrElse(Set()).size > map.get(right).getOrElse(Set()).size + ) + + val retval = new ArrayBuffer[T](keyList.size * 2) + var index = 0 + var found = true + + while (found){ + found = false + for (key <- keyList) { + val containerList: ArrayBuffer[T] = map.get(key).getOrElse(null) + assert(containerList != null) + // Get the index'th entry for this host - if present + if (index < containerList.size){ + retval += containerList.apply(index) + found = true + } + } + index += 1 + } + + retval.toList } } diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index bb289c9cf3..6b61152ed0 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -27,7 +27,7 @@ private[spark] class SparkDeploySchedulerBackend( val driverUrl = "akka://spark@%s:%s/user/%s".format( System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"), StandaloneSchedulerBackend.ACTOR_NAME) - val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}") + val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTPORT}}", "{{CORES}}") val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) val sparkHome = sc.getSparkHome().getOrElse( throw new IllegalArgumentException("must supply spark home for spark standalone")) @@ -57,9 +57,9 @@ private[spark] class SparkDeploySchedulerBackend( } } - override def executorAdded(executorId: String, workerId: String, host: String, cores: Int, memory: Int) { - logInfo("Granted executor ID %s on host %s with %d cores, %s RAM".format( - executorId, host, cores, Utils.memoryMegabytesToString(memory))) + override def executorAdded(executorId: String, workerId: String, hostPort: String, cores: Int, memory: Int) { + logInfo("Granted executor ID %s on hostPort %s with %d cores, %s RAM".format( + executorId, hostPort, cores, Utils.memoryMegabytesToString(memory))) } override def executorRemoved(executorId: String, message: String, exitStatus: Option[Int]) { diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala index d766067824..3335294844 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala @@ -3,6 +3,7 @@ package spark.scheduler.cluster import spark.TaskState.TaskState import java.nio.ByteBuffer import spark.util.SerializableBuffer +import spark.Utils private[spark] sealed trait StandaloneClusterMessage extends Serializable @@ -19,8 +20,10 @@ case class RegisterExecutorFailed(message: String) extends StandaloneClusterMess // Executors to driver private[spark] -case class RegisterExecutor(executorId: String, host: String, cores: Int) - extends StandaloneClusterMessage +case class RegisterExecutor(executorId: String, hostPort: String, cores: Int) + extends StandaloneClusterMessage { + Utils.checkHostPort(hostPort, "Expected host port") +} private[spark] case class StatusUpdate(executorId: String, taskId: Long, state: TaskState, data: SerializableBuffer) diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 7a428e3361..c20276a605 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -5,8 +5,9 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import akka.actor._ import akka.util.duration._ import akka.pattern.ask +import akka.util.Duration -import spark.{SparkException, Logging, TaskState} +import spark.{Utils, SparkException, Logging, TaskState} import akka.dispatch.Await import java.util.concurrent.atomic.AtomicInteger import akka.remote.{RemoteClientShutdown, RemoteClientDisconnected, RemoteClientLifeCycleEvent} @@ -24,12 +25,12 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor var totalCoreCount = new AtomicInteger(0) class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor { - val executorActor = new HashMap[String, ActorRef] - val executorAddress = new HashMap[String, Address] - val executorHost = new HashMap[String, String] - val freeCores = new HashMap[String, Int] - val actorToExecutorId = new HashMap[ActorRef, String] - val addressToExecutorId = new HashMap[Address, String] + private val executorActor = new HashMap[String, ActorRef] + private val executorAddress = new HashMap[String, Address] + private val executorHostPort = new HashMap[String, String] + private val freeCores = new HashMap[String, Int] + private val actorToExecutorId = new HashMap[ActorRef, String] + private val addressToExecutorId = new HashMap[Address, String] override def preStart() { // Listen for remote client disconnection events, since they don't go through Akka's watch() @@ -37,7 +38,8 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } def receive = { - case RegisterExecutor(executorId, host, cores) => + case RegisterExecutor(executorId, hostPort, cores) => + Utils.checkHostPort(hostPort, "Host port expected " + hostPort) if (executorActor.contains(executorId)) { sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId) } else { @@ -45,7 +47,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor sender ! RegisteredExecutor(sparkProperties) context.watch(sender) executorActor(executorId) = sender - executorHost(executorId) = host + executorHostPort(executorId) = hostPort freeCores(executorId) = cores executorAddress(executorId) = sender.path.address actorToExecutorId(sender) = executorId @@ -85,13 +87,13 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor // Make fake resource offers on all executors def makeOffers() { launchTasks(scheduler.resourceOffers( - executorHost.toArray.map {case (id, host) => new WorkerOffer(id, host, freeCores(id))})) + executorHostPort.toArray.map {case (id, hostPort) => new WorkerOffer(id, hostPort, freeCores(id))})) } // Make fake resource offers on just one executor def makeOffers(executorId: String) { launchTasks(scheduler.resourceOffers( - Seq(new WorkerOffer(executorId, executorHost(executorId), freeCores(executorId))))) + Seq(new WorkerOffer(executorId, executorHostPort(executorId), freeCores(executorId))))) } // Launch tasks returned by a set of resource offers @@ -110,9 +112,9 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor actorToExecutorId -= executorActor(executorId) addressToExecutorId -= executorAddress(executorId) executorActor -= executorId - executorHost -= executorId + executorHostPort -= executorId freeCores -= executorId - executorHost -= executorId + executorHostPort -= executorId totalCoreCount.addAndGet(-numCores) scheduler.executorLost(executorId, SlaveLost(reason)) } @@ -128,7 +130,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor while (iterator.hasNext) { val entry = iterator.next val (key, value) = (entry.getKey.toString, entry.getValue.toString) - if (key.startsWith("spark.")) { + if (key.startsWith("spark.") && !key.equals("spark.hostPort")) { properties += ((key, value)) } } @@ -136,10 +138,11 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor Props(new DriverActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME) } + private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") + override def stop() { try { if (driverActor != null) { - val timeout = 5.seconds val future = driverActor.ask(StopDriver)(timeout) Await.result(future, timeout) } diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala index dfe3c5a85b..718f26bfbd 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala @@ -1,5 +1,7 @@ package spark.scheduler.cluster +import spark.Utils + /** * Information about a running task attempt inside a TaskSet. */ @@ -9,8 +11,11 @@ class TaskInfo( val index: Int, val launchTime: Long, val executorId: String, - val host: String, - val preferred: Boolean) { + val hostPort: String, + val taskLocality: TaskLocality.TaskLocality) { + + Utils.checkHostPort(hostPort, "Expected hostport") + var finishTime: Long = 0 var failed = false diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index c9f2c48804..27e713e2c4 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -1,7 +1,6 @@ package spark.scheduler.cluster -import java.util.Arrays -import java.util.{HashMap => JHashMap} +import java.util.{HashMap => JHashMap, NoSuchElementException, Arrays} import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap @@ -14,6 +13,36 @@ import spark.scheduler._ import spark.TaskState.TaskState import java.nio.ByteBuffer +private[spark] object TaskLocality extends Enumeration("HOST_LOCAL", "RACK_LOCAL", "ANY") with Logging { + + val HOST_LOCAL, RACK_LOCAL, ANY = Value + + type TaskLocality = Value + + def isAllowed(constraint: TaskLocality, condition: TaskLocality): Boolean = { + + constraint match { + case TaskLocality.HOST_LOCAL => condition == TaskLocality.HOST_LOCAL + case TaskLocality.RACK_LOCAL => condition == TaskLocality.HOST_LOCAL || condition == TaskLocality.RACK_LOCAL + // For anything else, allow + case _ => true + } + } + + def parse(str: String): TaskLocality = { + // better way to do this ? + try { + TaskLocality.withName(str) + } catch { + case nEx: NoSuchElementException => { + logWarning("Invalid task locality specified '" + str + "', defaulting to HOST_LOCAL"); + // default to preserve earlier behavior + HOST_LOCAL + } + } + } +} + /** * Schedules the tasks within a single TaskSet in the ClusterScheduler. */ @@ -47,14 +76,22 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe // Last time when we launched a preferred task (for delay scheduling) var lastPreferredLaunchTime = System.currentTimeMillis - // List of pending tasks for each node. These collections are actually + // List of pending tasks for each node (hyper local to container). These collections are actually // treated as stacks, in which new tasks are added to the end of the // ArrayBuffer and removed from the end. This makes it faster to detect // tasks that repeatedly fail because whenever a task failed, it is put // back at the head of the stack. They are also only cleaned up lazily; // when a task is launched, it remains in all the pending lists except // the one that it was launched from, but gets removed from them later. - val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]] + private val pendingTasksForHostPort = new HashMap[String, ArrayBuffer[Int]] + + // List of pending tasks for each node. + // Essentially, similar to pendingTasksForHostPort, except at host level + private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]] + + // List of pending tasks for each node based on rack locality. + // Essentially, similar to pendingTasksForHost, except at rack level + private val pendingRackLocalTasksForHost = new HashMap[String, ArrayBuffer[Int]] // List containing pending tasks with no locality preferences val pendingTasksWithNoPrefs = new ArrayBuffer[Int] @@ -96,26 +133,117 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe addPendingTask(i) } + private def findPreferredLocations(_taskPreferredLocations: Seq[String], scheduler: ClusterScheduler, rackLocal: Boolean = false): ArrayBuffer[String] = { + // DEBUG code + _taskPreferredLocations.foreach(h => Utils.checkHost(h, "taskPreferredLocation " + _taskPreferredLocations)) + + val taskPreferredLocations = if (! rackLocal) _taskPreferredLocations else { + // Expand set to include all 'seen' rack local hosts. + // This works since container allocation/management happens within master - so any rack locality information is updated in msater. + // Best case effort, and maybe sort of kludge for now ... rework it later ? + val hosts = new HashSet[String] + _taskPreferredLocations.foreach(h => { + val rackOpt = scheduler.getRackForHost(h) + if (rackOpt.isDefined) { + val hostsOpt = scheduler.getCachedHostsForRack(rackOpt.get) + if (hostsOpt.isDefined) { + hosts ++= hostsOpt.get + } + } + + // Ensure that irrespective of what scheduler says, host is always added ! + hosts += h + }) + + hosts + } + + val retval = new ArrayBuffer[String] + scheduler.synchronized { + for (prefLocation <- taskPreferredLocations) { + val aliveLocationsOpt = scheduler.getExecutorsAliveOnHost(prefLocation) + if (aliveLocationsOpt.isDefined) { + retval ++= aliveLocationsOpt.get + } + } + } + + retval + } + // Add a task to all the pending-task lists that it should be on. private def addPendingTask(index: Int) { - val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive - if (locations.size == 0) { + // We can infer hostLocalLocations from rackLocalLocations by joining it against tasks(index).preferredLocations (with appropriate + // hostPort <-> host conversion). But not doing it for simplicity sake. If this becomes a performance issue, modify it. + val hostLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched) + val rackLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, true) + + if (rackLocalLocations.size == 0) { + // Current impl ensures this. + assert (hostLocalLocations.size == 0) pendingTasksWithNoPrefs += index } else { - for (host <- locations) { - val list = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer()) + + // host locality + for (hostPort <- hostLocalLocations) { + // DEBUG Code + Utils.checkHostPort(hostPort) + + val hostPortList = pendingTasksForHostPort.getOrElseUpdate(hostPort, ArrayBuffer()) + hostPortList += index + + val host = Utils.parseHostPort(hostPort)._1 + val hostList = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer()) + hostList += index + } + + // rack locality + for (rackLocalHostPort <- rackLocalLocations) { + // DEBUG Code + Utils.checkHostPort(rackLocalHostPort) + + val rackLocalHost = Utils.parseHostPort(rackLocalHostPort)._1 + val list = pendingRackLocalTasksForHost.getOrElseUpdate(rackLocalHost, ArrayBuffer()) list += index } } + allPendingTasks += index } + // Return the pending tasks list for a given host port (hyper local), or an empty list if + // there is no map entry for that host + private def getPendingTasksForHostPort(hostPort: String): ArrayBuffer[Int] = { + // DEBUG Code + Utils.checkHostPort(hostPort) + pendingTasksForHostPort.getOrElse(hostPort, ArrayBuffer()) + } + // Return the pending tasks list for a given host, or an empty list if // there is no map entry for that host - private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = { + private def getPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = { + val host = Utils.parseHostPort(hostPort)._1 pendingTasksForHost.getOrElse(host, ArrayBuffer()) } + // Return the pending tasks (rack level) list for a given host, or an empty list if + // there is no map entry for that host + private def getRackLocalPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = { + val host = Utils.parseHostPort(hostPort)._1 + pendingRackLocalTasksForHost.getOrElse(host, ArrayBuffer()) + } + + // Number of pending tasks for a given host (which would be data local) + def numPendingTasksForHost(hostPort: String): Int = { + getPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) ) + } + + // Number of pending rack local tasks for a given host + def numRackLocalPendingTasksForHost(hostPort: String): Int = { + getRackLocalPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) ) + } + + // Dequeue a pending task from the given list and return its index. // Return None if the list is empty. // This method also cleans up any tasks in the list that have already @@ -132,26 +260,49 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe } // Return a speculative task for a given host if any are available. The task should not have an - // attempt running on this host, in case the host is slow. In addition, if localOnly is set, the - // task must have a preference for this host (or no preferred locations at all). - private def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = { - val hostsAlive = sched.hostsAlive + // attempt running on this host, in case the host is slow. In addition, if locality is set, the + // task must have a preference for this host/rack/no preferred locations at all. + private def findSpeculativeTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = { + + assert (TaskLocality.isAllowed(locality, TaskLocality.HOST_LOCAL)) speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set - val localTask = speculatableTasks.find { - index => - val locations = tasks(index).preferredLocations.toSet & hostsAlive - val attemptLocs = taskAttempts(index).map(_.host) - (locations.size == 0 || locations.contains(host)) && !attemptLocs.contains(host) + + if (speculatableTasks.size > 0) { + val localTask = speculatableTasks.find { + index => + val locations = findPreferredLocations(tasks(index).preferredLocations, sched) + val attemptLocs = taskAttempts(index).map(_.hostPort) + (locations.size == 0 || locations.contains(hostPort)) && !attemptLocs.contains(hostPort) + } + + if (localTask != None) { + speculatableTasks -= localTask.get + return localTask } - if (localTask != None) { - speculatableTasks -= localTask.get - return localTask - } - if (!localOnly && speculatableTasks.size > 0) { - val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.host).contains(host)) - if (nonLocalTask != None) { - speculatableTasks -= nonLocalTask.get - return nonLocalTask + + // check for rack locality + if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { + val rackTask = speculatableTasks.find { + index => + val locations = findPreferredLocations(tasks(index).preferredLocations, sched, true) + val attemptLocs = taskAttempts(index).map(_.hostPort) + locations.contains(hostPort) && !attemptLocs.contains(hostPort) + } + + if (rackTask != None) { + speculatableTasks -= rackTask.get + return rackTask + } + } + + // Any task ... + if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { + // Check for attemptLocs also ? + val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.hostPort).contains(hostPort)) + if (nonLocalTask != None) { + speculatableTasks -= nonLocalTask.get + return nonLocalTask + } } } return None @@ -159,59 +310,103 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe // Dequeue a pending task for a given node and return its index. // If localOnly is set to false, allow non-local tasks as well. - private def findTask(host: String, localOnly: Boolean): Option[Int] = { - val localTask = findTaskFromList(getPendingTasksForHost(host)) + private def findTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = { + val localTask = findTaskFromList(getPendingTasksForHost(hostPort)) if (localTask != None) { return localTask } + + if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { + val rackLocalTask = findTaskFromList(getRackLocalPendingTasksForHost(hostPort)) + if (rackLocalTask != None) { + return rackLocalTask + } + } + + // Look for no pref tasks AFTER rack local tasks - this has side effect that we will get to failed tasks later rather than sooner. + // TODO: That code path needs to be revisited (adding to no prefs list when host:port goes down). val noPrefTask = findTaskFromList(pendingTasksWithNoPrefs) if (noPrefTask != None) { return noPrefTask } - if (!localOnly) { + + if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { val nonLocalTask = findTaskFromList(allPendingTasks) if (nonLocalTask != None) { return nonLocalTask } } + // Finally, if all else has failed, find a speculative task - return findSpeculativeTask(host, localOnly) + return findSpeculativeTask(hostPort, locality) } // Does a host count as a preferred location for a task? This is true if // either the task has preferred locations and this host is one, or it has // no preferred locations (in which we still count the launch as preferred). - private def isPreferredLocation(task: Task[_], host: String): Boolean = { + private def isPreferredLocation(task: Task[_], hostPort: String): Boolean = { val locs = task.preferredLocations - return (locs.contains(host) || locs.isEmpty) + // DEBUG code + locs.foreach(h => Utils.checkHost(h, "preferredLocation " + locs)) + + if (locs.contains(hostPort) || locs.isEmpty) return true + + val host = Utils.parseHostPort(hostPort)._1 + locs.contains(host) + } + + // Does a host count as a rack local preferred location for a task? (assumes host is NOT preferred location). + // This is true if either the task has preferred locations and this host is one, or it has + // no preferred locations (in which we still count the launch as preferred). + def isRackLocalLocation(task: Task[_], hostPort: String): Boolean = { + + val locs = task.preferredLocations + + // DEBUG code + locs.foreach(h => Utils.checkHost(h, "preferredLocation " + locs)) + + val preferredRacks = new HashSet[String]() + for (preferredHost <- locs) { + val rack = sched.getRackForHost(preferredHost) + if (None != rack) preferredRacks += rack.get + } + + if (preferredRacks.isEmpty) return false + + val hostRack = sched.getRackForHost(hostPort) + + return None != hostRack && preferredRacks.contains(hostRack.get) } // Respond to an offer of a single slave from the scheduler by finding a task - def slaveOffer(execId: String, host: String, availableCpus: Double): Option[TaskDescription] = { + def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = { + if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) { - val time = System.currentTimeMillis - val localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT) + // If explicitly specified, use that + val locality = if (overrideLocality != null) overrideLocality else { + // expand only if we have waited for more than LOCALITY_WAIT for a host local task ... + val time = System.currentTimeMillis + if (time - lastPreferredLaunchTime < LOCALITY_WAIT) TaskLocality.HOST_LOCAL else TaskLocality.ANY + } - findTask(host, localOnly) match { + findTask(hostPort, locality) match { case Some(index) => { // Found a task; do some bookkeeping and return a Mesos task for it val task = tasks(index) val taskId = sched.newTaskId() // Figure out whether this should count as a preferred launch - val preferred = isPreferredLocation(task, host) - val prefStr = if (preferred) { - "preferred" - } else { - "non-preferred, not one of " + task.preferredLocations.mkString(", ") - } - logInfo("Starting task %s:%d as TID %s on executor %s: %s (%s)".format( - taskSet.id, index, taskId, execId, host, prefStr)) + val taskLocality = if (isPreferredLocation(task, hostPort)) TaskLocality.HOST_LOCAL else + if (isRackLocalLocation(task, hostPort)) TaskLocality.RACK_LOCAL else TaskLocality.ANY + val prefStr = taskLocality.toString + logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format( + taskSet.id, index, taskId, execId, hostPort, prefStr)) // Do various bookkeeping copiesRunning(index) += 1 - val info = new TaskInfo(taskId, index, time, execId, host, preferred) + val time = System.currentTimeMillis + val info = new TaskInfo(taskId, index, time, execId, hostPort, taskLocality) taskInfos(taskId) = info taskAttempts(index) = info :: taskAttempts(index) - if (preferred) { + if (TaskLocality.HOST_LOCAL == taskLocality) { lastPreferredLaunchTime = time } // Serialize and return the task @@ -355,17 +550,15 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe sched.taskSetFinished(this) } - def executorLost(execId: String, hostname: String) { + def executorLost(execId: String, hostPort: String) { logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id) - val newHostsAlive = sched.hostsAlive // If some task has preferred locations only on hostname, and there are no more executors there, // put it in the no-prefs list to avoid the wait from delay scheduling - if (!newHostsAlive.contains(hostname)) { - for (index <- getPendingTasksForHost(hostname)) { - val newLocs = tasks(index).preferredLocations.toSet & newHostsAlive - if (newLocs.isEmpty) { - pendingTasksWithNoPrefs += index - } + for (index <- getPendingTasksForHostPort(hostPort)) { + val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, true) + if (newLocs.isEmpty) { + assert (findPreferredLocations(tasks(index).preferredLocations, sched).isEmpty) + pendingTasksWithNoPrefs += index } } // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage @@ -419,7 +612,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe !speculatableTasks.contains(index)) { logInfo( "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format( - taskSet.id, index, info.host, threshold)) + taskSet.id, index, info.hostPort, threshold)) speculatableTasks += index foundTasks = true } @@ -427,4 +620,8 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe } return foundTasks } + + def hasPendingTasks(): Boolean = { + numTasks > 0 && tasksFinished < numTasks + } } diff --git a/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala b/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala index 3c3afcbb14..c47824315c 100644 --- a/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala +++ b/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala @@ -4,5 +4,5 @@ package spark.scheduler.cluster * Represents free resources available on an executor. */ private[spark] -class WorkerOffer(val executorId: String, val hostname: String, val cores: Int) { +class WorkerOffer(val executorId: String, val hostPort: String, val cores: Int) { } diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index 9e1bde3fbe..f060a940a9 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -7,7 +7,7 @@ import scala.collection.mutable.HashMap import spark._ import spark.executor.ExecutorURLClassLoader import spark.scheduler._ -import spark.scheduler.cluster.TaskInfo +import spark.scheduler.cluster.{TaskLocality, TaskInfo} /** * A simple TaskScheduler implementation that runs tasks locally in a thread pool. Optionally @@ -53,7 +53,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon def runTask(task: Task[_], idInJob: Int, attemptId: Int) { logInfo("Running " + task) - val info = new TaskInfo(attemptId, idInJob, System.currentTimeMillis(), "local", "local", true) + val info = new TaskInfo(attemptId, idInJob, System.currentTimeMillis(), "local", "local:1", TaskLocality.HOST_LOCAL) // Set the Spark execution environment for the worker thread SparkEnv.set(env) try { diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 210061e972..10e70723db 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -37,17 +37,27 @@ class BlockManager( maxMemory: Long) extends Logging { - class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) { - var pending: Boolean = true - var size: Long = -1L - var failed: Boolean = false + private class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) { + @volatile var pending: Boolean = true + @volatile var size: Long = -1L + @volatile var initThread: Thread = null + @volatile var failed = false + + setInitThread() + + private def setInitThread() { + // Set current thread as init thread - waitForReady will not block this thread + // (in case there is non trivial initialization which ends up calling waitForReady as part of + // initialization itself) + this.initThread = Thread.currentThread() + } /** * Wait for this BlockInfo to be marked as ready (i.e. block is finished writing). * Return true if the block is available, false otherwise. */ def waitForReady(): Boolean = { - if (pending) { + if (initThread != Thread.currentThread() && pending) { synchronized { while (pending) this.wait() } @@ -57,19 +67,26 @@ class BlockManager( /** Mark this BlockInfo as ready (i.e. block is finished writing) */ def markReady(sizeInBytes: Long) { + assert (pending) + size = sizeInBytes + initThread = null + failed = false + initThread = null + pending = false synchronized { - pending = false - failed = false - size = sizeInBytes this.notifyAll() } } /** Mark this BlockInfo as ready but failed */ def markFailure() { + assert (pending) + size = 0 + initThread = null + failed = true + initThread = null + pending = false synchronized { - failed = true - pending = false this.notifyAll() } } @@ -101,7 +118,7 @@ class BlockManager( val heartBeatFrequency = BlockManager.getHeartBeatFrequencyFromSystemProperties - val host = System.getProperty("spark.hostname", Utils.localHostName()) + val hostPort = Utils.localHostPort() val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)), name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next) @@ -212,9 +229,12 @@ class BlockManager( * Tell the master about the current storage status of a block. This will send a block update * message reflecting the current status, *not* the desired storage level in its block info. * For example, a block with MEMORY_AND_DISK set might have fallen out to be only on disk. + * + * droppedMemorySize exists to account for when block is dropped from memory to disk (so it is still valid). + * This ensures that update in master will compensate for the increase in memory on slave. */ - def reportBlockStatus(blockId: String, info: BlockInfo) { - val needReregister = !tryToReportBlockStatus(blockId, info) + def reportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L) { + val needReregister = !tryToReportBlockStatus(blockId, info, droppedMemorySize) if (needReregister) { logInfo("Got told to reregister updating block " + blockId) // Reregistering will report our new block for free. @@ -228,7 +248,7 @@ class BlockManager( * which will be true if the block was successfully recorded and false if * the slave needs to re-register. */ - private def tryToReportBlockStatus(blockId: String, info: BlockInfo): Boolean = { + private def tryToReportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L): Boolean = { val (curLevel, inMemSize, onDiskSize, tellMaster) = info.synchronized { info.level match { case null => @@ -237,7 +257,7 @@ class BlockManager( val inMem = level.useMemory && memoryStore.contains(blockId) val onDisk = level.useDisk && diskStore.contains(blockId) val storageLevel = StorageLevel(onDisk, inMem, level.deserialized, level.replication) - val memSize = if (inMem) memoryStore.getSize(blockId) else 0L + val memSize = if (inMem) memoryStore.getSize(blockId) else droppedMemorySize val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L (storageLevel, memSize, diskSize, info.tellMaster) } @@ -257,7 +277,7 @@ class BlockManager( def getLocations(blockId: String): Seq[String] = { val startTimeMs = System.currentTimeMillis var managers = master.getLocations(blockId) - val locations = managers.map(_.ip) + val locations = managers.map(_.hostPort) logDebug("Got block locations in " + Utils.getUsedTimeMs(startTimeMs)) return locations } @@ -267,7 +287,7 @@ class BlockManager( */ def getLocations(blockIds: Array[String]): Array[Seq[String]] = { val startTimeMs = System.currentTimeMillis - val locations = master.getLocations(blockIds).map(_.map(_.ip).toSeq).toArray + val locations = master.getLocations(blockIds).map(_.map(_.hostPort).toSeq).toArray logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs)) return locations } @@ -339,6 +359,8 @@ class BlockManager( case Some(bytes) => // Put a copy of the block back in memory before returning it. Note that we can't // put the ByteBuffer returned by the disk store as that's a memory-mapped file. + // The use of rewind assumes this. + assert (0 == bytes.position()) val copyForMemory = ByteBuffer.allocate(bytes.limit) copyForMemory.put(bytes) memoryStore.putBytes(blockId, copyForMemory, level) @@ -411,6 +433,7 @@ class BlockManager( // Read it as a byte buffer into memory first, then return it diskStore.getBytes(blockId) match { case Some(bytes) => + assert (0 == bytes.position()) if (level.useMemory) { if (level.deserialized) { memoryStore.putBytes(blockId, bytes, level) @@ -450,7 +473,7 @@ class BlockManager( for (loc <- locations) { logDebug("Getting remote block " + blockId + " from " + loc) val data = BlockManagerWorker.syncGetBlock( - GetBlock(blockId), ConnectionManagerId(loc.ip, loc.port)) + GetBlock(blockId), ConnectionManagerId(loc.host, loc.port)) if (data != null) { return Some(dataDeserialize(blockId, data)) } @@ -501,17 +524,17 @@ class BlockManager( throw new IllegalArgumentException("Storage level is null or invalid") } - val oldBlock = blockInfo.get(blockId).orNull - if (oldBlock != null && oldBlock.waitForReady()) { - logWarning("Block " + blockId + " already exists on this machine; not re-adding it") - return oldBlock.size - } - // Remember the block's storage level so that we can correctly drop it to disk if it needs // to be dropped right after it got put into memory. Note, however, that other threads will // not be able to get() this block until we call markReady on its BlockInfo. val myInfo = new BlockInfo(level, tellMaster) - blockInfo.put(blockId, myInfo) + // Do atomically ! + val oldBlockOpt = blockInfo.putIfAbsent(blockId, myInfo) + + if (oldBlockOpt.isDefined && oldBlockOpt.get.waitForReady()) { + logWarning("Block " + blockId + " already exists on this machine; not re-adding it") + return oldBlockOpt.get.size + } val startTimeMs = System.currentTimeMillis @@ -531,6 +554,7 @@ class BlockManager( logTrace("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) + " to get into synchronized block") + var marked = false try { if (level.useMemory) { // Save it just to memory first, even if it also has useDisk set to true; we will later @@ -555,20 +579,20 @@ class BlockManager( // Now that the block is in either the memory or disk store, let other threads read it, // and tell the master about it. + marked = true myInfo.markReady(size) if (tellMaster) { reportBlockStatus(blockId, myInfo) } - } catch { + } finally { // If we failed at putting the block to memory/disk, notify other possible readers // that it has failed, and then remove it from the block info map. - case e: Exception => { + if (! marked) { // Note that the remove must happen before markFailure otherwise another thread // could've inserted a new BlockInfo before we remove it. blockInfo.remove(blockId) myInfo.markFailure() - logWarning("Putting block " + blockId + " failed", e) - throw e + logWarning("Putting block " + blockId + " failed") } } } @@ -611,16 +635,17 @@ class BlockManager( throw new IllegalArgumentException("Storage level is null or invalid") } - if (blockInfo.contains(blockId)) { - logWarning("Block " + blockId + " already exists on this machine; not re-adding it") - return - } - // Remember the block's storage level so that we can correctly drop it to disk if it needs // to be dropped right after it got put into memory. Note, however, that other threads will // not be able to get() this block until we call markReady on its BlockInfo. val myInfo = new BlockInfo(level, tellMaster) - blockInfo.put(blockId, myInfo) + // Do atomically ! + val prevInfo = blockInfo.putIfAbsent(blockId, myInfo) + if (prevInfo != null) { + // Should we check for prevInfo.waitForReady() here ? + logWarning("Block " + blockId + " already exists on this machine; not re-adding it") + return + } val startTimeMs = System.currentTimeMillis @@ -639,6 +664,7 @@ class BlockManager( logDebug("PutBytes for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) + " to get into synchronized block") + var marked = false try { if (level.useMemory) { // Store it only in memory at first, even if useDisk is also set to true @@ -649,22 +675,24 @@ class BlockManager( diskStore.putBytes(blockId, bytes, level) } + // assert (0 == bytes.position(), "" + bytes) + // Now that the block is in either the memory or disk store, let other threads read it, // and tell the master about it. + marked = true myInfo.markReady(bytes.limit) if (tellMaster) { reportBlockStatus(blockId, myInfo) } - } catch { + } finally { // If we failed at putting the block to memory/disk, notify other possible readers // that it has failed, and then remove it from the block info map. - case e: Exception => { + if (! marked) { // Note that the remove must happen before markFailure otherwise another thread // could've inserted a new BlockInfo before we remove it. blockInfo.remove(blockId) myInfo.markFailure() - logWarning("Putting block " + blockId + " failed", e) - throw e + logWarning("Putting block " + blockId + " failed") } } } @@ -698,7 +726,7 @@ class BlockManager( logDebug("Try to replicate BlockId " + blockId + " once; The size of the data is " + data.limit() + " Bytes. To node: " + peer) if (!BlockManagerWorker.syncPutBlock(PutBlock(blockId, data, tLevel), - new ConnectionManagerId(peer.ip, peer.port))) { + new ConnectionManagerId(peer.host, peer.port))) { logError("Failed to call syncPutBlock to " + peer) } logDebug("Replicated BlockId " + blockId + " once used " + @@ -730,6 +758,14 @@ class BlockManager( val info = blockInfo.get(blockId).orNull if (info != null) { info.synchronized { + // required ? As of now, this will be invoked only for blocks which are ready + // But in case this changes in future, adding for consistency sake. + if (! info.waitForReady() ) { + // If we get here, the block write failed. + logWarning("Block " + blockId + " was marked as failure. Nothing to drop") + return + } + val level = info.level if (level.useDisk && !diskStore.contains(blockId)) { logInfo("Writing block " + blockId + " to disk") @@ -740,12 +776,13 @@ class BlockManager( diskStore.putBytes(blockId, bytes, level) } } + val droppedMemorySize = memoryStore.getSize(blockId) val blockWasRemoved = memoryStore.remove(blockId) if (!blockWasRemoved) { logWarning("Block " + blockId + " could not be dropped from memory as it does not exist") } if (info.tellMaster) { - reportBlockStatus(blockId, info) + reportBlockStatus(blockId, info, droppedMemorySize) } if (!level.useDisk) { // The block is completely gone from this node; forget it so we can put() it again later. @@ -938,8 +975,8 @@ class BlockFetcherIterator( def sendRequest(req: FetchRequest) { logDebug("Sending request for %d blocks (%s) from %s".format( - req.blocks.size, Utils.memoryBytesToString(req.size), req.address.ip)) - val cmId = new ConnectionManagerId(req.address.ip, req.address.port) + req.blocks.size, Utils.memoryBytesToString(req.size), req.address.hostPort)) + val cmId = new ConnectionManagerId(req.address.host, req.address.port) val blockMessageArray = new BlockMessageArray(req.blocks.map { case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId)) }) diff --git a/core/src/main/scala/spark/storage/BlockManagerId.scala b/core/src/main/scala/spark/storage/BlockManagerId.scala index f2f1e77d41..f4a2181490 100644 --- a/core/src/main/scala/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/spark/storage/BlockManagerId.scala @@ -2,6 +2,7 @@ package spark.storage import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} import java.util.concurrent.ConcurrentHashMap +import spark.Utils /** * This class represent an unique identifier for a BlockManager. @@ -13,7 +14,7 @@ import java.util.concurrent.ConcurrentHashMap */ private[spark] class BlockManagerId private ( private var executorId_ : String, - private var ip_ : String, + private var host_ : String, private var port_ : Int ) extends Externalizable { @@ -21,32 +22,45 @@ private[spark] class BlockManagerId private ( def executorId: String = executorId_ - def ip: String = ip_ + if (null != host_){ + Utils.checkHost(host_, "Expected hostname") + assert (port_ > 0) + } + + def hostPort: String = { + // DEBUG code + Utils.checkHost(host) + assert (port > 0) + + host + ":" + port + } + + def host: String = host_ def port: Int = port_ override def writeExternal(out: ObjectOutput) { out.writeUTF(executorId_) - out.writeUTF(ip_) + out.writeUTF(host_) out.writeInt(port_) } override def readExternal(in: ObjectInput) { executorId_ = in.readUTF() - ip_ = in.readUTF() + host_ = in.readUTF() port_ = in.readInt() } @throws(classOf[IOException]) private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this) - override def toString = "BlockManagerId(%s, %s, %d)".format(executorId, ip, port) + override def toString = "BlockManagerId(%s, %s, %d)".format(executorId, host, port) - override def hashCode: Int = (executorId.hashCode * 41 + ip.hashCode) * 41 + port + override def hashCode: Int = (executorId.hashCode * 41 + host.hashCode) * 41 + port override def equals(that: Any) = that match { case id: BlockManagerId => - executorId == id.executorId && port == id.port && ip == id.ip + executorId == id.executorId && port == id.port && host == id.host case _ => false } @@ -55,8 +69,8 @@ private[spark] class BlockManagerId private ( private[spark] object BlockManagerId { - def apply(execId: String, ip: String, port: Int) = - getCachedBlockManagerId(new BlockManagerId(execId, ip, port)) + def apply(execId: String, host: String, port: Int) = + getCachedBlockManagerId(new BlockManagerId(execId, host, port)) def apply(in: ObjectInput) = { val obj = new BlockManagerId() @@ -67,11 +81,7 @@ private[spark] object BlockManagerId { val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]() def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = { - if (blockManagerIdCache.containsKey(id)) { - blockManagerIdCache.get(id) - } else { - blockManagerIdCache.put(id, id) - id - } + blockManagerIdCache.putIfAbsent(id, id) + blockManagerIdCache.get(id) } } diff --git a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala index 2830bc6297..3ce1e6e257 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala @@ -332,8 +332,8 @@ object BlockManagerMasterActor { // Mapping from block id to its status. private val _blocks = new JHashMap[String, BlockStatus] - logInfo("Registering block manager %s:%d with %s RAM".format( - blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(maxMem))) + logInfo("Registering block manager %s with %s RAM".format( + blockManagerId.hostPort, Utils.memoryBytesToString(maxMem))) def updateLastSeenMs() { _lastSeenMs = System.currentTimeMillis() @@ -358,13 +358,13 @@ object BlockManagerMasterActor { _blocks.put(blockId, BlockStatus(storageLevel, memSize, diskSize)) if (storageLevel.useMemory) { _remainingMem -= memSize - logInfo("Added %s in memory on %s:%d (size: %s, free: %s)".format( - blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize), + logInfo("Added %s in memory on %s (size: %s, free: %s)".format( + blockId, blockManagerId.hostPort, Utils.memoryBytesToString(memSize), Utils.memoryBytesToString(_remainingMem))) } if (storageLevel.useDisk) { - logInfo("Added %s on disk on %s:%d (size: %s)".format( - blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize))) + logInfo("Added %s on disk on %s (size: %s)".format( + blockId, blockManagerId.hostPort, Utils.memoryBytesToString(diskSize))) } } else if (_blocks.containsKey(blockId)) { // If isValid is not true, drop the block. @@ -372,13 +372,13 @@ object BlockManagerMasterActor { _blocks.remove(blockId) if (blockStatus.storageLevel.useMemory) { _remainingMem += blockStatus.memSize - logInfo("Removed %s on %s:%d in memory (size: %s, free: %s)".format( - blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize), + logInfo("Removed %s on %s in memory (size: %s, free: %s)".format( + blockId, blockManagerId.hostPort, Utils.memoryBytesToString(memSize), Utils.memoryBytesToString(_remainingMem))) } if (blockStatus.storageLevel.useDisk) { - logInfo("Removed %s on %s:%d on disk (size: %s)".format( - blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize))) + logInfo("Removed %s on %s on disk (size: %s)".format( + blockId, blockManagerId.hostPort, Utils.memoryBytesToString(diskSize))) } } } diff --git a/core/src/main/scala/spark/storage/BlockMessageArray.scala b/core/src/main/scala/spark/storage/BlockMessageArray.scala index a25decb123..ee0c5ff9a2 100644 --- a/core/src/main/scala/spark/storage/BlockMessageArray.scala +++ b/core/src/main/scala/spark/storage/BlockMessageArray.scala @@ -115,6 +115,7 @@ private[spark] object BlockMessageArray { val newBuffer = ByteBuffer.allocate(totalSize) newBuffer.clear() bufferMessage.buffers.foreach(buffer => { + assert (0 == buffer.position()) newBuffer.put(buffer) buffer.rewind() }) diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index ddbf8821ad..c9553d2e0f 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -20,6 +20,9 @@ import spark.Utils private class DiskStore(blockManager: BlockManager, rootDirs: String) extends BlockStore(blockManager) { + private val mapMode = MapMode.READ_ONLY + private var mapOpenMode = "r" + val MAX_DIR_CREATION_ATTEMPTS: Int = 10 val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt @@ -35,7 +38,10 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) getFile(blockId).length() } - override def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) { + override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) { + // So that we do not modify the input offsets ! + // duplicate does not copy buffer, so inexpensive + val bytes = _bytes.duplicate() logDebug("Attempting to put block " + blockId) val startTime = System.currentTimeMillis val file = createFile(blockId) @@ -49,6 +55,18 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) blockId, Utils.memoryBytesToString(bytes.limit), (finishTime - startTime))) } + private def getFileBytes(file: File): ByteBuffer = { + val length = file.length() + val channel = new RandomAccessFile(file, mapOpenMode).getChannel() + val buffer = try { + channel.map(mapMode, 0, length) + } finally { + channel.close() + } + + buffer + } + override def putValues( blockId: String, values: ArrayBuffer[Any], @@ -70,9 +88,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) if (returnValues) { // Return a byte buffer for the contents of the file - val channel = new RandomAccessFile(file, "r").getChannel() - val buffer = channel.map(MapMode.READ_ONLY, 0, length) - channel.close() + val buffer = getFileBytes(file) PutResult(length, Right(buffer)) } else { PutResult(length, null) @@ -81,10 +97,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) override def getBytes(blockId: String): Option[ByteBuffer] = { val file = getFile(blockId) - val length = file.length().toInt - val channel = new RandomAccessFile(file, "r").getChannel() - val bytes = channel.map(MapMode.READ_ONLY, 0, length) - channel.close() + val bytes = getFileBytes(file) Some(bytes) } @@ -96,7 +109,6 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) val file = getFile(blockId) if (file.exists()) { file.delete() - true } else { false } @@ -175,11 +187,12 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) } private def addShutdownHook() { + localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir) ) Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") { override def run() { logDebug("Shutdown hook called") try { - localDirs.foreach(localDir => Utils.deleteRecursively(localDir)) + localDirs.foreach(localDir => if (! Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)) } catch { case t: Throwable => logError("Exception while deleting local spark dirs", t) } diff --git a/core/src/main/scala/spark/storage/MemoryStore.scala b/core/src/main/scala/spark/storage/MemoryStore.scala index 949588476c..eba5ee507f 100644 --- a/core/src/main/scala/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/spark/storage/MemoryStore.scala @@ -31,7 +31,9 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } - override def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) { + override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) { + // Work on a duplicate - since the original input might be used elsewhere. + val bytes = _bytes.duplicate() bytes.rewind() if (level.deserialized) { val values = blockManager.dataDeserialize(blockId, bytes) diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala index 3b5a77ab22..cc0c354e7e 100644 --- a/core/src/main/scala/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/spark/storage/StorageLevel.scala @@ -123,11 +123,7 @@ object StorageLevel { val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]() private[spark] def getCachedStorageLevel(level: StorageLevel): StorageLevel = { - if (storageLevelCache.containsKey(level)) { - storageLevelCache.get(level) - } else { - storageLevelCache.put(level, level) - level - } + storageLevelCache.putIfAbsent(level, level) + storageLevelCache.get(level) } } diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index 3e805b7831..9fb7e001ba 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -11,7 +11,7 @@ import cc.spray.{SprayCanRootService, HttpService} import cc.spray.can.server.HttpServer import cc.spray.io.pipelines.MessageHandlerDispatch.SingletonHandler import akka.dispatch.Await -import spark.SparkException +import spark.{Utils, SparkException} import java.util.concurrent.TimeoutException /** @@ -31,7 +31,10 @@ private[spark] object AkkaUtils { val akkaBatchSize = System.getProperty("spark.akka.batchSize", "15").toInt val akkaTimeout = System.getProperty("spark.akka.timeout", "20").toInt val akkaFrameSize = System.getProperty("spark.akka.frameSize", "10").toInt - val lifecycleEvents = System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean + val lifecycleEvents = if (System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean) "on" else "off" + // 10 seconds is the default akka timeout, but in a cluster, we need higher by default. + val akkaWriteTimeout = System.getProperty("spark.akka.writeTimeout", "30").toInt + val akkaConf = ConfigFactory.parseString(""" akka.daemonic = on akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"] @@ -45,8 +48,9 @@ private[spark] object AkkaUtils { akka.remote.netty.execution-pool-size = %d akka.actor.default-dispatcher.throughput = %d akka.remote.log-remote-lifecycle-events = %s + akka.remote.netty.write-timeout = %ds """.format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize, - if (lifecycleEvents) "on" else "off")) + lifecycleEvents, akkaWriteTimeout)) val actorSystem = ActorSystem(name, akkaConf, getClass.getClassLoader) @@ -60,8 +64,9 @@ private[spark] object AkkaUtils { /** * Creates a Spray HTTP server bound to a given IP and port with a given Spray Route object to * handle requests. Returns the bound port or throws a SparkException on failure. + * TODO: Not changing ip to host here - is it required ? */ - def startSprayServer(actorSystem: ActorSystem, ip: String, port: Int, route: Route, + def startSprayServer(actorSystem: ActorSystem, ip: String, port: Int, route: Route, name: String = "HttpServer"): ActorRef = { val ioWorker = new IoWorker(actorSystem).start() val httpService = actorSystem.actorOf(Props(new HttpService(route))) diff --git a/core/src/main/scala/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/spark/util/TimeStampedHashMap.scala index 188f8910da..92dfaa6e6f 100644 --- a/core/src/main/scala/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/spark/util/TimeStampedHashMap.scala @@ -3,6 +3,7 @@ package spark.util import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConversions import scala.collection.mutable.Map +import spark.scheduler.MapStatus /** * This is a custom implementation of scala.collection.mutable.Map which stores the insertion @@ -42,6 +43,13 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() with spark.Logging { this } + // Should we return previous value directly or as Option ? + def putIfAbsent(key: A, value: B): Option[B] = { + val prev = internalMap.putIfAbsent(key, (value, currentTime)) + if (prev != null) Some(prev._1) else None + } + + override def -= (key: A): this.type = { internalMap.remove(key) this diff --git a/core/src/main/twirl/spark/deploy/master/index.scala.html b/core/src/main/twirl/spark/deploy/master/index.scala.html index ac51a39a51..b9b9f08810 100644 --- a/core/src/main/twirl/spark/deploy/master/index.scala.html +++ b/core/src/main/twirl/spark/deploy/master/index.scala.html @@ -2,7 +2,7 @@ @import spark.deploy.master._ @import spark.Utils -@spark.common.html.layout(title = "Spark Master on " + state.host) { +@spark.common.html.layout(title = "Spark Master on " + state.host + ":" + state.port) { <!-- Cluster Details --> <div class="row"> diff --git a/core/src/main/twirl/spark/deploy/worker/index.scala.html b/core/src/main/twirl/spark/deploy/worker/index.scala.html index c39f769a73..0e66af9284 100644 --- a/core/src/main/twirl/spark/deploy/worker/index.scala.html +++ b/core/src/main/twirl/spark/deploy/worker/index.scala.html @@ -1,7 +1,7 @@ @(worker: spark.deploy.WorkerState) @import spark.Utils -@spark.common.html.layout(title = "Spark Worker on " + worker.host) { +@spark.common.html.layout(title = "Spark Worker on " + worker.host + ":" + worker.port) { <!-- Worker Details --> <div class="row"> diff --git a/core/src/main/twirl/spark/storage/worker_table.scala.html b/core/src/main/twirl/spark/storage/worker_table.scala.html index d54b8de4cc..cd72a688c1 100644 --- a/core/src/main/twirl/spark/storage/worker_table.scala.html +++ b/core/src/main/twirl/spark/storage/worker_table.scala.html @@ -12,7 +12,7 @@ <tbody> @for(status <- workersStatusList) { <tr> - <td>@(status.blockManagerId.ip + ":" + status.blockManagerId.port)</td> + <td>@(status.blockManagerId.host + ":" + status.blockManagerId.port)</td> <td> @(Utils.memoryBytesToString(status.memUsed(prefix))) (@(Utils.memoryBytesToString(status.memRemaining)) Total Available) diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala index 4104b33c8b..c9b4707def 100644 --- a/core/src/test/scala/spark/DistributedSuite.scala +++ b/core/src/test/scala/spark/DistributedSuite.scala @@ -153,7 +153,7 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter val blockManager = SparkEnv.get.blockManager blockManager.master.getLocations(blockId).foreach(id => { val bytes = BlockManagerWorker.syncGetBlock( - GetBlock(blockId), ConnectionManagerId(id.ip, id.port)) + GetBlock(blockId), ConnectionManagerId(id.host, id.port)) val deserialized = blockManager.dataDeserialize(blockId, bytes).asInstanceOf[Iterator[Int]].toList assert(deserialized === (1 to 100).toList) }) diff --git a/core/src/test/scala/spark/FileSuite.scala b/core/src/test/scala/spark/FileSuite.scala index 91b48c7456..a3840905f4 100644 --- a/core/src/test/scala/spark/FileSuite.scala +++ b/core/src/test/scala/spark/FileSuite.scala @@ -18,6 +18,7 @@ class FileSuite extends FunSuite with LocalSparkContext { val outputDir = new File(tempDir, "output").getAbsolutePath val nums = sc.makeRDD(1 to 4) nums.saveAsTextFile(outputDir) + println("outputDir = " + outputDir) // Read the plain text file and check it's OK val outputFile = new File(outputDir, "part-00000") val content = Source.fromFile(outputFile).mkString diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala index 6da58a0f6e..c0f8986de8 100644 --- a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala @@ -271,7 +271,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont // have the 2nd attempt pass complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1)))) // we can see both result blocks now - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.ip) === Array("hostA", "hostB")) + assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB")) complete(taskSets(3), Seq((Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) } |