From 5e91495f5c718c837b5a5af2268f6faad00d357f Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Wed, 30 Oct 2013 17:07:24 -0700 Subject: Deduplicate Local and Cluster schedulers. The code in LocalScheduler/LocalTaskSetManager was nearly identical to the code in ClusterScheduler/ClusterTaskSetManager. The redundancy made making updating the schedulers unnecessarily painful and error- prone. This commit combines the two into a single TaskScheduler/ TaskSetManager. --- .../main/scala/org/apache/spark/SparkContext.scala | 34 +- .../spark/scheduler/ExecutorLossReason.scala | 38 ++ .../apache/spark/scheduler/SchedulerBackend.scala | 37 ++ .../apache/spark/scheduler/TaskResultGetter.scala | 107 ++++ .../org/apache/spark/scheduler/TaskScheduler.scala | 480 +++++++++++++- .../apache/spark/scheduler/TaskSetManager.scala | 688 +++++++++++++++++++- .../org/apache/spark/scheduler/WorkerOffer.scala | 24 + .../spark/scheduler/cluster/ClusterScheduler.scala | 486 -------------- .../scheduler/cluster/ClusterTaskSetManager.scala | 703 --------------------- .../cluster/CoarseGrainedSchedulerBackend.scala | 5 +- .../scheduler/cluster/ExecutorLossReason.scala | 38 -- .../spark/scheduler/cluster/SchedulerBackend.scala | 37 -- .../scheduler/cluster/SimrSchedulerBackend.scala | 4 +- .../cluster/SparkDeploySchedulerBackend.scala | 6 +- .../spark/scheduler/cluster/TaskResultGetter.scala | 108 ---- .../spark/scheduler/cluster/WorkerOffer.scala | 24 - .../mesos/CoarseMesosSchedulerBackend.scala | 5 +- .../cluster/mesos/MesosSchedulerBackend.scala | 9 +- .../spark/scheduler/local/LocalBackend.scala | 73 +++ .../spark/scheduler/local/LocalScheduler.scala | 219 ------- .../scheduler/local/LocalTaskSetManager.scala | 191 ------ .../test/scala/org/apache/spark/FailureSuite.scala | 20 +- .../spark/scheduler/SparkListenerSuite.scala | 19 +- .../scheduler/cluster/TaskResultGetterSuite.scala | 4 +- .../scheduler/local/LocalSchedulerSuite.scala | 227 ------- .../scheduler/cluster/YarnClusterScheduler.scala | 10 +- 26 files changed, 1472 insertions(+), 2124 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala delete mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala delete mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala delete mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorLossReason.scala delete mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala delete mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala delete mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/WorkerOffer.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala delete mode 100644 core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala delete mode 100644 core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala delete mode 100644 core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index ade75e20d5..1850436ff2 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -56,10 +56,9 @@ import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, - SparkDeploySchedulerBackend, ClusterScheduler, SimrSchedulerBackend} + SparkDeploySchedulerBackend, SimrSchedulerBackend} import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} -import org.apache.spark.scheduler.local.LocalScheduler -import org.apache.spark.scheduler.StageInfo +import org.apache.spark.scheduler.local.LocalBackend import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils} import org.apache.spark.ui.SparkUI import org.apache.spark.util.{ClosureCleaner, MetadataCleaner, MetadataCleanerType, @@ -149,8 +148,6 @@ class SparkContext( private[spark] var taskScheduler: TaskScheduler = { // Regular expression used for local[N] master format val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r - // Regular expression for local[N, maxRetries], used in tests with failing tasks - val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r // Regular expression for simulating a Spark cluster of [N, cores, memory] locally val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r // Regular expression for connecting to Spark deploy clusters @@ -162,23 +159,26 @@ class SparkContext( master match { case "local" => - new LocalScheduler(1, 0, this) + val scheduler = new TaskScheduler(this) + val backend = new LocalBackend(scheduler, 1) + scheduler.initialize(backend) + scheduler case LOCAL_N_REGEX(threads) => - new LocalScheduler(threads.toInt, 0, this) - - case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => - new LocalScheduler(threads.toInt, maxFailures.toInt, this) + val scheduler = new TaskScheduler(this) + val backend = new LocalBackend(scheduler, threads.toInt) + scheduler.initialize(backend) + scheduler case SPARK_REGEX(sparkUrl) => - val scheduler = new ClusterScheduler(this) + val scheduler = new TaskScheduler(this) val masterUrls = sparkUrl.split(",").map("spark://" + _) val backend = new SparkDeploySchedulerBackend(scheduler, this, masterUrls, appName) scheduler.initialize(backend) scheduler case SIMR_REGEX(simrUrl) => - val scheduler = new ClusterScheduler(this) + val scheduler = new TaskScheduler(this) val backend = new SimrSchedulerBackend(scheduler, this, simrUrl) scheduler.initialize(backend) scheduler @@ -192,7 +192,7 @@ class SparkContext( memoryPerSlaveInt, SparkContext.executorMemoryRequested)) } - val scheduler = new ClusterScheduler(this) + val scheduler = new TaskScheduler(this) val localCluster = new LocalSparkCluster( numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt) val masterUrls = localCluster.start() @@ -207,7 +207,7 @@ class SparkContext( val scheduler = try { val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler") val cons = clazz.getConstructor(classOf[SparkContext]) - cons.newInstance(this).asInstanceOf[ClusterScheduler] + cons.newInstance(this).asInstanceOf[TaskScheduler] } catch { // TODO: Enumerate the exact reasons why it can fail // But irrespective of it, it means we cannot proceed ! @@ -221,7 +221,7 @@ class SparkContext( case MESOS_REGEX(mesosUrl) => MesosNativeLibrary.load() - val scheduler = new ClusterScheduler(this) + val scheduler = new TaskScheduler(this) val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean val backend = if (coarseGrained) { new CoarseMesosSchedulerBackend(scheduler, this, mesosUrl, appName) @@ -593,9 +593,7 @@ class SparkContext( } addedFiles(key) = System.currentTimeMillis - // Fetch the file locally in case a job is executed locally. - // Jobs that run through LocalScheduler will already fetch the required dependencies, - // but jobs run in DAGScheduler.runLocally() will not so we must fetch the files here. + // Fetch the file locally in case a job is executed using DAGScheduler.runLocally(). Utils.fetchFile(path, new File(SparkFiles.getRootDirectory)) logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala new file mode 100644 index 0000000000..2bc43a9186 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.executor.ExecutorExitCode + +/** + * Represents an explanation for a executor or whole slave failing or exiting. + */ +private[spark] +class ExecutorLossReason(val message: String) { + override def toString: String = message +} + +private[spark] +case class ExecutorExited(val exitCode: Int) + extends ExecutorLossReason(ExecutorExitCode.explainExitCode(exitCode)) { +} + +private[spark] +case class SlaveLost(_message: String = "Slave lost") + extends ExecutorLossReason(_message) { +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala new file mode 100644 index 0000000000..1f0839a0e1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.SparkContext + +/** + * A backend interface for scheduling systems that allows plugging in different ones under + * TaskScheduler. We assume a Mesos-like model where the application gets resource offers as + * machines become available and can launch tasks on them. + */ +private[spark] trait SchedulerBackend { + def start(): Unit + def stop(): Unit + def reviveOffers(): Unit + def defaultParallelism(): Int + + def killTask(taskId: Long, executorId: String): Unit = throw new UnsupportedOperationException + + // Memory used by each executor (in megabytes) + protected val executorMemory: Int = SparkContext.executorMemoryRequested +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala new file mode 100644 index 0000000000..5408fa7353 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.nio.ByteBuffer +import java.util.concurrent.{LinkedBlockingDeque, ThreadFactory, ThreadPoolExecutor, TimeUnit} + +import org.apache.spark._ +import org.apache.spark.TaskState.TaskState +import org.apache.spark.serializer.SerializerInstance +import org.apache.spark.util.Utils + +/** + * Runs a thread pool that deserializes and remotely fetches (if necessary) task results. + */ +private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskScheduler) + extends Logging { + private val THREADS = System.getProperty("spark.resultGetter.threads", "4").toInt + private val getTaskResultExecutor = Utils.newDaemonFixedThreadPool( + THREADS, "Result resolver thread") + + protected val serializer = new ThreadLocal[SerializerInstance] { + override def initialValue(): SerializerInstance = { + return sparkEnv.closureSerializer.newInstance() + } + } + + def enqueueSuccessfulTask( + taskSetManager: TaskSetManager, tid: Long, serializedData: ByteBuffer) { + getTaskResultExecutor.execute(new Runnable { + override def run() { + try { + val result = serializer.get().deserialize[TaskResult[_]](serializedData) match { + case directResult: DirectTaskResult[_] => directResult + case IndirectTaskResult(blockId) => + logDebug("Fetching indirect task result for TID %s".format(tid)) + scheduler.handleTaskGettingResult(taskSetManager, tid) + val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId) + if (!serializedTaskResult.isDefined) { + /* We won't be able to get the task result if the machine that ran the task failed + * between when the task ended and when we tried to fetch the result, or if the + * block manager had to flush the result. */ + scheduler.handleFailedTask( + taskSetManager, tid, TaskState.FINISHED, Some(TaskResultLost)) + return + } + val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]]( + serializedTaskResult.get) + sparkEnv.blockManager.master.removeBlock(blockId) + deserializedResult + } + result.metrics.resultSize = serializedData.limit() + scheduler.handleSuccessfulTask(taskSetManager, tid, result) + } catch { + case cnf: ClassNotFoundException => + val loader = Thread.currentThread.getContextClassLoader + taskSetManager.abort("ClassNotFound with classloader: " + loader) + case ex => + taskSetManager.abort("Exception while deserializing and fetching task: %s".format(ex)) + } + } + }) + } + + def enqueueFailedTask(taskSetManager: TaskSetManager, tid: Long, taskState: TaskState, + serializedData: ByteBuffer) { + var reason: Option[TaskEndReason] = None + getTaskResultExecutor.execute(new Runnable { + override def run() { + try { + if (serializedData != null && serializedData.limit() > 0) { + reason = Some(serializer.get().deserialize[TaskEndReason]( + serializedData, getClass.getClassLoader)) + } + } catch { + case cnd: ClassNotFoundException => + // Log an error but keep going here -- the task failed, so not catastropic if we can't + // deserialize the reason. + val loader = Thread.currentThread.getContextClassLoader + logError( + "Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader) + case ex => {} + } + scheduler.handleFailedTask(taskSetManager, tid, taskState, reason) + } + }) + } + + def stop() { + getTaskResultExecutor.shutdownNow() + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 10e0478108..3f694dd25d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -17,39 +17,477 @@ package org.apache.spark.scheduler +import java.nio.ByteBuffer +import java.util.concurrent.atomic.AtomicLong +import java.util.{TimerTask, Timer} + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet + +import org.apache.spark._ +import org.apache.spark.TaskState.TaskState +import org.apache.spark.scheduler._ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode /** - * Low-level task scheduler interface, implemented by both ClusterScheduler and LocalScheduler. - * Each TaskScheduler schedulers task for a single SparkContext. - * These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage, - * and are responsible for sending the tasks to the cluster, running them, retrying if there - * are failures, and mitigating stragglers. They return events to the DAGScheduler. + * Schedules tasks for a single SparkContext. Receives a set of tasks from the DAGScheduler for + * each stage, and is responsible for sending tasks to executors, running them, retrying if there + * are failures, and mitigating stragglers. Returns events to the DAGScheduler. + * + * Clients should first call initialize() and start(), then submit task sets through the + * runTasks method. + * + * This class can work with multiple types of clusters by acting through a SchedulerBackend. + * It can also work with a local setup by using a LocalBackend and setting isLocal to true. + * It handles common logic, like determining a scheduling order across jobs, waking up to launch + * speculative tasks, etc. + * + * THREADING: SchedulerBackends and task-submitting clients can call this class from multiple + * threads, so it needs locks in public API methods to maintain its state. In addition, some + * SchedulerBackends sycnchronize on themselves when they want to send events here, and then + * acquire a lock on us, so we need to make sure that we don't try to lock the backend while + * we are holding a lock on ourselves. */ -private[spark] trait TaskScheduler { +private[spark] class TaskScheduler(val sc: SparkContext, isLocal: Boolean = false) extends Logging { + // How often to check for speculative tasks + val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong + + // Threshold above which we warn user initial TaskSet may be starved + val STARVATION_TIMEOUT = System.getProperty("spark.starvation.timeout", "15000").toLong + + // TaskSetManagers are not thread safe, so any access to one should be synchronized + // on this class. + val activeTaskSets = new HashMap[String, TaskSetManager] + + val taskIdToTaskSetId = new HashMap[Long, String] + val taskIdToExecutorId = new HashMap[Long, String] + val taskSetTaskIds = new HashMap[String, HashSet[Long]] + + @volatile private var hasReceivedTask = false + @volatile private var hasLaunchedTask = false + private val starvationTimer = new Timer(true) + + // Incrementing task IDs + val nextTaskId = new AtomicLong(0) + + // Which executor IDs we have executors on + val activeExecutorIds = new HashSet[String] + + // The set of executors we have on each host; this is used to compute hostsAlive, which + // in turn is used to decide when we can attain data locality on a given host + private val executorsByHost = new HashMap[String, HashSet[String]] + + private val executorIdToHost = new HashMap[String, String] + + // Listener object to pass upcalls into + var dagScheduler: DAGScheduler = null + + var backend: SchedulerBackend = null + + val mapOutputTracker = SparkEnv.get.mapOutputTracker + + var schedulableBuilder: SchedulableBuilder = null + var rootPool: Pool = null + // default scheduler is FIFO + val schedulingMode: SchedulingMode = SchedulingMode.withName( + System.getProperty("spark.scheduler.mode", "FIFO")) + + // This is a var so that we can reset it for testing purposes. + private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this) + + def setDAGScheduler(dagScheduler: DAGScheduler) { + this.dagScheduler = dagScheduler + } + + def initialize(context: SchedulerBackend) { + backend = context + // temporarily set rootPool name to empty + rootPool = new Pool("", schedulingMode, 0, 0) + schedulableBuilder = { + schedulingMode match { + case SchedulingMode.FIFO => + new FIFOSchedulableBuilder(rootPool) + case SchedulingMode.FAIR => + new FairSchedulableBuilder(rootPool) + } + } + schedulableBuilder.buildPools() + } + + def newTaskId(): Long = nextTaskId.getAndIncrement() + + def start() { + backend.start() + + if (!isLocal && System.getProperty("spark.speculation", "false").toBoolean) { + new Thread("TaskScheduler speculation check") { + setDaemon(true) + + override def run() { + logInfo("Starting speculative execution thread") + while (true) { + try { + Thread.sleep(SPECULATION_INTERVAL) + } catch { + case e: InterruptedException => {} + } + checkSpeculatableTasks() + } + } + }.start() + } + } + + def submitTasks(taskSet: TaskSet) { + val tasks = taskSet.tasks + logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") + this.synchronized { + val manager = new TaskSetManager(this, taskSet) + activeTaskSets(taskSet.id) = manager + schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) + taskSetTaskIds(taskSet.id) = new HashSet[Long]() + + if (!isLocal && !hasReceivedTask) { + starvationTimer.scheduleAtFixedRate(new TimerTask() { + override def run() { + if (!hasLaunchedTask) { + logWarning("Initial job has not accepted any resources; " + + "check your cluster UI to ensure that workers are registered " + + "and have sufficient memory") + } else { + this.cancel() + } + } + }, STARVATION_TIMEOUT, STARVATION_TIMEOUT) + } + hasReceivedTask = true + } + backend.reviveOffers() + } - def rootPool: Pool + def cancelTasks(stageId: Int): Unit = synchronized { + logInfo("Cancelling stage " + stageId) + activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) => + // There are two possible cases here: + // 1. The task set manager has been created and some tasks have been scheduled. + // In this case, send a kill signal to the executors to kill the task and then abort + // the stage. + // 2. The task set manager has been created but no tasks has been scheduled. In this case, + // simply abort the stage. + val taskIds = taskSetTaskIds(tsm.taskSet.id) + if (taskIds.size > 0) { + taskIds.foreach { tid => + val execId = taskIdToExecutorId(tid) + backend.killTask(tid, execId) + } + } + tsm.error("Stage %d was cancelled".format(stageId)) + } + } - def schedulingMode: SchedulingMode + def taskSetFinished(manager: TaskSetManager): Unit = synchronized { + // Check to see if the given task set has been removed. This is possible in the case of + // multiple unrecoverable task failures (e.g. if the entire task set is killed when it has + // more than one running tasks). + if (activeTaskSets.contains(manager.taskSet.id)) { + activeTaskSets -= manager.taskSet.id + manager.parent.removeSchedulable(manager) + logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name)) + taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) + taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id) + taskSetTaskIds.remove(manager.taskSet.id) + } + } - def start(): Unit + /** + * Called by cluster manager to offer resources on slaves. We respond by asking our active task + * sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so + * that tasks are balanced across the cluster. + */ + def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized { + SparkEnv.set(sc.env) - // Invoked after system has successfully initialized (typically in spark context). - // Yarn uses this to bootstrap allocation of resources based on preferred locations, wait for slave registerations, etc. + // Mark each slave as alive and remember its hostname + for (o <- offers) { + executorIdToHost(o.executorId) = o.host + if (!executorsByHost.contains(o.host)) { + executorsByHost(o.host) = new HashSet[String]() + executorGained(o.executorId, o.host) + } + } + + // Build a list of tasks to assign to each worker + val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores)) + val availableCpus = offers.map(o => o.cores).toArray + val sortedTaskSets = rootPool.getSortedTaskSetQueue() + for (taskSet <- sortedTaskSets) { + logDebug("parentName: %s, name: %s, runningTasks: %s".format( + taskSet.parent.name, taskSet.name, taskSet.runningTasks)) + } + + // Take each TaskSet in our scheduling order, and then offer it each node in increasing order + // of locality levels so that it gets a chance to launch local tasks on all of them. + var launchedTask = false + for (taskSet <- sortedTaskSets; maxLocality <- TaskLocality.values) { + do { + launchedTask = false + for (i <- 0 until offers.size) { + val execId = offers(i).executorId + val host = offers(i).host + for (task <- taskSet.resourceOffer(execId, host, availableCpus(i), maxLocality)) { + tasks(i) += task + val tid = task.taskId + taskIdToTaskSetId(tid) = taskSet.taskSet.id + taskSetTaskIds(taskSet.taskSet.id) += tid + taskIdToExecutorId(tid) = execId + activeExecutorIds += execId + executorsByHost(host) += execId + availableCpus(i) -= 1 + launchedTask = true + } + } + } while (launchedTask) + } + + if (tasks.size > 0) { + hasLaunchedTask = true + } + return tasks + } + + def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { + var failedExecutor: Option[String] = None + var taskFailed = false + synchronized { + try { + if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) { + // We lost this entire executor, so remember that it's gone + val execId = taskIdToExecutorId(tid) + if (activeExecutorIds.contains(execId)) { + removeExecutor(execId) + failedExecutor = Some(execId) + } + } + taskIdToTaskSetId.get(tid) match { + case Some(taskSetId) => + if (TaskState.isFinished(state)) { + taskIdToTaskSetId.remove(tid) + if (taskSetTaskIds.contains(taskSetId)) { + taskSetTaskIds(taskSetId) -= tid + } + taskIdToExecutorId.remove(tid) + } + if (state == TaskState.FAILED) { + taskFailed = true + } + activeTaskSets.get(taskSetId).foreach { taskSet => + if (state == TaskState.FINISHED) { + taskSet.removeRunningTask(tid) + taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData) + } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) { + taskSet.removeRunningTask(tid) + taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData) + } + } + case None => + logInfo("Ignoring update from TID " + tid + " because its task set is gone") + } + } catch { + case e: Exception => logError("Exception in statusUpdate", e) + } + } + // Update the DAGScheduler without holding a lock on this, since that can deadlock + if (failedExecutor != None) { + dagScheduler.executorLost(failedExecutor.get) + backend.reviveOffers() + } + if (taskFailed) { + // Also revive offers if a task had failed for some reason other than host lost + backend.reviveOffers() + } + } + + def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long) { + taskSetManager.handleTaskGettingResult(tid) + } + + def handleSuccessfulTask( + taskSetManager: TaskSetManager, + tid: Long, + taskResult: DirectTaskResult[_]) = synchronized { + taskSetManager.handleSuccessfulTask(tid, taskResult) + } + + def handleFailedTask( + taskSetManager: TaskSetManager, + tid: Long, + taskState: TaskState, + reason: Option[TaskEndReason]) = synchronized { + taskSetManager.handleFailedTask(tid, taskState, reason) + if (taskState == TaskState.FINISHED) { + // The task finished successfully but the result was lost, so we should revive offers. + backend.reviveOffers() + } + } + + def error(message: String) { + synchronized { + if (activeTaskSets.size > 0) { + // Have each task set throw a SparkException with the error + for ((taskSetId, manager) <- activeTaskSets) { + try { + manager.error(message) + } catch { + case e: Exception => logError("Exception in error callback", e) + } + } + } else { + // No task sets are active but we still got an error. Just exit since this + // must mean the error is during registration. + // It might be good to do something smarter here in the future. + logError("Exiting due to error from task scheduler: " + message) + System.exit(1) + } + } + } + + def stop() { + if (backend != null) { + backend.stop() + } + if (taskResultGetter != null) { + taskResultGetter.stop() + } + + // sleeping for an arbitrary 5 seconds : to ensure that messages are sent out. + // TODO: Do something better ! + Thread.sleep(5000L) + } + + def defaultParallelism() = backend.defaultParallelism() + + // Check for speculatable tasks in all our active jobs. + def checkSpeculatableTasks() { + var shouldRevive = false + synchronized { + shouldRevive = rootPool.checkSpeculatableTasks() + } + if (shouldRevive) { + backend.reviveOffers() + } + } + + // Check for pending tasks in all our active jobs. + def hasPendingTasks: Boolean = { + synchronized { + rootPool.hasPendingTasks() + } + } + + def executorLost(executorId: String, reason: ExecutorLossReason) { + var failedExecutor: Option[String] = None + + synchronized { + if (activeExecutorIds.contains(executorId)) { + val hostPort = executorIdToHost(executorId) + logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason)) + removeExecutor(executorId) + failedExecutor = Some(executorId) + } else { + // We may get multiple executorLost() calls with different loss reasons. For example, one + // may be triggered by a dropped connection from the slave while another may be a report + // of executor termination from Mesos. We produce log messages for both so we eventually + // report the termination reason. + logError("Lost an executor " + executorId + " (already removed): " + reason) + } + } + // Call dagScheduler.executorLost without holding the lock on this to prevent deadlock + if (failedExecutor != None) { + dagScheduler.executorLost(failedExecutor.get) + backend.reviveOffers() + } + } + + /** Remove an executor from all our data structures and mark it as lost */ + private def removeExecutor(executorId: String) { + activeExecutorIds -= executorId + val host = executorIdToHost(executorId) + val execs = executorsByHost.getOrElse(host, new HashSet) + execs -= executorId + if (execs.isEmpty) { + executorsByHost -= host + } + executorIdToHost -= executorId + rootPool.executorLost(executorId, host) + } + + def executorGained(execId: String, host: String) { + dagScheduler.executorGained(execId, host) + } + + def getExecutorsAliveOnHost(host: String): Option[Set[String]] = synchronized { + executorsByHost.get(host).map(_.toSet) + } + + def hasExecutorsAliveOnHost(host: String): Boolean = synchronized { + executorsByHost.contains(host) + } + + def isExecutorAlive(execId: String): Boolean = synchronized { + activeExecutorIds.contains(execId) + } + + // By default, rack is unknown + def getRackForHost(value: String): Option[String] = None + + /** + * Invoked after the system has successfully been initialized. YARN uses this to bootstrap + * allocation of resources based on preferred locations, wait for slave registrations, etc. + */ def postStartHook() { } +} + - // Disconnect from the cluster. - def stop(): Unit +object TaskScheduler { + /** + * Used to balance containers across hosts. + * + * Accepts a map of hosts to resource offers for that host, and returns a prioritized list of + * resource offers representing the order in which the offers should be used. The resource + * offers are ordered such that we'll allocate one container on each host before allocating a + * second container on any host, and so on, in order to reduce the damage if a host fails. + * + * For example, given , , , returns + * [o1, o5, o4, 02, o6, o3] + */ + def prioritizeContainers[K, T] (map: HashMap[K, ArrayBuffer[T]]): List[T] = { + val _keyList = new ArrayBuffer[K](map.size) + _keyList ++= map.keys - // Submit a sequence of tasks to run. - def submitTasks(taskSet: TaskSet): Unit + // order keyList based on population of value in map + val keyList = _keyList.sortWith( + (left, right) => map(left).size > map(right).size + ) - // Cancel a stage. - def cancelTasks(stageId: Int) + val retval = new ArrayBuffer[T](keyList.size * 2) + var index = 0 + var found = true - // Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called. - def setDAGScheduler(dagScheduler: DAGScheduler): Unit + while (found) { + found = false + for (key <- keyList) { + val containerList: ArrayBuffer[T] = map.get(key).getOrElse(null) + assert(containerList != null) + // Get the index'th entry for this host - if present + if (index < containerList.size){ + retval += containerList.apply(index) + found = true + } + } + index += 1 + } - // Get the default level of parallelism to use in the cluster, as a hint for sizing jobs. - def defaultParallelism(): Int + retval.toList + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 90f6bcefac..13271b10f3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -17,32 +17,690 @@ package org.apache.spark.scheduler -import java.nio.ByteBuffer +import java.util.Arrays +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet +import scala.math.max +import scala.math.min + +import org.apache.spark.{ExceptionFailure, FetchFailed, Logging, Resubmitted, SparkEnv, + Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState} import org.apache.spark.TaskState.TaskState +import org.apache.spark.scheduler._ +import org.apache.spark.util.{SystemClock, Clock} + /** - * Tracks and schedules the tasks within a single TaskSet. This class keeps track of the status of - * each task and is responsible for retries on failure and locality. The main interfaces to it - * are resourceOffer, which asks the TaskSet whether it wants to run a task on one node, and - * statusUpdate, which tells it that one of its tasks changed state (e.g. finished). + * Schedules the tasks within a single TaskSet in the TaskScheduler. This class keeps track of + * each task, retries tasks if they fail (up to a limited number of times), and + * handles locality-aware scheduling for this TaskSet via delay scheduling. The main interfaces + * to it are resourceOffer, which asks the TaskSet whether it wants to run a task on one node, + * and statusUpdate, which tells it that one of its tasks changed state (e.g. finished). * - * THREADING: This class is designed to only be called from code with a lock on the TaskScheduler - * (e.g. its event handlers). It should not be called from other threads. + * THREADING: This class is designed to only be called from code with a lock on the + * TaskScheduler (e.g. its event handlers). It should not be called from other threads. */ -private[spark] trait TaskSetManager extends Schedulable { - def schedulableQueue = null - - def schedulingMode = SchedulingMode.NONE - - def taskSet: TaskSet +private[spark] class TaskSetManager( + sched: TaskScheduler, + val taskSet: TaskSet, + clock: Clock = SystemClock) + extends Schedulable with Logging +{ + // CPUs to request per task + val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toInt + + // Maximum times a task is allowed to fail before failing the job + val MAX_TASK_FAILURES = System.getProperty("spark.task.maxFailures", "4").toInt + + // Quantile of tasks at which to start speculation + val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble + val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble + + // Serializer for closures and tasks. + val env = SparkEnv.get + val ser = env.closureSerializer.newInstance() + + val tasks = taskSet.tasks + val numTasks = tasks.length + val copiesRunning = new Array[Int](numTasks) + val successful = new Array[Boolean](numTasks) + val numFailures = new Array[Int](numTasks) + val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) + var tasksSuccessful = 0 + + var weight = 1 + var minShare = 0 + var priority = taskSet.priority + var stageId = taskSet.stageId + var name = "TaskSet_"+taskSet.stageId.toString + var parent: Pool = null + + var runningTasks = 0 + private val runningTasksSet = new HashSet[Long] + + // Set of pending tasks for each executor. These collections are actually + // treated as stacks, in which new tasks are added to the end of the + // ArrayBuffer and removed from the end. This makes it faster to detect + // tasks that repeatedly fail because whenever a task failed, it is put + // back at the head of the stack. They are also only cleaned up lazily; + // when a task is launched, it remains in all the pending lists except + // the one that it was launched from, but gets removed from them later. + private val pendingTasksForExecutor = new HashMap[String, ArrayBuffer[Int]] + + // Set of pending tasks for each host. Similar to pendingTasksForExecutor, + // but at host level. + private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]] + + // Set of pending tasks for each rack -- similar to the above. + private val pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]] + + // Set containing pending tasks with no locality preferences. + val pendingTasksWithNoPrefs = new ArrayBuffer[Int] + + // Set containing all pending tasks (also used as a stack, as above). + val allPendingTasks = new ArrayBuffer[Int] + + // Tasks that can be speculated. Since these will be a small fraction of total + // tasks, we'll just hold them in a HashSet. + val speculatableTasks = new HashSet[Int] + + // Task index, start and finish time for each task attempt (indexed by task ID) + val taskInfos = new HashMap[Long, TaskInfo] + + // Did the TaskSet fail? + var failed = false + var causeOfFailure = "" + + // How frequently to reprint duplicate exceptions in full, in milliseconds + val EXCEPTION_PRINT_INTERVAL = + System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong + + // Map of recent exceptions (identified by string representation and top stack frame) to + // duplicate count (how many times the same exception has appeared) and time the full exception + // was printed. This should ideally be an LRU map that can drop old exceptions automatically. + val recentExceptions = HashMap[String, (Int, Long)]() + + // Figure out the current map output tracker epoch and set it on all tasks + val epoch = sched.mapOutputTracker.getEpoch + logDebug("Epoch for " + taskSet + ": " + epoch) + for (t <- tasks) { + t.epoch = epoch + } + + // Add all our tasks to the pending lists. We do this in reverse order + // of task index so that tasks with low indices get launched first. + for (i <- (0 until numTasks).reverse) { + addPendingTask(i) + } + + // Figure out which locality levels we have in our TaskSet, so we can do delay scheduling + val myLocalityLevels = computeValidLocalityLevels() + val localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level + + // Delay scheduling variables: we keep track of our current locality level and the time we + // last launched a task at that level, and move up a level when localityWaits[curLevel] expires. + // We then move down if we manage to launch a "more local" task. + var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels + var lastLaunchTime = clock.getTime() // Time we last launched a task at this level + + override def schedulableQueue = null + + override def schedulingMode = SchedulingMode.NONE + + /** + * Add a task to all the pending-task lists that it should be on. If readding is set, we are + * re-adding the task so only include it in each list if it's not already there. + */ + private def addPendingTask(index: Int, readding: Boolean = false) { + // Utility method that adds `index` to a list only if readding=false or it's not already there + def addTo(list: ArrayBuffer[Int]) { + if (!readding || !list.contains(index)) { + list += index + } + } + + var hadAliveLocations = false + for (loc <- tasks(index).preferredLocations) { + for (execId <- loc.executorId) { + if (sched.isExecutorAlive(execId)) { + addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer)) + hadAliveLocations = true + } + } + if (sched.hasExecutorsAliveOnHost(loc.host)) { + addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer)) + for (rack <- sched.getRackForHost(loc.host)) { + addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer)) + } + hadAliveLocations = true + } + } + + if (!hadAliveLocations) { + // Even though the task might've had preferred locations, all of those hosts or executors + // are dead; put it in the no-prefs list so we can schedule it elsewhere right away. + addTo(pendingTasksWithNoPrefs) + } + + if (!readding) { + allPendingTasks += index // No point scanning this whole list to find the old task there + } + } + + /** + * Return the pending tasks list for a given executor ID, or an empty list if + * there is no map entry for that host + */ + private def getPendingTasksForExecutor(executorId: String): ArrayBuffer[Int] = { + pendingTasksForExecutor.getOrElse(executorId, ArrayBuffer()) + } + + /** + * Return the pending tasks list for a given host, or an empty list if + * there is no map entry for that host + */ + private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = { + pendingTasksForHost.getOrElse(host, ArrayBuffer()) + } + + /** + * Return the pending rack-local task list for a given rack, or an empty list if + * there is no map entry for that rack + */ + private def getPendingTasksForRack(rack: String): ArrayBuffer[Int] = { + pendingTasksForRack.getOrElse(rack, ArrayBuffer()) + } + + /** + * Dequeue a pending task from the given list and return its index. + * Return None if the list is empty. + * This method also cleans up any tasks in the list that have already + * been launched, since we want that to happen lazily. + */ + private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = { + while (!list.isEmpty) { + val index = list.last + list.trimEnd(1) + if (copiesRunning(index) == 0 && !successful(index)) { + return Some(index) + } + } + return None + } + + /** Check whether a task is currently running an attempt on a given host */ + private def hasAttemptOnHost(taskIndex: Int, host: String): Boolean = { + !taskAttempts(taskIndex).exists(_.host == host) + } + + /** + * Return a speculative task for a given executor if any are available. The task should not have + * an attempt running on this host, in case the host is slow. In addition, the task should meet + * the given locality constraint. + */ + private def findSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value) + : Option[(Int, TaskLocality.Value)] = + { + speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set + + if (!speculatableTasks.isEmpty) { + // Check for process-local or preference-less tasks; note that tasks can be process-local + // on multiple nodes when we replicate cached blocks, as in Spark Streaming + for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { + val prefs = tasks(index).preferredLocations + val executors = prefs.flatMap(_.executorId) + if (prefs.size == 0 || executors.contains(execId)) { + speculatableTasks -= index + return Some((index, TaskLocality.PROCESS_LOCAL)) + } + } + + // Check for node-local tasks + if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) { + for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { + val locations = tasks(index).preferredLocations.map(_.host) + if (locations.contains(host)) { + speculatableTasks -= index + return Some((index, TaskLocality.NODE_LOCAL)) + } + } + } + + // Check for rack-local tasks + if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { + for (rack <- sched.getRackForHost(host)) { + for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { + val racks = tasks(index).preferredLocations.map(_.host).map(sched.getRackForHost) + if (racks.contains(rack)) { + speculatableTasks -= index + return Some((index, TaskLocality.RACK_LOCAL)) + } + } + } + } + // Check for non-local tasks + if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { + for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { + speculatableTasks -= index + return Some((index, TaskLocality.ANY)) + } + } + } + + return None + } + + /** + * Dequeue a pending task for a given node and return its index and locality level. + * Only search for tasks matching the given locality constraint. + */ + private def findTask(execId: String, host: String, locality: TaskLocality.Value) + : Option[(Int, TaskLocality.Value)] = + { + for (index <- findTaskFromList(getPendingTasksForExecutor(execId))) { + return Some((index, TaskLocality.PROCESS_LOCAL)) + } + + if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) { + for (index <- findTaskFromList(getPendingTasksForHost(host))) { + return Some((index, TaskLocality.NODE_LOCAL)) + } + } + + if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { + for { + rack <- sched.getRackForHost(host) + index <- findTaskFromList(getPendingTasksForRack(rack)) + } { + return Some((index, TaskLocality.RACK_LOCAL)) + } + } + + // Look for no-pref tasks after rack-local tasks since they can run anywhere. + for (index <- findTaskFromList(pendingTasksWithNoPrefs)) { + return Some((index, TaskLocality.PROCESS_LOCAL)) + } + + if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { + for (index <- findTaskFromList(allPendingTasks)) { + return Some((index, TaskLocality.ANY)) + } + } + + // Finally, if all else has failed, find a speculative task + return findSpeculativeTask(execId, host, locality) + } + + /** + * Respond to an offer of a single executor from the scheduler by finding a task + */ def resourceOffer( execId: String, host: String, availableCpus: Int, maxLocality: TaskLocality.TaskLocality) - : Option[TaskDescription] + : Option[TaskDescription] = + { + if (tasksSuccessful < numTasks && availableCpus >= CPUS_PER_TASK) { + val curTime = clock.getTime() + + var allowedLocality = getAllowedLocalityLevel(curTime) + if (allowedLocality > maxLocality) { + allowedLocality = maxLocality // We're not allowed to search for farther-away tasks + } + + findTask(execId, host, allowedLocality) match { + case Some((index, taskLocality)) => { + // Found a task; do some bookkeeping and return a task description + val task = tasks(index) + val taskId = sched.newTaskId() + // Figure out whether this should count as a preferred launch + logInfo("Starting task %s:%d as TID %s on executor %s: %s (%s)".format( + taskSet.id, index, taskId, execId, host, taskLocality)) + // Do various bookkeeping + copiesRunning(index) += 1 + val info = new TaskInfo(taskId, index, curTime, execId, host, taskLocality) + taskInfos(taskId) = info + taskAttempts(index) = info :: taskAttempts(index) + // Update our locality level for delay scheduling + currentLocalityIndex = getLocalityIndex(taskLocality) + lastLaunchTime = curTime + // Serialize and return the task + val startTime = clock.getTime() + // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here + // we assume the task can be serialized without exceptions. + val serializedTask = Task.serializeWithDependencies( + task, sched.sc.addedFiles, sched.sc.addedJars, ser) + val timeTaken = clock.getTime() - startTime + addRunningTask(taskId) + logInfo("Serialized task %s:%d as %d bytes in %d ms".format( + taskSet.id, index, serializedTask.limit, timeTaken)) + val taskName = "task %s:%d".format(taskSet.id, index) + if (taskAttempts(index).size == 1) + taskStarted(task,info) + return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask)) + } + case _ => + } + } + return None + } + + /** + * Get the level we can launch tasks according to delay scheduling, based on current wait time. + */ + private def getAllowedLocalityLevel(curTime: Long): TaskLocality.TaskLocality = { + while (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex) && + currentLocalityIndex < myLocalityLevels.length - 1) + { + // Jump to the next locality level, and remove our waiting time for the current one since + // we don't want to count it again on the next one + lastLaunchTime += localityWaits(currentLocalityIndex) + currentLocalityIndex += 1 + } + myLocalityLevels(currentLocalityIndex) + } + + /** + * Find the index in myLocalityLevels for a given locality. This is also designed to work with + * localities that are not in myLocalityLevels (in case we somehow get those) by returning the + * next-biggest level we have. Uses the fact that the last value in myLocalityLevels is ANY. + */ + def getLocalityIndex(locality: TaskLocality.TaskLocality): Int = { + var index = 0 + while (locality > myLocalityLevels(index)) { + index += 1 + } + index + } + + private def taskStarted(task: Task[_], info: TaskInfo) { + sched.dagScheduler.taskStarted(task, info) + } + + def handleTaskGettingResult(tid: Long) = { + val info = taskInfos(tid) + info.markGettingResult() + sched.dagScheduler.taskGettingResult(tasks(info.index), info) + } + + /** + * Marks the task as successful and notifies the DAGScheduler that a task has ended. + */ + def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = { + val info = taskInfos(tid) + val index = info.index + info.markSuccessful() + removeRunningTask(tid) + if (!successful(index)) { + logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format( + tid, info.duration, info.host, tasksSuccessful, numTasks)) + sched.dagScheduler.taskEnded( + tasks(index), Success, result.value, result.accumUpdates, info, result.metrics) + + // Mark successful and stop if all the tasks have succeeded. + tasksSuccessful += 1 + successful(index) = true + if (tasksSuccessful == numTasks) { + sched.taskSetFinished(this) + } + } else { + logInfo("Ignorning task-finished event for TID " + tid + " because task " + + index + " has already completed successfully") + } + } + + /** + * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the + * DAG Scheduler. + */ + def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) { + val info = taskInfos(tid) + if (info.failed) { + return + } + removeRunningTask(tid) + val index = info.index + info.markFailed() + if (!successful(index)) { + logWarning("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) + copiesRunning(index) -= 1 + // Check if the problem is a map output fetch failure. In that case, this + // task will never succeed on any node, so tell the scheduler about it. + reason.foreach { + case fetchFailed: FetchFailed => + logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress) + sched.dagScheduler.taskEnded(tasks(index), fetchFailed, null, null, info, null) + successful(index) = true + tasksSuccessful += 1 + sched.taskSetFinished(this) + removeAllRunningTasks() + return + + case TaskKilled => + logWarning("Task %d was killed.".format(tid)) + sched.dagScheduler.taskEnded(tasks(index), reason.get, null, null, info, null) + return + + case ef: ExceptionFailure => + sched.dagScheduler.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null)) + val key = ef.description + val now = clock.getTime() + val (printFull, dupCount) = { + if (recentExceptions.contains(key)) { + val (dupCount, printTime) = recentExceptions(key) + if (now - printTime > EXCEPTION_PRINT_INTERVAL) { + recentExceptions(key) = (0, now) + (true, 0) + } else { + recentExceptions(key) = (dupCount + 1, printTime) + (false, dupCount + 1) + } + } else { + recentExceptions(key) = (0, now) + (true, 0) + } + } + if (printFull) { + val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString)) + logWarning("Loss was due to %s\n%s\n%s".format( + ef.className, ef.description, locs.mkString("\n"))) + } else { + logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount)) + } + + case TaskResultLost => + logWarning("Lost result for TID %s on host %s".format(tid, info.host)) + sched.dagScheduler.taskEnded(tasks(index), TaskResultLost, null, null, info, null) + + case _ => {} + } + // On non-fetch failures, re-enqueue the task as pending for a max number of retries + addPendingTask(index) + if (state != TaskState.KILLED) { + numFailures(index) += 1 + if (numFailures(index) > MAX_TASK_FAILURES) { + logError("Task %s:%d failed more than %d times; aborting job".format( + taskSet.id, index, MAX_TASK_FAILURES)) + abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES)) + } + } + } else { + logInfo("Ignoring task-lost event for TID " + tid + + " because task " + index + " is already finished") + } + } + + def error(message: String) { + // Save the error message + abort("Error: " + message) + } + + def abort(message: String) { + failed = true + causeOfFailure = message + // TODO: Kill running tasks if we were not terminated due to a Mesos error + sched.dagScheduler.taskSetFailed(taskSet, message) + removeAllRunningTasks() + sched.taskSetFinished(this) + } + + /** If the given task ID is not in the set of running tasks, adds it. + * + * Used to keep track of the number of running tasks, for enforcing scheduling policies. + */ + def addRunningTask(tid: Long) { + if (runningTasksSet.add(tid) && parent != null) { + parent.increaseRunningTasks(1) + } + runningTasks = runningTasksSet.size + } + + /** If the given task ID is in the set of running tasks, removes it. */ + def removeRunningTask(tid: Long) { + if (runningTasksSet.remove(tid) && parent != null) { + parent.decreaseRunningTasks(1) + } + runningTasks = runningTasksSet.size + } + + private def removeAllRunningTasks() { + val numRunningTasks = runningTasksSet.size + runningTasksSet.clear() + if (parent != null) { + parent.decreaseRunningTasks(numRunningTasks) + } + runningTasks = 0 + } + + override def getSchedulableByName(name: String): Schedulable = { + return null + } + + override def addSchedulable(schedulable: Schedulable) {} + + override def removeSchedulable(schedulable: Schedulable) {} + + override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { + var sortedTaskSetQueue = ArrayBuffer[TaskSetManager](this) + sortedTaskSetQueue += this + return sortedTaskSetQueue + } + + /** Called by TaskScheduler when an executor is lost so we can re-enqueue our tasks */ + override def executorLost(execId: String, host: String) { + logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id) + + // Re-enqueue pending tasks for this host based on the status of the cluster -- for example, a + // task that used to have locations on only this host might now go to the no-prefs list. Note + // that it's okay if we add a task to the same queue twice (if it had multiple preferred + // locations), because findTaskFromList will skip already-running tasks. + for (index <- getPendingTasksForExecutor(execId)) { + addPendingTask(index, readding=true) + } + for (index <- getPendingTasksForHost(host)) { + addPendingTask(index, readding=true) + } + + // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage + if (tasks(0).isInstanceOf[ShuffleMapTask]) { + for ((tid, info) <- taskInfos if info.executorId == execId) { + val index = taskInfos(tid).index + if (successful(index)) { + successful(index) = false + copiesRunning(index) -= 1 + tasksSuccessful -= 1 + addPendingTask(index) + // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our + // stage finishes when a total of tasks.size tasks finish. + sched.dagScheduler.taskEnded(tasks(index), Resubmitted, null, null, info, null) + } + } + } + // Also re-enqueue any tasks that were running on the node + for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { + handleFailedTask(tid, TaskState.KILLED, None) + } + } + + /** + * Check for tasks to be speculated and return true if there are any. This is called periodically + * by the TaskScheduler. + * + * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that + * we don't scan the whole task set. It might also help to make this sorted by launch time. + */ + override def checkSpeculatableTasks(): Boolean = { + // Can't speculate if we only have one task, or if all tasks have finished. + if (numTasks == 1 || tasksSuccessful == numTasks) { + return false + } + var foundTasks = false + val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt + logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation) + if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) { + val time = clock.getTime() + val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray + Arrays.sort(durations) + val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.size - 1)) + val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100) + // TODO: Threshold should also look at standard deviation of task durations and have a lower + // bound based on that. + logDebug("Task length threshold for speculation: " + threshold) + for ((tid, info) <- taskInfos) { + val index = info.index + if (!successful(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold && + !speculatableTasks.contains(index)) { + logInfo( + "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format( + taskSet.id, index, info.host, threshold)) + speculatableTasks += index + foundTasks = true + } + } + } + return foundTasks + } + + override def hasPendingTasks(): Boolean = { + numTasks > 0 && tasksSuccessful < numTasks + } + + private def getLocalityWait(level: TaskLocality.TaskLocality): Long = { + val defaultWait = System.getProperty("spark.locality.wait", "3000") + level match { + case TaskLocality.PROCESS_LOCAL => + System.getProperty("spark.locality.wait.process", defaultWait).toLong + case TaskLocality.NODE_LOCAL => + System.getProperty("spark.locality.wait.node", defaultWait).toLong + case TaskLocality.RACK_LOCAL => + System.getProperty("spark.locality.wait.rack", defaultWait).toLong + case TaskLocality.ANY => + 0L + } + } - def error(message: String) + /** + * Compute the locality levels used in this TaskSet. Assumes that all tasks have already been + * added to queues using addPendingTask. + */ + private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = { + import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY} + val levels = new ArrayBuffer[TaskLocality.TaskLocality] + if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0) { + levels += PROCESS_LOCAL + } + if (!pendingTasksForHost.isEmpty && getLocalityWait(NODE_LOCAL) != 0) { + levels += NODE_LOCAL + } + if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0) { + levels += RACK_LOCAL + } + levels += ANY + logDebug("Valid locality levels for " + taskSet + ": " + levels.mkString(", ")) + levels.toArray + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala new file mode 100644 index 0000000000..ba6bab3f91 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +/** + * Represents free resources available on an executor. + */ +private[spark] +class WorkerOffer(val executorId: String, val host: String, val cores: Int) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala deleted file mode 100644 index 85033958ef..0000000000 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala +++ /dev/null @@ -1,486 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.cluster - -import java.nio.ByteBuffer -import java.util.concurrent.atomic.AtomicLong -import java.util.{TimerTask, Timer} - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet - -import org.apache.spark._ -import org.apache.spark.TaskState.TaskState -import org.apache.spark.scheduler._ -import org.apache.spark.scheduler.SchedulingMode.SchedulingMode - -/** - * The main TaskScheduler implementation, for running tasks on a cluster. Clients should first call - * initialize() and start(), then submit task sets through the runTasks method. - * - * This class can work with multiple types of clusters by acting through a SchedulerBackend. - * It handles common logic, like determining a scheduling order across jobs, waking up to launch - * speculative tasks, etc. - * - * THREADING: SchedulerBackends and task-submitting clients can call this class from multiple - * threads, so it needs locks in public API methods to maintain its state. In addition, some - * SchedulerBackends sycnchronize on themselves when they want to send events here, and then - * acquire a lock on us, so we need to make sure that we don't try to lock the backend while - * we are holding a lock on ourselves. - */ -private[spark] class ClusterScheduler(val sc: SparkContext) - extends TaskScheduler - with Logging -{ - // How often to check for speculative tasks - val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong - - // Threshold above which we warn user initial TaskSet may be starved - val STARVATION_TIMEOUT = System.getProperty("spark.starvation.timeout", "15000").toLong - - // ClusterTaskSetManagers are not thread safe, so any access to one should be synchronized - // on this class. - val activeTaskSets = new HashMap[String, ClusterTaskSetManager] - - val taskIdToTaskSetId = new HashMap[Long, String] - val taskIdToExecutorId = new HashMap[Long, String] - val taskSetTaskIds = new HashMap[String, HashSet[Long]] - - @volatile private var hasReceivedTask = false - @volatile private var hasLaunchedTask = false - private val starvationTimer = new Timer(true) - - // Incrementing task IDs - val nextTaskId = new AtomicLong(0) - - // Which executor IDs we have executors on - val activeExecutorIds = new HashSet[String] - - // The set of executors we have on each host; this is used to compute hostsAlive, which - // in turn is used to decide when we can attain data locality on a given host - private val executorsByHost = new HashMap[String, HashSet[String]] - - private val executorIdToHost = new HashMap[String, String] - - // Listener object to pass upcalls into - var dagScheduler: DAGScheduler = null - - var backend: SchedulerBackend = null - - val mapOutputTracker = SparkEnv.get.mapOutputTracker - - var schedulableBuilder: SchedulableBuilder = null - var rootPool: Pool = null - // default scheduler is FIFO - val schedulingMode: SchedulingMode = SchedulingMode.withName( - System.getProperty("spark.scheduler.mode", "FIFO")) - - // This is a var so that we can reset it for testing purposes. - private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this) - - override def setDAGScheduler(dagScheduler: DAGScheduler) { - this.dagScheduler = dagScheduler - } - - def initialize(context: SchedulerBackend) { - backend = context - // temporarily set rootPool name to empty - rootPool = new Pool("", schedulingMode, 0, 0) - schedulableBuilder = { - schedulingMode match { - case SchedulingMode.FIFO => - new FIFOSchedulableBuilder(rootPool) - case SchedulingMode.FAIR => - new FairSchedulableBuilder(rootPool) - } - } - schedulableBuilder.buildPools() - } - - def newTaskId(): Long = nextTaskId.getAndIncrement() - - override def start() { - backend.start() - - if (System.getProperty("spark.speculation", "false").toBoolean) { - new Thread("ClusterScheduler speculation check") { - setDaemon(true) - - override def run() { - logInfo("Starting speculative execution thread") - while (true) { - try { - Thread.sleep(SPECULATION_INTERVAL) - } catch { - case e: InterruptedException => {} - } - checkSpeculatableTasks() - } - } - }.start() - } - } - - override def submitTasks(taskSet: TaskSet) { - val tasks = taskSet.tasks - logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") - this.synchronized { - val manager = new ClusterTaskSetManager(this, taskSet) - activeTaskSets(taskSet.id) = manager - schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) - taskSetTaskIds(taskSet.id) = new HashSet[Long]() - - if (!hasReceivedTask) { - starvationTimer.scheduleAtFixedRate(new TimerTask() { - override def run() { - if (!hasLaunchedTask) { - logWarning("Initial job has not accepted any resources; " + - "check your cluster UI to ensure that workers are registered " + - "and have sufficient memory") - } else { - this.cancel() - } - } - }, STARVATION_TIMEOUT, STARVATION_TIMEOUT) - } - hasReceivedTask = true - } - backend.reviveOffers() - } - - override def cancelTasks(stageId: Int): Unit = synchronized { - logInfo("Cancelling stage " + stageId) - activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) => - // There are two possible cases here: - // 1. The task set manager has been created and some tasks have been scheduled. - // In this case, send a kill signal to the executors to kill the task and then abort - // the stage. - // 2. The task set manager has been created but no tasks has been scheduled. In this case, - // simply abort the stage. - val taskIds = taskSetTaskIds(tsm.taskSet.id) - if (taskIds.size > 0) { - taskIds.foreach { tid => - val execId = taskIdToExecutorId(tid) - backend.killTask(tid, execId) - } - } - tsm.error("Stage %d was cancelled".format(stageId)) - } - } - - def taskSetFinished(manager: TaskSetManager): Unit = synchronized { - // Check to see if the given task set has been removed. This is possible in the case of - // multiple unrecoverable task failures (e.g. if the entire task set is killed when it has - // more than one running tasks). - if (activeTaskSets.contains(manager.taskSet.id)) { - activeTaskSets -= manager.taskSet.id - manager.parent.removeSchedulable(manager) - logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name)) - taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) - taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id) - taskSetTaskIds.remove(manager.taskSet.id) - } - } - - /** - * Called by cluster manager to offer resources on slaves. We respond by asking our active task - * sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so - * that tasks are balanced across the cluster. - */ - def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized { - SparkEnv.set(sc.env) - - // Mark each slave as alive and remember its hostname - for (o <- offers) { - executorIdToHost(o.executorId) = o.host - if (!executorsByHost.contains(o.host)) { - executorsByHost(o.host) = new HashSet[String]() - executorGained(o.executorId, o.host) - } - } - - // Build a list of tasks to assign to each worker - val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores)) - val availableCpus = offers.map(o => o.cores).toArray - val sortedTaskSets = rootPool.getSortedTaskSetQueue() - for (taskSet <- sortedTaskSets) { - logDebug("parentName: %s, name: %s, runningTasks: %s".format( - taskSet.parent.name, taskSet.name, taskSet.runningTasks)) - } - - // Take each TaskSet in our scheduling order, and then offer it each node in increasing order - // of locality levels so that it gets a chance to launch local tasks on all of them. - var launchedTask = false - for (taskSet <- sortedTaskSets; maxLocality <- TaskLocality.values) { - do { - launchedTask = false - for (i <- 0 until offers.size) { - val execId = offers(i).executorId - val host = offers(i).host - for (task <- taskSet.resourceOffer(execId, host, availableCpus(i), maxLocality)) { - tasks(i) += task - val tid = task.taskId - taskIdToTaskSetId(tid) = taskSet.taskSet.id - taskSetTaskIds(taskSet.taskSet.id) += tid - taskIdToExecutorId(tid) = execId - activeExecutorIds += execId - executorsByHost(host) += execId - availableCpus(i) -= 1 - launchedTask = true - } - } - } while (launchedTask) - } - - if (tasks.size > 0) { - hasLaunchedTask = true - } - return tasks - } - - def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { - var failedExecutor: Option[String] = None - var taskFailed = false - synchronized { - try { - if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) { - // We lost this entire executor, so remember that it's gone - val execId = taskIdToExecutorId(tid) - if (activeExecutorIds.contains(execId)) { - removeExecutor(execId) - failedExecutor = Some(execId) - } - } - taskIdToTaskSetId.get(tid) match { - case Some(taskSetId) => - if (TaskState.isFinished(state)) { - taskIdToTaskSetId.remove(tid) - if (taskSetTaskIds.contains(taskSetId)) { - taskSetTaskIds(taskSetId) -= tid - } - taskIdToExecutorId.remove(tid) - } - if (state == TaskState.FAILED) { - taskFailed = true - } - activeTaskSets.get(taskSetId).foreach { taskSet => - if (state == TaskState.FINISHED) { - taskSet.removeRunningTask(tid) - taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData) - } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) { - taskSet.removeRunningTask(tid) - taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData) - } - } - case None => - logInfo("Ignoring update from TID " + tid + " because its task set is gone") - } - } catch { - case e: Exception => logError("Exception in statusUpdate", e) - } - } - // Update the DAGScheduler without holding a lock on this, since that can deadlock - if (failedExecutor != None) { - dagScheduler.executorLost(failedExecutor.get) - backend.reviveOffers() - } - if (taskFailed) { - // Also revive offers if a task had failed for some reason other than host lost - backend.reviveOffers() - } - } - - def handleTaskGettingResult(taskSetManager: ClusterTaskSetManager, tid: Long) { - taskSetManager.handleTaskGettingResult(tid) - } - - def handleSuccessfulTask( - taskSetManager: ClusterTaskSetManager, - tid: Long, - taskResult: DirectTaskResult[_]) = synchronized { - taskSetManager.handleSuccessfulTask(tid, taskResult) - } - - def handleFailedTask( - taskSetManager: ClusterTaskSetManager, - tid: Long, - taskState: TaskState, - reason: Option[TaskEndReason]) = synchronized { - taskSetManager.handleFailedTask(tid, taskState, reason) - if (taskState == TaskState.FINISHED) { - // The task finished successfully but the result was lost, so we should revive offers. - backend.reviveOffers() - } - } - - def error(message: String) { - synchronized { - if (activeTaskSets.size > 0) { - // Have each task set throw a SparkException with the error - for ((taskSetId, manager) <- activeTaskSets) { - try { - manager.error(message) - } catch { - case e: Exception => logError("Exception in error callback", e) - } - } - } else { - // No task sets are active but we still got an error. Just exit since this - // must mean the error is during registration. - // It might be good to do something smarter here in the future. - logError("Exiting due to error from cluster scheduler: " + message) - System.exit(1) - } - } - } - - override def stop() { - if (backend != null) { - backend.stop() - } - if (taskResultGetter != null) { - taskResultGetter.stop() - } - - // sleeping for an arbitrary 5 seconds : to ensure that messages are sent out. - // TODO: Do something better ! - Thread.sleep(5000L) - } - - override def defaultParallelism() = backend.defaultParallelism() - - - // Check for speculatable tasks in all our active jobs. - def checkSpeculatableTasks() { - var shouldRevive = false - synchronized { - shouldRevive = rootPool.checkSpeculatableTasks() - } - if (shouldRevive) { - backend.reviveOffers() - } - } - - // Check for pending tasks in all our active jobs. - def hasPendingTasks: Boolean = { - synchronized { - rootPool.hasPendingTasks() - } - } - - def executorLost(executorId: String, reason: ExecutorLossReason) { - var failedExecutor: Option[String] = None - - synchronized { - if (activeExecutorIds.contains(executorId)) { - val hostPort = executorIdToHost(executorId) - logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason)) - removeExecutor(executorId) - failedExecutor = Some(executorId) - } else { - // We may get multiple executorLost() calls with different loss reasons. For example, one - // may be triggered by a dropped connection from the slave while another may be a report - // of executor termination from Mesos. We produce log messages for both so we eventually - // report the termination reason. - logError("Lost an executor " + executorId + " (already removed): " + reason) - } - } - // Call dagScheduler.executorLost without holding the lock on this to prevent deadlock - if (failedExecutor != None) { - dagScheduler.executorLost(failedExecutor.get) - backend.reviveOffers() - } - } - - /** Remove an executor from all our data structures and mark it as lost */ - private def removeExecutor(executorId: String) { - activeExecutorIds -= executorId - val host = executorIdToHost(executorId) - val execs = executorsByHost.getOrElse(host, new HashSet) - execs -= executorId - if (execs.isEmpty) { - executorsByHost -= host - } - executorIdToHost -= executorId - rootPool.executorLost(executorId, host) - } - - def executorGained(execId: String, host: String) { - dagScheduler.executorGained(execId, host) - } - - def getExecutorsAliveOnHost(host: String): Option[Set[String]] = synchronized { - executorsByHost.get(host).map(_.toSet) - } - - def hasExecutorsAliveOnHost(host: String): Boolean = synchronized { - executorsByHost.contains(host) - } - - def isExecutorAlive(execId: String): Boolean = synchronized { - activeExecutorIds.contains(execId) - } - - // By default, rack is unknown - def getRackForHost(value: String): Option[String] = None -} - - -object ClusterScheduler { - /** - * Used to balance containers across hosts. - * - * Accepts a map of hosts to resource offers for that host, and returns a prioritized list of - * resource offers representing the order in which the offers should be used. The resource - * offers are ordered such that we'll allocate one container on each host before allocating a - * second container on any host, and so on, in order to reduce the damage if a host fails. - * - * For example, given , , , returns - * [o1, o5, o4, 02, o6, o3] - */ - def prioritizeContainers[K, T] (map: HashMap[K, ArrayBuffer[T]]): List[T] = { - val _keyList = new ArrayBuffer[K](map.size) - _keyList ++= map.keys - - // order keyList based on population of value in map - val keyList = _keyList.sortWith( - (left, right) => map(left).size > map(right).size - ) - - val retval = new ArrayBuffer[T](keyList.size * 2) - var index = 0 - var found = true - - while (found) { - found = false - for (key <- keyList) { - val containerList: ArrayBuffer[T] = map.get(key).getOrElse(null) - assert(containerList != null) - // Get the index'th entry for this host - if present - if (index < containerList.size){ - retval += containerList.apply(index) - found = true - } - } - index += 1 - } - - retval.toList - } -} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala deleted file mode 100644 index ee47aaffca..0000000000 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala +++ /dev/null @@ -1,703 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.cluster - -import java.util.Arrays - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet -import scala.math.max -import scala.math.min - -import org.apache.spark.{ExceptionFailure, FetchFailed, Logging, Resubmitted, SparkEnv, - Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState} -import org.apache.spark.TaskState.TaskState -import org.apache.spark.scheduler._ -import org.apache.spark.util.{SystemClock, Clock} - - -/** - * Schedules the tasks within a single TaskSet in the ClusterScheduler. This class keeps track of - * the status of each task, retries tasks if they fail (up to a limited number of times), and - * handles locality-aware scheduling for this TaskSet via delay scheduling. The main interfaces - * to it are resourceOffer, which asks the TaskSet whether it wants to run a task on one node, - * and statusUpdate, which tells it that one of its tasks changed state (e.g. finished). - * - * THREADING: This class is designed to only be called from code with a lock on the - * ClusterScheduler (e.g. its event handlers). It should not be called from other threads. - */ -private[spark] class ClusterTaskSetManager( - sched: ClusterScheduler, - val taskSet: TaskSet, - clock: Clock = SystemClock) - extends TaskSetManager - with Logging -{ - // CPUs to request per task - val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toInt - - // Maximum times a task is allowed to fail before failing the job - val MAX_TASK_FAILURES = System.getProperty("spark.task.maxFailures", "4").toInt - - // Quantile of tasks at which to start speculation - val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble - val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble - - // Serializer for closures and tasks. - val env = SparkEnv.get - val ser = env.closureSerializer.newInstance() - - val tasks = taskSet.tasks - val numTasks = tasks.length - val copiesRunning = new Array[Int](numTasks) - val successful = new Array[Boolean](numTasks) - val numFailures = new Array[Int](numTasks) - val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) - var tasksSuccessful = 0 - - var weight = 1 - var minShare = 0 - var priority = taskSet.priority - var stageId = taskSet.stageId - var name = "TaskSet_"+taskSet.stageId.toString - var parent: Pool = null - - var runningTasks = 0 - private val runningTasksSet = new HashSet[Long] - - // Set of pending tasks for each executor. These collections are actually - // treated as stacks, in which new tasks are added to the end of the - // ArrayBuffer and removed from the end. This makes it faster to detect - // tasks that repeatedly fail because whenever a task failed, it is put - // back at the head of the stack. They are also only cleaned up lazily; - // when a task is launched, it remains in all the pending lists except - // the one that it was launched from, but gets removed from them later. - private val pendingTasksForExecutor = new HashMap[String, ArrayBuffer[Int]] - - // Set of pending tasks for each host. Similar to pendingTasksForExecutor, - // but at host level. - private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]] - - // Set of pending tasks for each rack -- similar to the above. - private val pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]] - - // Set containing pending tasks with no locality preferences. - val pendingTasksWithNoPrefs = new ArrayBuffer[Int] - - // Set containing all pending tasks (also used as a stack, as above). - val allPendingTasks = new ArrayBuffer[Int] - - // Tasks that can be speculated. Since these will be a small fraction of total - // tasks, we'll just hold them in a HashSet. - val speculatableTasks = new HashSet[Int] - - // Task index, start and finish time for each task attempt (indexed by task ID) - val taskInfos = new HashMap[Long, TaskInfo] - - // Did the TaskSet fail? - var failed = false - var causeOfFailure = "" - - // How frequently to reprint duplicate exceptions in full, in milliseconds - val EXCEPTION_PRINT_INTERVAL = - System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong - - // Map of recent exceptions (identified by string representation and top stack frame) to - // duplicate count (how many times the same exception has appeared) and time the full exception - // was printed. This should ideally be an LRU map that can drop old exceptions automatically. - val recentExceptions = HashMap[String, (Int, Long)]() - - // Figure out the current map output tracker epoch and set it on all tasks - val epoch = sched.mapOutputTracker.getEpoch - logDebug("Epoch for " + taskSet + ": " + epoch) - for (t <- tasks) { - t.epoch = epoch - } - - // Add all our tasks to the pending lists. We do this in reverse order - // of task index so that tasks with low indices get launched first. - for (i <- (0 until numTasks).reverse) { - addPendingTask(i) - } - - // Figure out which locality levels we have in our TaskSet, so we can do delay scheduling - val myLocalityLevels = computeValidLocalityLevels() - val localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level - - // Delay scheduling variables: we keep track of our current locality level and the time we - // last launched a task at that level, and move up a level when localityWaits[curLevel] expires. - // We then move down if we manage to launch a "more local" task. - var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels - var lastLaunchTime = clock.getTime() // Time we last launched a task at this level - - /** - * Add a task to all the pending-task lists that it should be on. If readding is set, we are - * re-adding the task so only include it in each list if it's not already there. - */ - private def addPendingTask(index: Int, readding: Boolean = false) { - // Utility method that adds `index` to a list only if readding=false or it's not already there - def addTo(list: ArrayBuffer[Int]) { - if (!readding || !list.contains(index)) { - list += index - } - } - - var hadAliveLocations = false - for (loc <- tasks(index).preferredLocations) { - for (execId <- loc.executorId) { - if (sched.isExecutorAlive(execId)) { - addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer)) - hadAliveLocations = true - } - } - if (sched.hasExecutorsAliveOnHost(loc.host)) { - addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer)) - for (rack <- sched.getRackForHost(loc.host)) { - addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer)) - } - hadAliveLocations = true - } - } - - if (!hadAliveLocations) { - // Even though the task might've had preferred locations, all of those hosts or executors - // are dead; put it in the no-prefs list so we can schedule it elsewhere right away. - addTo(pendingTasksWithNoPrefs) - } - - if (!readding) { - allPendingTasks += index // No point scanning this whole list to find the old task there - } - } - - /** - * Return the pending tasks list for a given executor ID, or an empty list if - * there is no map entry for that host - */ - private def getPendingTasksForExecutor(executorId: String): ArrayBuffer[Int] = { - pendingTasksForExecutor.getOrElse(executorId, ArrayBuffer()) - } - - /** - * Return the pending tasks list for a given host, or an empty list if - * there is no map entry for that host - */ - private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = { - pendingTasksForHost.getOrElse(host, ArrayBuffer()) - } - - /** - * Return the pending rack-local task list for a given rack, or an empty list if - * there is no map entry for that rack - */ - private def getPendingTasksForRack(rack: String): ArrayBuffer[Int] = { - pendingTasksForRack.getOrElse(rack, ArrayBuffer()) - } - - /** - * Dequeue a pending task from the given list and return its index. - * Return None if the list is empty. - * This method also cleans up any tasks in the list that have already - * been launched, since we want that to happen lazily. - */ - private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = { - while (!list.isEmpty) { - val index = list.last - list.trimEnd(1) - if (copiesRunning(index) == 0 && !successful(index)) { - return Some(index) - } - } - return None - } - - /** Check whether a task is currently running an attempt on a given host */ - private def hasAttemptOnHost(taskIndex: Int, host: String): Boolean = { - !taskAttempts(taskIndex).exists(_.host == host) - } - - /** - * Return a speculative task for a given executor if any are available. The task should not have - * an attempt running on this host, in case the host is slow. In addition, the task should meet - * the given locality constraint. - */ - private def findSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value) - : Option[(Int, TaskLocality.Value)] = - { - speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set - - if (!speculatableTasks.isEmpty) { - // Check for process-local or preference-less tasks; note that tasks can be process-local - // on multiple nodes when we replicate cached blocks, as in Spark Streaming - for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { - val prefs = tasks(index).preferredLocations - val executors = prefs.flatMap(_.executorId) - if (prefs.size == 0 || executors.contains(execId)) { - speculatableTasks -= index - return Some((index, TaskLocality.PROCESS_LOCAL)) - } - } - - // Check for node-local tasks - if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) { - for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { - val locations = tasks(index).preferredLocations.map(_.host) - if (locations.contains(host)) { - speculatableTasks -= index - return Some((index, TaskLocality.NODE_LOCAL)) - } - } - } - - // Check for rack-local tasks - if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { - for (rack <- sched.getRackForHost(host)) { - for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { - val racks = tasks(index).preferredLocations.map(_.host).map(sched.getRackForHost) - if (racks.contains(rack)) { - speculatableTasks -= index - return Some((index, TaskLocality.RACK_LOCAL)) - } - } - } - } - - // Check for non-local tasks - if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { - for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { - speculatableTasks -= index - return Some((index, TaskLocality.ANY)) - } - } - } - - return None - } - - /** - * Dequeue a pending task for a given node and return its index and locality level. - * Only search for tasks matching the given locality constraint. - */ - private def findTask(execId: String, host: String, locality: TaskLocality.Value) - : Option[(Int, TaskLocality.Value)] = - { - for (index <- findTaskFromList(getPendingTasksForExecutor(execId))) { - return Some((index, TaskLocality.PROCESS_LOCAL)) - } - - if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) { - for (index <- findTaskFromList(getPendingTasksForHost(host))) { - return Some((index, TaskLocality.NODE_LOCAL)) - } - } - - if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { - for { - rack <- sched.getRackForHost(host) - index <- findTaskFromList(getPendingTasksForRack(rack)) - } { - return Some((index, TaskLocality.RACK_LOCAL)) - } - } - - // Look for no-pref tasks after rack-local tasks since they can run anywhere. - for (index <- findTaskFromList(pendingTasksWithNoPrefs)) { - return Some((index, TaskLocality.PROCESS_LOCAL)) - } - - if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { - for (index <- findTaskFromList(allPendingTasks)) { - return Some((index, TaskLocality.ANY)) - } - } - - // Finally, if all else has failed, find a speculative task - return findSpeculativeTask(execId, host, locality) - } - - /** - * Respond to an offer of a single executor from the scheduler by finding a task - */ - override def resourceOffer( - execId: String, - host: String, - availableCpus: Int, - maxLocality: TaskLocality.TaskLocality) - : Option[TaskDescription] = - { - if (tasksSuccessful < numTasks && availableCpus >= CPUS_PER_TASK) { - val curTime = clock.getTime() - - var allowedLocality = getAllowedLocalityLevel(curTime) - if (allowedLocality > maxLocality) { - allowedLocality = maxLocality // We're not allowed to search for farther-away tasks - } - - findTask(execId, host, allowedLocality) match { - case Some((index, taskLocality)) => { - // Found a task; do some bookkeeping and return a task description - val task = tasks(index) - val taskId = sched.newTaskId() - // Figure out whether this should count as a preferred launch - logInfo("Starting task %s:%d as TID %s on executor %s: %s (%s)".format( - taskSet.id, index, taskId, execId, host, taskLocality)) - // Do various bookkeeping - copiesRunning(index) += 1 - val info = new TaskInfo(taskId, index, curTime, execId, host, taskLocality) - taskInfos(taskId) = info - taskAttempts(index) = info :: taskAttempts(index) - // Update our locality level for delay scheduling - currentLocalityIndex = getLocalityIndex(taskLocality) - lastLaunchTime = curTime - // Serialize and return the task - val startTime = clock.getTime() - // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here - // we assume the task can be serialized without exceptions. - val serializedTask = Task.serializeWithDependencies( - task, sched.sc.addedFiles, sched.sc.addedJars, ser) - val timeTaken = clock.getTime() - startTime - addRunningTask(taskId) - logInfo("Serialized task %s:%d as %d bytes in %d ms".format( - taskSet.id, index, serializedTask.limit, timeTaken)) - val taskName = "task %s:%d".format(taskSet.id, index) - if (taskAttempts(index).size == 1) - taskStarted(task,info) - return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask)) - } - case _ => - } - } - return None - } - - /** - * Get the level we can launch tasks according to delay scheduling, based on current wait time. - */ - private def getAllowedLocalityLevel(curTime: Long): TaskLocality.TaskLocality = { - while (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex) && - currentLocalityIndex < myLocalityLevels.length - 1) - { - // Jump to the next locality level, and remove our waiting time for the current one since - // we don't want to count it again on the next one - lastLaunchTime += localityWaits(currentLocalityIndex) - currentLocalityIndex += 1 - } - myLocalityLevels(currentLocalityIndex) - } - - /** - * Find the index in myLocalityLevels for a given locality. This is also designed to work with - * localities that are not in myLocalityLevels (in case we somehow get those) by returning the - * next-biggest level we have. Uses the fact that the last value in myLocalityLevels is ANY. - */ - def getLocalityIndex(locality: TaskLocality.TaskLocality): Int = { - var index = 0 - while (locality > myLocalityLevels(index)) { - index += 1 - } - index - } - - private def taskStarted(task: Task[_], info: TaskInfo) { - sched.dagScheduler.taskStarted(task, info) - } - - def handleTaskGettingResult(tid: Long) = { - val info = taskInfos(tid) - info.markGettingResult() - sched.dagScheduler.taskGettingResult(tasks(info.index), info) - } - - /** - * Marks the task as successful and notifies the DAGScheduler that a task has ended. - */ - def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = { - val info = taskInfos(tid) - val index = info.index - info.markSuccessful() - removeRunningTask(tid) - if (!successful(index)) { - logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format( - tid, info.duration, info.host, tasksSuccessful, numTasks)) - sched.dagScheduler.taskEnded( - tasks(index), Success, result.value, result.accumUpdates, info, result.metrics) - - // Mark successful and stop if all the tasks have succeeded. - tasksSuccessful += 1 - successful(index) = true - if (tasksSuccessful == numTasks) { - sched.taskSetFinished(this) - } - } else { - logInfo("Ignorning task-finished event for TID " + tid + " because task " + - index + " has already completed successfully") - } - } - - /** - * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the - * DAG Scheduler. - */ - def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) { - val info = taskInfos(tid) - if (info.failed) { - return - } - removeRunningTask(tid) - val index = info.index - info.markFailed() - if (!successful(index)) { - logWarning("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) - copiesRunning(index) -= 1 - // Check if the problem is a map output fetch failure. In that case, this - // task will never succeed on any node, so tell the scheduler about it. - reason.foreach { - case fetchFailed: FetchFailed => - logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress) - sched.dagScheduler.taskEnded(tasks(index), fetchFailed, null, null, info, null) - successful(index) = true - tasksSuccessful += 1 - sched.taskSetFinished(this) - removeAllRunningTasks() - return - - case TaskKilled => - logWarning("Task %d was killed.".format(tid)) - sched.dagScheduler.taskEnded(tasks(index), reason.get, null, null, info, null) - return - - case ef: ExceptionFailure => - sched.dagScheduler.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null)) - val key = ef.description - val now = clock.getTime() - val (printFull, dupCount) = { - if (recentExceptions.contains(key)) { - val (dupCount, printTime) = recentExceptions(key) - if (now - printTime > EXCEPTION_PRINT_INTERVAL) { - recentExceptions(key) = (0, now) - (true, 0) - } else { - recentExceptions(key) = (dupCount + 1, printTime) - (false, dupCount + 1) - } - } else { - recentExceptions(key) = (0, now) - (true, 0) - } - } - if (printFull) { - val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString)) - logWarning("Loss was due to %s\n%s\n%s".format( - ef.className, ef.description, locs.mkString("\n"))) - } else { - logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount)) - } - - case TaskResultLost => - logWarning("Lost result for TID %s on host %s".format(tid, info.host)) - sched.dagScheduler.taskEnded(tasks(index), TaskResultLost, null, null, info, null) - - case _ => {} - } - // On non-fetch failures, re-enqueue the task as pending for a max number of retries - addPendingTask(index) - if (state != TaskState.KILLED) { - numFailures(index) += 1 - if (numFailures(index) > MAX_TASK_FAILURES) { - logError("Task %s:%d failed more than %d times; aborting job".format( - taskSet.id, index, MAX_TASK_FAILURES)) - abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES)) - } - } - } else { - logInfo("Ignoring task-lost event for TID " + tid + - " because task " + index + " is already finished") - } - } - - override def error(message: String) { - // Save the error message - abort("Error: " + message) - } - - def abort(message: String) { - failed = true - causeOfFailure = message - // TODO: Kill running tasks if we were not terminated due to a Mesos error - sched.dagScheduler.taskSetFailed(taskSet, message) - removeAllRunningTasks() - sched.taskSetFinished(this) - } - - /** If the given task ID is not in the set of running tasks, adds it. - * - * Used to keep track of the number of running tasks, for enforcing scheduling policies. - */ - def addRunningTask(tid: Long) { - if (runningTasksSet.add(tid) && parent != null) { - parent.increaseRunningTasks(1) - } - runningTasks = runningTasksSet.size - } - - /** If the given task ID is in the set of running tasks, removes it. */ - def removeRunningTask(tid: Long) { - if (runningTasksSet.remove(tid) && parent != null) { - parent.decreaseRunningTasks(1) - } - runningTasks = runningTasksSet.size - } - - private def removeAllRunningTasks() { - val numRunningTasks = runningTasksSet.size - runningTasksSet.clear() - if (parent != null) { - parent.decreaseRunningTasks(numRunningTasks) - } - runningTasks = 0 - } - - override def getSchedulableByName(name: String): Schedulable = { - return null - } - - override def addSchedulable(schedulable: Schedulable) {} - - override def removeSchedulable(schedulable: Schedulable) {} - - override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { - var sortedTaskSetQueue = ArrayBuffer[TaskSetManager](this) - sortedTaskSetQueue += this - return sortedTaskSetQueue - } - - /** Called by cluster scheduler when an executor is lost so we can re-enqueue our tasks */ - override def executorLost(execId: String, host: String) { - logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id) - - // Re-enqueue pending tasks for this host based on the status of the cluster -- for example, a - // task that used to have locations on only this host might now go to the no-prefs list. Note - // that it's okay if we add a task to the same queue twice (if it had multiple preferred - // locations), because findTaskFromList will skip already-running tasks. - for (index <- getPendingTasksForExecutor(execId)) { - addPendingTask(index, readding=true) - } - for (index <- getPendingTasksForHost(host)) { - addPendingTask(index, readding=true) - } - - // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage - if (tasks(0).isInstanceOf[ShuffleMapTask]) { - for ((tid, info) <- taskInfos if info.executorId == execId) { - val index = taskInfos(tid).index - if (successful(index)) { - successful(index) = false - copiesRunning(index) -= 1 - tasksSuccessful -= 1 - addPendingTask(index) - // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our - // stage finishes when a total of tasks.size tasks finish. - sched.dagScheduler.taskEnded(tasks(index), Resubmitted, null, null, info, null) - } - } - } - // Also re-enqueue any tasks that were running on the node - for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { - handleFailedTask(tid, TaskState.KILLED, None) - } - } - - /** - * Check for tasks to be speculated and return true if there are any. This is called periodically - * by the ClusterScheduler. - * - * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that - * we don't scan the whole task set. It might also help to make this sorted by launch time. - */ - override def checkSpeculatableTasks(): Boolean = { - // Can't speculate if we only have one task, or if all tasks have finished. - if (numTasks == 1 || tasksSuccessful == numTasks) { - return false - } - var foundTasks = false - val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt - logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation) - if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) { - val time = clock.getTime() - val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray - Arrays.sort(durations) - val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.size - 1)) - val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100) - // TODO: Threshold should also look at standard deviation of task durations and have a lower - // bound based on that. - logDebug("Task length threshold for speculation: " + threshold) - for ((tid, info) <- taskInfos) { - val index = info.index - if (!successful(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold && - !speculatableTasks.contains(index)) { - logInfo( - "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format( - taskSet.id, index, info.host, threshold)) - speculatableTasks += index - foundTasks = true - } - } - } - return foundTasks - } - - override def hasPendingTasks(): Boolean = { - numTasks > 0 && tasksSuccessful < numTasks - } - - private def getLocalityWait(level: TaskLocality.TaskLocality): Long = { - val defaultWait = System.getProperty("spark.locality.wait", "3000") - level match { - case TaskLocality.PROCESS_LOCAL => - System.getProperty("spark.locality.wait.process", defaultWait).toLong - case TaskLocality.NODE_LOCAL => - System.getProperty("spark.locality.wait.node", defaultWait).toLong - case TaskLocality.RACK_LOCAL => - System.getProperty("spark.locality.wait.rack", defaultWait).toLong - case TaskLocality.ANY => - 0L - } - } - - /** - * Compute the locality levels used in this TaskSet. Assumes that all tasks have already been - * added to queues using addPendingTask. - */ - private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = { - import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY} - val levels = new ArrayBuffer[TaskLocality.TaskLocality] - if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0) { - levels += PROCESS_LOCAL - } - if (!pendingTasksForHost.isEmpty && getLocalityWait(NODE_LOCAL) != 0) { - levels += NODE_LOCAL - } - if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0) { - levels += RACK_LOCAL - } - levels += ANY - logDebug("Valid locality levels for " + taskSet + ": " + levels.mkString(", ")) - levels.toArray - } -} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 70f3f88401..b8ac498527 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -29,7 +29,8 @@ import akka.util.Duration import akka.util.duration._ import org.apache.spark.{SparkException, Logging, TaskState} -import org.apache.spark.scheduler.TaskDescription +import org.apache.spark.scheduler.{SchedulerBackend, SlaveLost, TaskDescription, TaskScheduler, + WorkerOffer} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.Utils @@ -42,7 +43,7 @@ import org.apache.spark.util.Utils * (spark.deploy.*). */ private[spark] -class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: ActorSystem) +class CoarseGrainedSchedulerBackend(scheduler: TaskScheduler, actorSystem: ActorSystem) extends SchedulerBackend with Logging { // Use an atomic variable to track total number of cores in the cluster for simplicity and speed diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorLossReason.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorLossReason.scala deleted file mode 100644 index 5077b2b48b..0000000000 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorLossReason.scala +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.cluster - -import org.apache.spark.executor.ExecutorExitCode - -/** - * Represents an explanation for a executor or whole slave failing or exiting. - */ -private[spark] -class ExecutorLossReason(val message: String) { - override def toString: String = message -} - -private[spark] -case class ExecutorExited(val exitCode: Int) - extends ExecutorLossReason(ExecutorExitCode.explainExitCode(exitCode)) { -} - -private[spark] -case class SlaveLost(_message: String = "Slave lost") - extends ExecutorLossReason(_message) { -} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala deleted file mode 100644 index 5367218faa..0000000000 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.cluster - -import org.apache.spark.SparkContext - -/** - * A backend interface for cluster scheduling systems that allows plugging in different ones under - * ClusterScheduler. We assume a Mesos-like model where the application gets resource offers as - * machines become available and can launch tasks on them. - */ -private[spark] trait SchedulerBackend { - def start(): Unit - def stop(): Unit - def reviveOffers(): Unit - def defaultParallelism(): Int - - def killTask(taskId: Long, executorId: String): Unit = throw new UnsupportedOperationException - - // Memory used by each executor (in megabytes) - protected val executorMemory: Int = SparkContext.executorMemoryRequested -} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index d78bdbaa7a..a589e7456f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -19,10 +19,12 @@ package org.apache.spark.scheduler.cluster import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, FileSystem} + import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.scheduler.TaskScheduler private[spark] class SimrSchedulerBackend( - scheduler: ClusterScheduler, + scheduler: TaskScheduler, sc: SparkContext, driverFilePath: String) extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index cefa970bb9..15c600a1ec 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -17,14 +17,16 @@ package org.apache.spark.scheduler.cluster +import scala.collection.mutable.HashMap + import org.apache.spark.{Logging, SparkContext} import org.apache.spark.deploy.client.{Client, ClientListener} import org.apache.spark.deploy.{Command, ApplicationDescription} -import scala.collection.mutable.HashMap +import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskScheduler} import org.apache.spark.util.Utils private[spark] class SparkDeploySchedulerBackend( - scheduler: ClusterScheduler, + scheduler: TaskScheduler, sc: SparkContext, masters: Array[String], appName: String) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala deleted file mode 100644 index 2064d97b49..0000000000 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.cluster - -import java.nio.ByteBuffer -import java.util.concurrent.{LinkedBlockingDeque, ThreadFactory, ThreadPoolExecutor, TimeUnit} - -import org.apache.spark._ -import org.apache.spark.TaskState.TaskState -import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult} -import org.apache.spark.serializer.SerializerInstance -import org.apache.spark.util.Utils - -/** - * Runs a thread pool that deserializes and remotely fetches (if necessary) task results. - */ -private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler) - extends Logging { - private val THREADS = System.getProperty("spark.resultGetter.threads", "4").toInt - private val getTaskResultExecutor = Utils.newDaemonFixedThreadPool( - THREADS, "Result resolver thread") - - protected val serializer = new ThreadLocal[SerializerInstance] { - override def initialValue(): SerializerInstance = { - return sparkEnv.closureSerializer.newInstance() - } - } - - def enqueueSuccessfulTask( - taskSetManager: ClusterTaskSetManager, tid: Long, serializedData: ByteBuffer) { - getTaskResultExecutor.execute(new Runnable { - override def run() { - try { - val result = serializer.get().deserialize[TaskResult[_]](serializedData) match { - case directResult: DirectTaskResult[_] => directResult - case IndirectTaskResult(blockId) => - logDebug("Fetching indirect task result for TID %s".format(tid)) - scheduler.handleTaskGettingResult(taskSetManager, tid) - val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId) - if (!serializedTaskResult.isDefined) { - /* We won't be able to get the task result if the machine that ran the task failed - * between when the task ended and when we tried to fetch the result, or if the - * block manager had to flush the result. */ - scheduler.handleFailedTask( - taskSetManager, tid, TaskState.FINISHED, Some(TaskResultLost)) - return - } - val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]]( - serializedTaskResult.get) - sparkEnv.blockManager.master.removeBlock(blockId) - deserializedResult - } - result.metrics.resultSize = serializedData.limit() - scheduler.handleSuccessfulTask(taskSetManager, tid, result) - } catch { - case cnf: ClassNotFoundException => - val loader = Thread.currentThread.getContextClassLoader - taskSetManager.abort("ClassNotFound with classloader: " + loader) - case ex => - taskSetManager.abort("Exception while deserializing and fetching task: %s".format(ex)) - } - } - }) - } - - def enqueueFailedTask(taskSetManager: ClusterTaskSetManager, tid: Long, taskState: TaskState, - serializedData: ByteBuffer) { - var reason: Option[TaskEndReason] = None - getTaskResultExecutor.execute(new Runnable { - override def run() { - try { - if (serializedData != null && serializedData.limit() > 0) { - reason = Some(serializer.get().deserialize[TaskEndReason]( - serializedData, getClass.getClassLoader)) - } - } catch { - case cnd: ClassNotFoundException => - // Log an error but keep going here -- the task failed, so not catastropic if we can't - // deserialize the reason. - val loader = Thread.currentThread.getContextClassLoader - logError( - "Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader) - case ex => {} - } - scheduler.handleFailedTask(taskSetManager, tid, taskState, reason) - } - }) - } - - def stop() { - getTaskResultExecutor.shutdownNow() - } -} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/WorkerOffer.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/WorkerOffer.scala deleted file mode 100644 index 938f62883a..0000000000 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/WorkerOffer.scala +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.cluster - -/** - * Represents free resources available on an executor. - */ -private[spark] -class WorkerOffer(val executorId: String, val host: String, val cores: Int) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 300fe693f1..310da0027e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -30,7 +30,8 @@ import org.apache.mesos._ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} import org.apache.spark.{SparkException, Logging, SparkContext, TaskState} -import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedulerBackend} +import org.apache.spark.scheduler.TaskScheduler +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend /** * A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds @@ -43,7 +44,7 @@ import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedu * remove this. */ private[spark] class CoarseMesosSchedulerBackend( - scheduler: ClusterScheduler, + scheduler: TaskScheduler, sc: SparkContext, master: String, appName: String) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 50cbc2ca92..c0e99df0b6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -30,9 +30,8 @@ import org.apache.mesos._ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} import org.apache.spark.{Logging, SparkException, SparkContext, TaskState} -import org.apache.spark.scheduler.TaskDescription -import org.apache.spark.scheduler.cluster.{ClusterScheduler, ExecutorExited, ExecutorLossReason} -import org.apache.spark.scheduler.cluster.{SchedulerBackend, SlaveLost, WorkerOffer} +import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SchedulerBackend, SlaveLost, + TaskDescription, TaskScheduler, WorkerOffer} import org.apache.spark.util.Utils /** @@ -41,7 +40,7 @@ import org.apache.spark.util.Utils * from multiple apps can run on different cores) and in time (a core can switch ownership). */ private[spark] class MesosSchedulerBackend( - scheduler: ClusterScheduler, + scheduler: TaskScheduler, sc: SparkContext, master: String, appName: String) @@ -210,7 +209,7 @@ private[spark] class MesosSchedulerBackend( getResource(offer.getResourcesList, "cpus").toInt) } - // Call into the ClusterScheduler + // Call into the TaskScheduler val taskLists = scheduler.resourceOffers(offerableWorkers) // Build a list of Mesos tasks for each slave diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala new file mode 100644 index 0000000000..96c3a03602 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.local + +import java.nio.ByteBuffer + +import akka.actor.{Actor, ActorRef, Props} + +import org.apache.spark.{SparkContext, SparkEnv, TaskState} +import org.apache.spark.TaskState.TaskState +import org.apache.spark.executor.{Executor, ExecutorBackend} +import org.apache.spark.scheduler.{SchedulerBackend, TaskScheduler, WorkerOffer} + +/** + * LocalBackend sits behind a TaskScheduler and handles launching tasks on a single Executor + * (created by the LocalBackend) running locally. + * + * THREADING: Because methods can be called both from the Executor and the TaskScheduler, and + * because the Executor class is not thread safe, all methods are synchronized. + */ +private[spark] class LocalBackend(scheduler: TaskScheduler, private val totalCores: Int) + extends SchedulerBackend with ExecutorBackend { + + private var freeCores = totalCores + + private val localExecutorId = "localhost" + private val localExecutorHostname = "localhost" + + val executor = new Executor(localExecutorId, localExecutorHostname, Seq.empty, isLocal = true) + + override def start() { + } + + override def stop() { + } + + override def reviveOffers() = synchronized { + val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores)) + for (task <- scheduler.resourceOffers(offers).flatten) { + freeCores -= 1 + executor.launchTask(this, task.taskId, task.serializedTask) + } + } + + override def defaultParallelism() = totalCores + + override def killTask(taskId: Long, executorId: String) = synchronized { + executor.killTask(taskId) + } + + override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) = synchronized { + scheduler.statusUpdate(taskId, state, serializedData) + if (TaskState.isFinished(state)) { + freeCores += 1 + reviveOffers() + } + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala deleted file mode 100644 index 2699f0b33e..0000000000 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala +++ /dev/null @@ -1,219 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.local - -import java.nio.ByteBuffer -import java.util.concurrent.atomic.AtomicInteger - -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} - -import akka.actor._ - -import org.apache.spark._ -import org.apache.spark.TaskState.TaskState -import org.apache.spark.executor.{Executor, ExecutorBackend} -import org.apache.spark.scheduler._ -import org.apache.spark.scheduler.SchedulingMode.SchedulingMode - - -/** - * A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally - * the scheduler also allows each task to fail up to maxFailures times, which is useful for - * testing fault recovery. - */ - -private[local] -case class LocalReviveOffers() - -private[local] -case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) - -private[local] -case class KillTask(taskId: Long) - -private[spark] -class LocalActor(localScheduler: LocalScheduler, private var freeCores: Int) - extends Actor with Logging { - - val executor = new Executor("localhost", "localhost", Seq.empty, isLocal = true) - - def receive = { - case LocalReviveOffers => - launchTask(localScheduler.resourceOffer(freeCores)) - - case LocalStatusUpdate(taskId, state, serializeData) => - if (TaskState.isFinished(state)) { - freeCores += 1 - launchTask(localScheduler.resourceOffer(freeCores)) - } - - case KillTask(taskId) => - executor.killTask(taskId) - } - - private def launchTask(tasks: Seq[TaskDescription]) { - for (task <- tasks) { - freeCores -= 1 - executor.launchTask(localScheduler, task.taskId, task.serializedTask) - } - } -} - -private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext) - extends TaskScheduler - with ExecutorBackend - with Logging { - - val env = SparkEnv.get - val attemptId = new AtomicInteger - var dagScheduler: DAGScheduler = null - - // Application dependencies (added through SparkContext) that we've fetched so far on this node. - // Each map holds the master's timestamp for the version of that file or JAR we got. - val currentFiles: HashMap[String, Long] = new HashMap[String, Long]() - val currentJars: HashMap[String, Long] = new HashMap[String, Long]() - - var schedulableBuilder: SchedulableBuilder = null - var rootPool: Pool = null - val schedulingMode: SchedulingMode = SchedulingMode.withName( - System.getProperty("spark.scheduler.mode", "FIFO")) - val activeTaskSets = new HashMap[String, LocalTaskSetManager] - val taskIdToTaskSetId = new HashMap[Long, String] - val taskSetTaskIds = new HashMap[String, HashSet[Long]] - - var localActor: ActorRef = null - - override def start() { - // temporarily set rootPool name to empty - rootPool = new Pool("", schedulingMode, 0, 0) - schedulableBuilder = { - schedulingMode match { - case SchedulingMode.FIFO => - new FIFOSchedulableBuilder(rootPool) - case SchedulingMode.FAIR => - new FairSchedulableBuilder(rootPool) - } - } - schedulableBuilder.buildPools() - - localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test") - } - - override def setDAGScheduler(dagScheduler: DAGScheduler) { - this.dagScheduler = dagScheduler - } - - override def submitTasks(taskSet: TaskSet) { - synchronized { - val manager = new LocalTaskSetManager(this, taskSet) - schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) - activeTaskSets(taskSet.id) = manager - taskSetTaskIds(taskSet.id) = new HashSet[Long]() - localActor ! LocalReviveOffers - } - } - - override def cancelTasks(stageId: Int): Unit = synchronized { - logInfo("Cancelling stage " + stageId) - logInfo("Cancelling stage " + activeTaskSets.map(_._2.stageId)) - activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) => - // There are two possible cases here: - // 1. The task set manager has been created and some tasks have been scheduled. - // In this case, send a kill signal to the executors to kill the task and then abort - // the stage. - // 2. The task set manager has been created but no tasks has been scheduled. In this case, - // simply abort the stage. - val taskIds = taskSetTaskIds(tsm.taskSet.id) - if (taskIds.size > 0) { - taskIds.foreach { tid => - localActor ! KillTask(tid) - } - } - tsm.error("Stage %d was cancelled".format(stageId)) - } - } - - def resourceOffer(freeCores: Int): Seq[TaskDescription] = { - synchronized { - var freeCpuCores = freeCores - val tasks = new ArrayBuffer[TaskDescription](freeCores) - val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue() - for (manager <- sortedTaskSetQueue) { - logDebug("parentName:%s,name:%s,runningTasks:%s".format( - manager.parent.name, manager.name, manager.runningTasks)) - } - - var launchTask = false - for (manager <- sortedTaskSetQueue) { - do { - launchTask = false - manager.resourceOffer(null, null, freeCpuCores, null) match { - case Some(task) => - tasks += task - taskIdToTaskSetId(task.taskId) = manager.taskSet.id - taskSetTaskIds(manager.taskSet.id) += task.taskId - freeCpuCores -= 1 - launchTask = true - case None => {} - } - } while(launchTask) - } - return tasks - } - } - - def taskSetFinished(manager: TaskSetManager) { - synchronized { - activeTaskSets -= manager.taskSet.id - manager.parent.removeSchedulable(manager) - logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name)) - taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) - taskSetTaskIds -= manager.taskSet.id - } - } - - override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) { - if (TaskState.isFinished(state)) { - synchronized { - taskIdToTaskSetId.get(taskId) match { - case Some(taskSetId) => - val taskSetManager = activeTaskSets(taskSetId) - taskSetTaskIds(taskSetId) -= taskId - - state match { - case TaskState.FINISHED => - taskSetManager.taskEnded(taskId, state, serializedData) - case TaskState.FAILED => - taskSetManager.taskFailed(taskId, state, serializedData) - case TaskState.KILLED => - taskSetManager.error("Task %d was killed".format(taskId)) - case _ => {} - } - case None => - logInfo("Ignoring update from TID " + taskId + " because its task set is gone") - } - } - localActor ! LocalStatusUpdate(taskId, state, serializedData) - } - } - - override def stop() { - } - - override def defaultParallelism() = threads -} diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala deleted file mode 100644 index 53bf78267e..0000000000 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala +++ /dev/null @@ -1,191 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.local - -import java.nio.ByteBuffer -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap - -import org.apache.spark.{ExceptionFailure, Logging, SparkEnv, SparkException, Success, TaskState} -import org.apache.spark.TaskState.TaskState -import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Pool, Schedulable, Task, - TaskDescription, TaskInfo, TaskLocality, TaskResult, TaskSet, TaskSetManager} - - -private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet) - extends TaskSetManager with Logging { - - var parent: Pool = null - var weight: Int = 1 - var minShare: Int = 0 - var runningTasks: Int = 0 - var priority: Int = taskSet.priority - var stageId: Int = taskSet.stageId - var name: String = "TaskSet_" + taskSet.stageId.toString - - var failCount = new Array[Int](taskSet.tasks.size) - val taskInfos = new HashMap[Long, TaskInfo] - val numTasks = taskSet.tasks.size - var numFinished = 0 - val env = SparkEnv.get - val ser = env.closureSerializer.newInstance() - val copiesRunning = new Array[Int](numTasks) - val finished = new Array[Boolean](numTasks) - val numFailures = new Array[Int](numTasks) - val MAX_TASK_FAILURES = sched.maxFailures - - def increaseRunningTasks(taskNum: Int): Unit = { - runningTasks += taskNum - if (parent != null) { - parent.increaseRunningTasks(taskNum) - } - } - - def decreaseRunningTasks(taskNum: Int): Unit = { - runningTasks -= taskNum - if (parent != null) { - parent.decreaseRunningTasks(taskNum) - } - } - - override def addSchedulable(schedulable: Schedulable): Unit = { - // nothing - } - - override def removeSchedulable(schedulable: Schedulable): Unit = { - // nothing - } - - override def getSchedulableByName(name: String): Schedulable = { - return null - } - - override def executorLost(executorId: String, host: String): Unit = { - // nothing - } - - override def checkSpeculatableTasks() = true - - override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { - var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] - sortedTaskSetQueue += this - return sortedTaskSetQueue - } - - override def hasPendingTasks() = true - - def findTask(): Option[Int] = { - for (i <- 0 to numTasks-1) { - if (copiesRunning(i) == 0 && !finished(i)) { - return Some(i) - } - } - return None - } - - override def resourceOffer( - execId: String, - host: String, - availableCpus: Int, - maxLocality: TaskLocality.TaskLocality) - : Option[TaskDescription] = - { - SparkEnv.set(sched.env) - logDebug("availableCpus:%d, numFinished:%d, numTasks:%d".format( - availableCpus.toInt, numFinished, numTasks)) - if (availableCpus > 0 && numFinished < numTasks) { - findTask() match { - case Some(index) => - val taskId = sched.attemptId.getAndIncrement() - val task = taskSet.tasks(index) - val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1", - TaskLocality.NODE_LOCAL) - taskInfos(taskId) = info - // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here - // we assume the task can be serialized without exceptions. - val bytes = Task.serializeWithDependencies( - task, sched.sc.addedFiles, sched.sc.addedJars, ser) - logInfo("Size of task " + taskId + " is " + bytes.limit + " bytes") - val taskName = "task %s:%d".format(taskSet.id, index) - copiesRunning(index) += 1 - increaseRunningTasks(1) - taskStarted(task, info) - return Some(new TaskDescription(taskId, null, taskName, index, bytes)) - case None => {} - } - } - return None - } - - def taskStarted(task: Task[_], info: TaskInfo) { - sched.dagScheduler.taskStarted(task, info) - } - - def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) { - val info = taskInfos(tid) - val index = info.index - val task = taskSet.tasks(index) - info.markSuccessful() - val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader) match { - case directResult: DirectTaskResult[_] => directResult - case IndirectTaskResult(blockId) => { - throw new SparkException("Expect only DirectTaskResults when using LocalScheduler") - } - } - result.metrics.resultSize = serializedData.limit() - sched.dagScheduler.taskEnded(task, Success, result.value, result.accumUpdates, info, - result.metrics) - numFinished += 1 - decreaseRunningTasks(1) - finished(index) = true - if (numFinished == numTasks) { - sched.taskSetFinished(this) - } - } - - def taskFailed(tid: Long, state: TaskState, serializedData: ByteBuffer) { - val info = taskInfos(tid) - val index = info.index - val task = taskSet.tasks(index) - info.markFailed() - decreaseRunningTasks(1) - val reason: ExceptionFailure = ser.deserialize[ExceptionFailure]( - serializedData, getClass.getClassLoader) - sched.dagScheduler.taskEnded(task, reason, null, null, info, reason.metrics.getOrElse(null)) - if (!finished(index)) { - copiesRunning(index) -= 1 - numFailures(index) += 1 - val locs = reason.stackTrace.map(loc => "\tat %s".format(loc.toString)) - logInfo("Loss was due to %s\n%s\n%s".format( - reason.className, reason.description, locs.mkString("\n"))) - if (numFailures(index) > MAX_TASK_FAILURES) { - val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format( - taskSet.id, index, MAX_TASK_FAILURES, reason.description) - decreaseRunningTasks(runningTasks) - sched.dagScheduler.taskSetFailed(taskSet, errorMessage) - // need to delete failed Taskset from schedule queue - sched.taskSetFinished(this) - } - } - } - - override def error(message: String) { - sched.dagScheduler.taskSetFailed(taskSet, message) - sched.taskSetFinished(this) - } -} diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index af448fcb37..2f7d6dff38 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark -import org.scalatest.FunSuite +import org.scalatest.{BeforeAndAfterAll, FunSuite} import SparkContext._ import org.apache.spark.util.NonSerializable @@ -37,12 +37,20 @@ object FailureSuiteState { } } -class FailureSuite extends FunSuite with LocalSparkContext { +class FailureSuite extends FunSuite with LocalSparkContext with BeforeAndAfterAll { + + override def beforeAll { + System.setProperty("spark.task.maxFailures", "1") + } + + override def afterAll { + System.clearProperty("spark.task.maxFailures") + } // Run a 3-task map job in which task 1 deterministically fails once, and check // whether the job completes successfully and we ran 4 tasks in total. test("failure in a single-stage job") { - sc = new SparkContext("local[1,1]", "test") + sc = new SparkContext("local[1]", "test") val results = sc.makeRDD(1 to 3, 3).map { x => FailureSuiteState.synchronized { FailureSuiteState.tasksRun += 1 @@ -62,7 +70,7 @@ class FailureSuite extends FunSuite with LocalSparkContext { // Run a map-reduce job in which a reduce task deterministically fails once. test("failure in a two-stage job") { - sc = new SparkContext("local[1,1]", "test") + sc = new SparkContext("local[1]", "test") val results = sc.makeRDD(1 to 3).map(x => (x, x)).groupByKey(3).map { case (k, v) => FailureSuiteState.synchronized { @@ -82,7 +90,7 @@ class FailureSuite extends FunSuite with LocalSparkContext { } test("failure because task results are not serializable") { - sc = new SparkContext("local[1,1]", "test") + sc = new SparkContext("local[1]", "test") val results = sc.makeRDD(1 to 3).map(x => new NonSerializable) val thrown = intercept[SparkException] { @@ -95,7 +103,7 @@ class FailureSuite extends FunSuite with LocalSparkContext { } test("failure because task closure is not serializable") { - sc = new SparkContext("local[1,1]", "test") + sc = new SparkContext("local[1]", "test") val a = new NonSerializable // Non-serializable closure in the final result stage diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index f7f599532a..d9a1d6d087 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -19,23 +19,26 @@ package org.apache.spark.scheduler import scala.collection.mutable.{Buffer, HashSet} -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} import org.scalatest.matchers.ShouldMatchers import org.apache.spark.{LocalSparkContext, SparkContext} import org.apache.spark.SparkContext._ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers - with BeforeAndAfterAll { + with BeforeAndAfter with BeforeAndAfterAll { /** Length of time to wait while draining listener events. */ val WAIT_TIMEOUT_MILLIS = 10000 + before { + sc = new SparkContext("local", "SparkListenerSuite") + } + override def afterAll { System.clearProperty("spark.akka.frameSize") } test("basic creation of StageInfo") { - sc = new SparkContext("local", "DAGSchedulerSuite") val listener = new SaveStageInfo sc.addSparkListener(listener) val rdd1 = sc.parallelize(1 to 100, 4) @@ -56,7 +59,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc } test("StageInfo with fewer tasks than partitions") { - sc = new SparkContext("local", "DAGSchedulerSuite") val listener = new SaveStageInfo sc.addSparkListener(listener) val rdd1 = sc.parallelize(1 to 100, 4) @@ -72,7 +74,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc } test("local metrics") { - sc = new SparkContext("local", "DAGSchedulerSuite") val listener = new SaveStageInfo sc.addSparkListener(listener) sc.addSparkListener(new StatsReportListener) @@ -135,10 +136,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc } test("onTaskGettingResult() called when result fetched remotely") { - // Need to use local cluster mode here, because results are not ever returned through the - // block manager when using the LocalScheduler. - sc = new SparkContext("local-cluster[1,1,512]", "test") - val listener = new SaveTaskEvents sc.addSparkListener(listener) @@ -157,10 +154,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc } test("onTaskGettingResult() not called when result sent directly") { - // Need to use local cluster mode here, because results are not ever returned through the - // block manager when using the LocalScheduler. - sc = new SparkContext("local-cluster[1,1,512]", "test") - val listener = new SaveTaskEvents sc.addSparkListener(listener) diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala index ee150a3107..77d3038614 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala @@ -66,9 +66,7 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA } before { - // Use local-cluster mode because results are returned differently when running with the - // LocalScheduler. - sc = new SparkContext("local-cluster[1,1,512]", "test") + sc = new SparkContext("local", "test") } override def afterAll { diff --git a/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala deleted file mode 100644 index 1e676c1719..0000000000 --- a/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala +++ /dev/null @@ -1,227 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.local - -import java.util.concurrent.Semaphore -import java.util.concurrent.CountDownLatch - -import scala.collection.mutable.HashMap - -import org.scalatest.{BeforeAndAfterEach, FunSuite} - -import org.apache.spark._ - - -class Lock() { - var finished = false - def jobWait() = { - synchronized { - while(!finished) { - this.wait() - } - } - } - - def jobFinished() = { - synchronized { - finished = true - this.notifyAll() - } - } -} - -object TaskThreadInfo { - val threadToLock = HashMap[Int, Lock]() - val threadToRunning = HashMap[Int, Boolean]() - val threadToStarted = HashMap[Int, CountDownLatch]() -} - -/* - * 1. each thread contains one job. - * 2. each job contains one stage. - * 3. each stage only contains one task. - * 4. each task(launched) must be lanched orderly(using threadToStarted) to make sure - * it will get cpu core resource, and will wait to finished after user manually - * release "Lock" and then cluster will contain another free cpu cores. - * 5. each task(pending) must use "sleep" to make sure it has been added to taskSetManager queue, - * thus it will be scheduled later when cluster has free cpu cores. - */ -class LocalSchedulerSuite extends FunSuite with LocalSparkContext with BeforeAndAfterEach { - - override def afterEach() { - super.afterEach() - System.clearProperty("spark.scheduler.mode") - } - - def createThread(threadIndex: Int, poolName: String, sc: SparkContext, sem: Semaphore) { - - TaskThreadInfo.threadToRunning(threadIndex) = false - val nums = sc.parallelize(threadIndex to threadIndex, 1) - TaskThreadInfo.threadToLock(threadIndex) = new Lock() - TaskThreadInfo.threadToStarted(threadIndex) = new CountDownLatch(1) - new Thread { - if (poolName != null) { - sc.setLocalProperty("spark.scheduler.pool", poolName) - } - override def run() { - val ans = nums.map(number => { - TaskThreadInfo.threadToRunning(number) = true - TaskThreadInfo.threadToStarted(number).countDown() - TaskThreadInfo.threadToLock(number).jobWait() - TaskThreadInfo.threadToRunning(number) = false - number - }).collect() - assert(ans.toList === List(threadIndex)) - sem.release() - } - }.start() - } - - test("Local FIFO scheduler end-to-end test") { - System.setProperty("spark.scheduler.mode", "FIFO") - sc = new SparkContext("local[4]", "test") - val sem = new Semaphore(0) - - createThread(1,null,sc,sem) - TaskThreadInfo.threadToStarted(1).await() - createThread(2,null,sc,sem) - TaskThreadInfo.threadToStarted(2).await() - createThread(3,null,sc,sem) - TaskThreadInfo.threadToStarted(3).await() - createThread(4,null,sc,sem) - TaskThreadInfo.threadToStarted(4).await() - // thread 5 and 6 (stage pending)must meet following two points - // 1. stages (taskSetManager) of jobs in thread 5 and 6 should be add to taskSetManager - // queue before executing TaskThreadInfo.threadToLock(1).jobFinished() - // 2. priority of stage in thread 5 should be prior to priority of stage in thread 6 - // So I just use "sleep" 1s here for each thread. - // TODO: any better solution? - createThread(5,null,sc,sem) - Thread.sleep(1000) - createThread(6,null,sc,sem) - Thread.sleep(1000) - - assert(TaskThreadInfo.threadToRunning(1) === true) - assert(TaskThreadInfo.threadToRunning(2) === true) - assert(TaskThreadInfo.threadToRunning(3) === true) - assert(TaskThreadInfo.threadToRunning(4) === true) - assert(TaskThreadInfo.threadToRunning(5) === false) - assert(TaskThreadInfo.threadToRunning(6) === false) - - TaskThreadInfo.threadToLock(1).jobFinished() - TaskThreadInfo.threadToStarted(5).await() - - assert(TaskThreadInfo.threadToRunning(1) === false) - assert(TaskThreadInfo.threadToRunning(2) === true) - assert(TaskThreadInfo.threadToRunning(3) === true) - assert(TaskThreadInfo.threadToRunning(4) === true) - assert(TaskThreadInfo.threadToRunning(5) === true) - assert(TaskThreadInfo.threadToRunning(6) === false) - - TaskThreadInfo.threadToLock(3).jobFinished() - TaskThreadInfo.threadToStarted(6).await() - - assert(TaskThreadInfo.threadToRunning(1) === false) - assert(TaskThreadInfo.threadToRunning(2) === true) - assert(TaskThreadInfo.threadToRunning(3) === false) - assert(TaskThreadInfo.threadToRunning(4) === true) - assert(TaskThreadInfo.threadToRunning(5) === true) - assert(TaskThreadInfo.threadToRunning(6) === true) - - TaskThreadInfo.threadToLock(2).jobFinished() - TaskThreadInfo.threadToLock(4).jobFinished() - TaskThreadInfo.threadToLock(5).jobFinished() - TaskThreadInfo.threadToLock(6).jobFinished() - sem.acquire(6) - } - - test("Local fair scheduler end-to-end test") { - System.setProperty("spark.scheduler.mode", "FAIR") - val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() - System.setProperty("spark.scheduler.allocation.file", xmlPath) - - sc = new SparkContext("local[8]", "LocalSchedulerSuite") - val sem = new Semaphore(0) - - createThread(10,"1",sc,sem) - TaskThreadInfo.threadToStarted(10).await() - createThread(20,"2",sc,sem) - TaskThreadInfo.threadToStarted(20).await() - createThread(30,"3",sc,sem) - TaskThreadInfo.threadToStarted(30).await() - - assert(TaskThreadInfo.threadToRunning(10) === true) - assert(TaskThreadInfo.threadToRunning(20) === true) - assert(TaskThreadInfo.threadToRunning(30) === true) - - createThread(11,"1",sc,sem) - TaskThreadInfo.threadToStarted(11).await() - createThread(21,"2",sc,sem) - TaskThreadInfo.threadToStarted(21).await() - createThread(31,"3",sc,sem) - TaskThreadInfo.threadToStarted(31).await() - - assert(TaskThreadInfo.threadToRunning(11) === true) - assert(TaskThreadInfo.threadToRunning(21) === true) - assert(TaskThreadInfo.threadToRunning(31) === true) - - createThread(12,"1",sc,sem) - TaskThreadInfo.threadToStarted(12).await() - createThread(22,"2",sc,sem) - TaskThreadInfo.threadToStarted(22).await() - createThread(32,"3",sc,sem) - - assert(TaskThreadInfo.threadToRunning(12) === true) - assert(TaskThreadInfo.threadToRunning(22) === true) - assert(TaskThreadInfo.threadToRunning(32) === false) - - TaskThreadInfo.threadToLock(10).jobFinished() - TaskThreadInfo.threadToStarted(32).await() - - assert(TaskThreadInfo.threadToRunning(32) === true) - - //1. Similar with above scenario, sleep 1s for stage of 23 and 33 to be added to taskSetManager - // queue so that cluster will assign free cpu core to stage 23 after stage 11 finished. - //2. priority of 23 and 33 will be meaningless as using fair scheduler here. - createThread(23,"2",sc,sem) - createThread(33,"3",sc,sem) - Thread.sleep(1000) - - TaskThreadInfo.threadToLock(11).jobFinished() - TaskThreadInfo.threadToStarted(23).await() - - assert(TaskThreadInfo.threadToRunning(23) === true) - assert(TaskThreadInfo.threadToRunning(33) === false) - - TaskThreadInfo.threadToLock(12).jobFinished() - TaskThreadInfo.threadToStarted(33).await() - - assert(TaskThreadInfo.threadToRunning(33) === true) - - TaskThreadInfo.threadToLock(20).jobFinished() - TaskThreadInfo.threadToLock(21).jobFinished() - TaskThreadInfo.threadToLock(22).jobFinished() - TaskThreadInfo.threadToLock(23).jobFinished() - TaskThreadInfo.threadToLock(30).jobFinished() - TaskThreadInfo.threadToLock(31).jobFinished() - TaskThreadInfo.threadToLock(32).jobFinished() - TaskThreadInfo.threadToLock(33).jobFinished() - - sem.acquire(11) - } -} diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala index 29b3f22e13..e873400680 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala @@ -17,16 +17,20 @@ package org.apache.spark.scheduler.cluster +import org.apache.hadoop.conf.Configuration + import org.apache.spark._ import org.apache.spark.deploy.yarn.{ApplicationMaster, YarnAllocationHandler} +import org.apache.spark.scheduler.TaskScheduler import org.apache.spark.util.Utils -import org.apache.hadoop.conf.Configuration /** * - * This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of ApplicationMaster, etc is done + * This is a simple extension to TaskScheduler - to ensure that appropriate initialization of + * ApplicationMaster, etc. is done */ -private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) { +private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) + extends TaskScheduler(sc) { logInfo("Created YarnClusterScheduler") -- cgit v1.2.3 From a124658e53a5abeda00a2582385a294c8e452d21 Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Wed, 30 Oct 2013 19:29:38 -0700 Subject: Fixed most issues with unit tests --- .../apache/spark/scheduler/DAGSchedulerSuite.scala | 94 +++--- .../org/apache/spark/scheduler/FakeTask.scala | 26 ++ .../spark/scheduler/TaskResultGetterSuite.scala | 111 +++++++ .../spark/scheduler/TaskSchedulerSuite.scala | 265 +++++++++++++++++ .../spark/scheduler/TaskSetManagerSuite.scala | 317 ++++++++++++++++++++ .../scheduler/cluster/ClusterSchedulerSuite.scala | 267 ----------------- .../cluster/ClusterTaskSetManagerSuite.scala | 318 --------------------- .../apache/spark/scheduler/cluster/FakeTask.scala | 27 -- .../scheduler/cluster/TaskResultGetterSuite.scala | 112 -------- 9 files changed, 767 insertions(+), 770 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala create mode 100644 core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala delete mode 100644 core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 00f2fdd657..394a1bb06f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -33,6 +33,24 @@ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} +/** + * TaskScheduler that records the task sets that the DAGScheduler requested executed. + */ +class TaskSetRecordingTaskScheduler(sc: SparkContext) extends TaskScheduler(sc) { + /** Set of TaskSets the DAGScheduler has requested executed. */ + val taskSets = scala.collection.mutable.Buffer[TaskSet]() + override def start() = {} + override def stop() = {} + override def submitTasks(taskSet: TaskSet) = { + // normally done by TaskSetManager + taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch) + taskSets += taskSet + } + override def cancelTasks(stageId: Int) {} + override def setDAGScheduler(dagScheduler: DAGScheduler) = {} + override def defaultParallelism() = 2 +} + /** * Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler * rather than spawning an event loop thread as happens in the real code. They use EasyMock @@ -46,24 +64,7 @@ import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} * and capturing the resulting TaskSets from the mock TaskScheduler. */ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { - - /** Set of TaskSets the DAGScheduler has requested executed. */ - val taskSets = scala.collection.mutable.Buffer[TaskSet]() - val taskScheduler = new TaskScheduler() { - override def rootPool: Pool = null - override def schedulingMode: SchedulingMode = SchedulingMode.NONE - override def start() = {} - override def stop() = {} - override def submitTasks(taskSet: TaskSet) = { - // normally done by TaskSetManager - taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch) - taskSets += taskSet - } - override def cancelTasks(stageId: Int) {} - override def setDAGScheduler(dagScheduler: DAGScheduler) = {} - override def defaultParallelism() = 2 - } - + var taskScheduler: TaskSetRecordingTaskScheduler = null var mapOutputTracker: MapOutputTrackerMaster = null var scheduler: DAGScheduler = null @@ -96,7 +97,8 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont before { sc = new SparkContext("local", "DAGSchedulerSuite") - taskSets.clear() + taskScheduler = new TaskSetRecordingTaskScheduler(sc) + taskScheduler.taskSets.clear() cacheLocations.clear() results.clear() mapOutputTracker = new MapOutputTrackerMaster() @@ -204,7 +206,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont test("run trivial job") { val rdd = makeRdd(1, Nil) submit(rdd, Array(0)) - complete(taskSets(0), List((Success, 42))) + complete(taskScheduler.taskSets(0), List((Success, 42))) assert(results === Map(0 -> 42)) } @@ -225,7 +227,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont val baseRdd = makeRdd(1, Nil) val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd))) submit(finalRdd, Array(0)) - complete(taskSets(0), Seq((Success, 42))) + complete(taskScheduler.taskSets(0), Seq((Success, 42))) assert(results === Map(0 -> 42)) } @@ -235,7 +237,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont cacheLocations(baseRdd.id -> 0) = Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")) submit(finalRdd, Array(0)) - val taskSet = taskSets(0) + val taskSet = taskScheduler.taskSets(0) assertLocations(taskSet, Seq(Seq("hostA", "hostB"))) complete(taskSet, Seq((Success, 42))) assert(results === Map(0 -> 42)) @@ -243,7 +245,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont test("trivial job failure") { submit(makeRdd(1, Nil), Array(0)) - failed(taskSets(0), "some failure") + failed(taskScheduler.taskSets(0), "some failure") assert(failure.getMessage === "Job aborted: some failure") } @@ -253,12 +255,12 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont val shuffleId = shuffleDep.shuffleId val reduceRdd = makeRdd(1, List(shuffleDep)) submit(reduceRdd, Array(0)) - complete(taskSets(0), Seq( + complete(taskScheduler.taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)))) assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) - complete(taskSets(1), Seq((Success, 42))) + complete(taskScheduler.taskSets(1), Seq((Success, 42))) assert(results === Map(0 -> 42)) } @@ -268,11 +270,11 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont val shuffleId = shuffleDep.shuffleId val reduceRdd = makeRdd(2, List(shuffleDep)) submit(reduceRdd, Array(0, 1)) - complete(taskSets(0), Seq( + complete(taskScheduler.taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)))) // the 2nd ResultTask failed - complete(taskSets(1), Seq( + complete(taskScheduler.taskSets(1), Seq( (Success, 42), (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), null))) // this will get called @@ -280,10 +282,10 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont // ask the scheduler to try it again scheduler.resubmitFailedStages() // have the 2nd attempt pass - complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1)))) + complete(taskScheduler.taskSets(2), Seq((Success, makeMapStatus("hostA", 1)))) // we can see both result blocks now assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB")) - complete(taskSets(3), Seq((Success, 43))) + complete(taskScheduler.taskSets(3), Seq((Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) } @@ -299,7 +301,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont val newEpoch = mapOutputTracker.getEpoch assert(newEpoch > oldEpoch) val noAccum = Map[Long, Any]() - val taskSet = taskSets(0) + val taskSet = taskScheduler.taskSets(0) // should be ignored for being too old runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum, null, null)) // should work because it's a non-failed host @@ -311,7 +313,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum, null, null)) assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) - complete(taskSets(1), Seq((Success, 42), (Success, 43))) + complete(taskScheduler.taskSets(1), Seq((Success, 42), (Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) } @@ -326,14 +328,14 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont runEvent(ExecutorLost("exec-hostA")) // DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks // rather than marking it is as failed and waiting. - complete(taskSets(0), Seq( + complete(taskScheduler.taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)))) // have hostC complete the resubmitted task - complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) + complete(taskScheduler.taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) - complete(taskSets(2), Seq((Success, 42))) + complete(taskScheduler.taskSets(2), Seq((Success, 42))) assert(results === Map(0 -> 42)) } @@ -345,23 +347,23 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont val finalRdd = makeRdd(1, List(shuffleDepTwo)) submit(finalRdd, Array(0)) // have the first stage complete normally - complete(taskSets(0), Seq( + complete(taskScheduler.taskSets(0), Seq( (Success, makeMapStatus("hostA", 2)), (Success, makeMapStatus("hostB", 2)))) // have the second stage complete normally - complete(taskSets(1), Seq( + complete(taskScheduler.taskSets(1), Seq( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostC", 1)))) // fail the third stage because hostA went down - complete(taskSets(2), Seq( + complete(taskScheduler.taskSets(2), Seq( (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null))) // TODO assert this: // blockManagerMaster.removeExecutor("exec-hostA") // have DAGScheduler try again scheduler.resubmitFailedStages() - complete(taskSets(3), Seq((Success, makeMapStatus("hostA", 2)))) - complete(taskSets(4), Seq((Success, makeMapStatus("hostA", 1)))) - complete(taskSets(5), Seq((Success, 42))) + complete(taskScheduler.taskSets(3), Seq((Success, makeMapStatus("hostA", 2)))) + complete(taskScheduler.taskSets(4), Seq((Success, makeMapStatus("hostA", 1)))) + complete(taskScheduler.taskSets(5), Seq((Success, 42))) assert(results === Map(0 -> 42)) } @@ -375,24 +377,24 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD")) cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC")) // complete stage 2 - complete(taskSets(0), Seq( + complete(taskScheduler.taskSets(0), Seq( (Success, makeMapStatus("hostA", 2)), (Success, makeMapStatus("hostB", 2)))) // complete stage 1 - complete(taskSets(1), Seq( + complete(taskScheduler.taskSets(1), Seq( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)))) // pretend stage 0 failed because hostA went down - complete(taskSets(2), Seq( + complete(taskScheduler.taskSets(2), Seq( (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null))) // TODO assert this: // blockManagerMaster.removeExecutor("exec-hostA") // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun. scheduler.resubmitFailedStages() - assertLocations(taskSets(3), Seq(Seq("hostD"))) + assertLocations(taskScheduler.taskSets(3), Seq(Seq("hostD"))) // allow hostD to recover - complete(taskSets(3), Seq((Success, makeMapStatus("hostD", 1)))) - complete(taskSets(4), Seq((Success, 42))) + complete(taskScheduler.taskSets(3), Seq((Success, makeMapStatus("hostD", 1)))) + complete(taskScheduler.taskSets(4), Seq((Success, 42))) assert(results === Map(0 -> 42)) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala new file mode 100644 index 0000000000..0b90c4e74c --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.TaskContext + +class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0) { + override def runTask(context: TaskContext): Int = 0 + + override def preferredLocations: Seq[TaskLocation] = prefLocs +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala new file mode 100644 index 0000000000..30e6bc5721 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.nio.ByteBuffer + +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} + +import org.apache.spark.{LocalSparkContext, SparkContext, SparkEnv} +import org.apache.spark.storage.TaskResultBlockId + +/** + * Removes the TaskResult from the BlockManager before delegating to a normal TaskResultGetter. + * + * Used to test the case where a BlockManager evicts the task result (or dies) before the + * TaskResult is retrieved. + */ +class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskScheduler) + extends TaskResultGetter(sparkEnv, scheduler) { + var removedResult = false + + override def enqueueSuccessfulTask( + taskSetManager: TaskSetManager, tid: Long, serializedData: ByteBuffer) { + if (!removedResult) { + // Only remove the result once, since we'd like to test the case where the task eventually + // succeeds. + serializer.get().deserialize[TaskResult[_]](serializedData) match { + case IndirectTaskResult(blockId) => + sparkEnv.blockManager.master.removeBlock(blockId) + case directResult: DirectTaskResult[_] => + taskSetManager.abort("Internal error: expect only indirect results") + } + serializedData.rewind() + removedResult = true + } + super.enqueueSuccessfulTask(taskSetManager, tid, serializedData) + } +} + +/** + * Tests related to handling task results (both direct and indirect). + */ +class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndAfterAll + with LocalSparkContext { + + override def beforeAll { + // Set the Akka frame size to be as small as possible (it must be an integer, so 1 is as small + // as we can make it) so the tests don't take too long. + System.setProperty("spark.akka.frameSize", "1") + } + + before { + sc = new SparkContext("local", "test") + } + + override def afterAll { + System.clearProperty("spark.akka.frameSize") + } + + test("handling results smaller than Akka frame size") { + val result = sc.parallelize(Seq(1), 1).map(x => 2 * x).reduce((x, y) => x) + assert(result === 2) + } + + test("handling results larger than Akka frame size") { + val akkaFrameSize = + sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt + val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x) + assert(result === 1.to(akkaFrameSize).toArray) + + val RESULT_BLOCK_ID = TaskResultBlockId(0) + assert(sc.env.blockManager.master.getLocations(RESULT_BLOCK_ID).size === 0, + "Expect result to be removed from the block manager.") + } + + test("task retried if result missing from block manager") { + // If this test hangs, it's probably because no resource offers were made after the task + // failed. + val scheduler: TaskScheduler = sc.taskScheduler match { + case clusterScheduler: TaskScheduler => + clusterScheduler + case _ => + assert(false, "Expect local cluster to use TaskScheduler") + throw new ClassCastException + } + scheduler.taskResultGetter = new ResultDeletingTaskResultGetter(sc.env, scheduler) + val akkaFrameSize = + sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt + val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x) + assert(result === 1.to(akkaFrameSize).toArray) + + // Make sure two tasks were run (one failed one, and a second retried one). + assert(scheduler.nextTaskId.get() === 2) + } +} + diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerSuite.scala new file mode 100644 index 0000000000..bfbffdf261 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerSuite.scala @@ -0,0 +1,265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter + +import org.apache.spark._ +import scala.collection.mutable.ArrayBuffer + +import java.util.Properties + +class FakeTaskSetManager( + initPriority: Int, + initStageId: Int, + initNumTasks: Int, + taskScheduler: TaskScheduler, + taskSet: TaskSet) + extends TaskSetManager(taskScheduler, taskSet) { + + parent = null + weight = 1 + minShare = 2 + runningTasks = 0 + priority = initPriority + stageId = initStageId + name = "TaskSet_"+stageId + override val numTasks = initNumTasks + tasksSuccessful = 0 + + def increaseRunningTasks(taskNum: Int) { + runningTasks += taskNum + if (parent != null) { + parent.increaseRunningTasks(taskNum) + } + } + + def decreaseRunningTasks(taskNum: Int) { + runningTasks -= taskNum + if (parent != null) { + parent.decreaseRunningTasks(taskNum) + } + } + + override def addSchedulable(schedulable: Schedulable) { + } + + override def removeSchedulable(schedulable: Schedulable) { + } + + override def getSchedulableByName(name: String): Schedulable = { + return null + } + + override def executorLost(executorId: String, host: String): Unit = { + } + + override def resourceOffer( + execId: String, + host: String, + availableCpus: Int, + maxLocality: TaskLocality.TaskLocality) + : Option[TaskDescription] = + { + if (tasksSuccessful + runningTasks < numTasks) { + increaseRunningTasks(1) + return Some(new TaskDescription(0, execId, "task 0:0", 0, null)) + } + return None + } + + override def checkSpeculatableTasks(): Boolean = { + return true + } + + def taskFinished() { + decreaseRunningTasks(1) + tasksSuccessful +=1 + if (tasksSuccessful == numTasks) { + parent.removeSchedulable(this) + } + } + + def abort() { + decreaseRunningTasks(runningTasks) + parent.removeSchedulable(this) + } +} + +class TaskSchedulerSuite extends FunSuite with LocalSparkContext with Logging { + + def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: TaskScheduler, taskSet: TaskSet): FakeTaskSetManager = { + new FakeTaskSetManager(priority, stage, numTasks, cs , taskSet) + } + + def resourceOffer(rootPool: Pool): Int = { + val taskSetQueue = rootPool.getSortedTaskSetQueue() + /* Just for Test*/ + for (manager <- taskSetQueue) { + logInfo("parentName:%s, parent running tasks:%d, name:%s,runningTasks:%d".format( + manager.parent.name, manager.parent.runningTasks, manager.name, manager.runningTasks)) + } + for (taskSet <- taskSetQueue) { + taskSet.resourceOffer("execId_1", "hostname_1", 1, TaskLocality.ANY) match { + case Some(task) => + return taskSet.stageId + case None => {} + } + } + -1 + } + + def checkTaskSetId(rootPool: Pool, expectedTaskSetId: Int) { + assert(resourceOffer(rootPool) === expectedTaskSetId) + } + + test("FIFO Scheduler Test") { + sc = new SparkContext("local", "TaskSchedulerSuite") + val taskScheduler = new TaskScheduler(sc) + var tasks = ArrayBuffer[Task[_]]() + val task = new FakeTask(0) + tasks += task + val taskSet = new TaskSet(tasks.toArray,0,0,0,null) + + val rootPool = new Pool("", SchedulingMode.FIFO, 0, 0) + val schedulableBuilder = new FIFOSchedulableBuilder(rootPool) + schedulableBuilder.buildPools() + + val taskSetManager0 = createDummyTaskSetManager(0, 0, 2, taskScheduler, taskSet) + val taskSetManager1 = createDummyTaskSetManager(0, 1, 2, taskScheduler, taskSet) + val taskSetManager2 = createDummyTaskSetManager(0, 2, 2, taskScheduler, taskSet) + schedulableBuilder.addTaskSetManager(taskSetManager0, null) + schedulableBuilder.addTaskSetManager(taskSetManager1, null) + schedulableBuilder.addTaskSetManager(taskSetManager2, null) + + checkTaskSetId(rootPool, 0) + resourceOffer(rootPool) + checkTaskSetId(rootPool, 1) + resourceOffer(rootPool) + taskSetManager1.abort() + checkTaskSetId(rootPool, 2) + } + + test("Fair Scheduler Test") { + sc = new SparkContext("local", "TaskSchedulerSuite") + val taskScheduler = new TaskScheduler(sc) + var tasks = ArrayBuffer[Task[_]]() + val task = new FakeTask(0) + tasks += task + val taskSet = new TaskSet(tasks.toArray,0,0,0,null) + + val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() + System.setProperty("spark.scheduler.allocation.file", xmlPath) + val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) + val schedulableBuilder = new FairSchedulableBuilder(rootPool) + schedulableBuilder.buildPools() + + assert(rootPool.getSchedulableByName("default") != null) + assert(rootPool.getSchedulableByName("1") != null) + assert(rootPool.getSchedulableByName("2") != null) + assert(rootPool.getSchedulableByName("3") != null) + assert(rootPool.getSchedulableByName("1").minShare === 2) + assert(rootPool.getSchedulableByName("1").weight === 1) + assert(rootPool.getSchedulableByName("2").minShare === 3) + assert(rootPool.getSchedulableByName("2").weight === 1) + assert(rootPool.getSchedulableByName("3").minShare === 0) + assert(rootPool.getSchedulableByName("3").weight === 1) + + val properties1 = new Properties() + properties1.setProperty("spark.scheduler.pool","1") + val properties2 = new Properties() + properties2.setProperty("spark.scheduler.pool","2") + + val taskSetManager10 = createDummyTaskSetManager(1, 0, 1, taskScheduler, taskSet) + val taskSetManager11 = createDummyTaskSetManager(1, 1, 1, taskScheduler, taskSet) + val taskSetManager12 = createDummyTaskSetManager(1, 2, 2, taskScheduler, taskSet) + schedulableBuilder.addTaskSetManager(taskSetManager10, properties1) + schedulableBuilder.addTaskSetManager(taskSetManager11, properties1) + schedulableBuilder.addTaskSetManager(taskSetManager12, properties1) + + val taskSetManager23 = createDummyTaskSetManager(2, 3, 2, taskScheduler, taskSet) + val taskSetManager24 = createDummyTaskSetManager(2, 4, 2, taskScheduler, taskSet) + schedulableBuilder.addTaskSetManager(taskSetManager23, properties2) + schedulableBuilder.addTaskSetManager(taskSetManager24, properties2) + + checkTaskSetId(rootPool, 0) + checkTaskSetId(rootPool, 3) + checkTaskSetId(rootPool, 3) + checkTaskSetId(rootPool, 1) + checkTaskSetId(rootPool, 4) + checkTaskSetId(rootPool, 2) + checkTaskSetId(rootPool, 2) + checkTaskSetId(rootPool, 4) + + taskSetManager12.taskFinished() + assert(rootPool.getSchedulableByName("1").runningTasks === 3) + taskSetManager24.abort() + assert(rootPool.getSchedulableByName("2").runningTasks === 2) + } + + test("Nested Pool Test") { + sc = new SparkContext("local", "TaskSchedulerSuite") + val taskScheduler = new TaskScheduler(sc) + var tasks = ArrayBuffer[Task[_]]() + val task = new FakeTask(0) + tasks += task + val taskSet = new TaskSet(tasks.toArray,0,0,0,null) + + val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) + val pool0 = new Pool("0", SchedulingMode.FAIR, 3, 1) + val pool1 = new Pool("1", SchedulingMode.FAIR, 4, 1) + rootPool.addSchedulable(pool0) + rootPool.addSchedulable(pool1) + + val pool00 = new Pool("00", SchedulingMode.FAIR, 2, 2) + val pool01 = new Pool("01", SchedulingMode.FAIR, 1, 1) + pool0.addSchedulable(pool00) + pool0.addSchedulable(pool01) + + val pool10 = new Pool("10", SchedulingMode.FAIR, 2, 2) + val pool11 = new Pool("11", SchedulingMode.FAIR, 2, 1) + pool1.addSchedulable(pool10) + pool1.addSchedulable(pool11) + + val taskSetManager000 = createDummyTaskSetManager(0, 0, 5, taskScheduler, taskSet) + val taskSetManager001 = createDummyTaskSetManager(0, 1, 5, taskScheduler, taskSet) + pool00.addSchedulable(taskSetManager000) + pool00.addSchedulable(taskSetManager001) + + val taskSetManager010 = createDummyTaskSetManager(1, 2, 5, taskScheduler, taskSet) + val taskSetManager011 = createDummyTaskSetManager(1, 3, 5, taskScheduler, taskSet) + pool01.addSchedulable(taskSetManager010) + pool01.addSchedulable(taskSetManager011) + + val taskSetManager100 = createDummyTaskSetManager(2, 4, 5, taskScheduler, taskSet) + val taskSetManager101 = createDummyTaskSetManager(2, 5, 5, taskScheduler, taskSet) + pool10.addSchedulable(taskSetManager100) + pool10.addSchedulable(taskSetManager101) + + val taskSetManager110 = createDummyTaskSetManager(3, 6, 5, taskScheduler, taskSet) + val taskSetManager111 = createDummyTaskSetManager(3, 7, 5, taskScheduler, taskSet) + pool11.addSchedulable(taskSetManager110) + pool11.addSchedulable(taskSetManager111) + + checkTaskSetId(rootPool, 0) + checkTaskSetId(rootPool, 4) + checkTaskSetId(rootPool, 6) + checkTaskSetId(rootPool, 2) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala new file mode 100644 index 0000000000..fe3ea7b594 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -0,0 +1,317 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable + +import org.scalatest.FunSuite + +import org.apache.spark._ +import org.apache.spark.executor.TaskMetrics +import java.nio.ByteBuffer +import org.apache.spark.util.{Utils, FakeClock} + +class FakeDAGScheduler(taskScheduler: FakeTaskScheduler) extends DAGScheduler(taskScheduler) { + override def taskStarted(task: Task[_], taskInfo: TaskInfo) { + taskScheduler.startedTasks += taskInfo.index + } + + override def taskEnded( + task: Task[_], + reason: TaskEndReason, + result: Any, + accumUpdates: mutable.Map[Long, Any], + taskInfo: TaskInfo, + taskMetrics: TaskMetrics) { + taskScheduler.endedTasks(taskInfo.index) = reason + } + + override def executorGained(execId: String, host: String) {} + + override def executorLost(execId: String) {} + + override def taskSetFailed(taskSet: TaskSet, reason: String) { + taskScheduler.taskSetsFailed += taskSet.id + } +} + +/** + * A mock TaskScheduler implementation that just remembers information about tasks started and + * feedback received from the TaskSetManagers. Note that it's important to initialize this with + * a list of "live" executors and their hostnames for isExecutorAlive and hasExecutorsAliveOnHost + * to work, and these are required for locality in TaskSetManager. + */ +class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* execId, host */) + extends TaskScheduler(sc) +{ + val startedTasks = new ArrayBuffer[Long] + val endedTasks = new mutable.HashMap[Long, TaskEndReason] + val finishedManagers = new ArrayBuffer[TaskSetManager] + val taskSetsFailed = new ArrayBuffer[String] + + val executors = new mutable.HashMap[String, String] ++ liveExecutors + + dagScheduler = new FakeDAGScheduler(this) + + def removeExecutor(execId: String): Unit = executors -= execId + + override def taskSetFinished(manager: TaskSetManager): Unit = finishedManagers += manager + + override def isExecutorAlive(execId: String): Boolean = executors.contains(execId) + + override def hasExecutorsAliveOnHost(host: String): Boolean = executors.values.exists(_ == host) +} + +class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { + import TaskLocality.{ANY, PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL} + + val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong + + test("TaskSet with no preferences") { + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) + val taskSet = createTaskSet(1) + val manager = new TaskSetManager(sched, taskSet) + + // Offer a host with no CPUs + assert(manager.resourceOffer("exec1", "host1", 0, ANY) === None) + + // Offer a host with process-local as the constraint; this should work because the TaskSet + // above won't have any locality preferences + val taskOption = manager.resourceOffer("exec1", "host1", 2, TaskLocality.PROCESS_LOCAL) + assert(taskOption.isDefined) + val task = taskOption.get + assert(task.executorId === "exec1") + assert(sched.startedTasks.contains(0)) + + // Re-offer the host -- now we should get no more tasks + assert(manager.resourceOffer("exec1", "host1", 2, PROCESS_LOCAL) === None) + + // Tell it the task has finished + manager.handleSuccessfulTask(0, createTaskResult(0)) + assert(sched.endedTasks(0) === Success) + assert(sched.finishedManagers.contains(manager)) + } + + test("multiple offers with no preferences") { + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) + val taskSet = createTaskSet(3) + val manager = new TaskSetManager(sched, taskSet) + + // First three offers should all find tasks + for (i <- 0 until 3) { + val taskOption = manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) + assert(taskOption.isDefined) + val task = taskOption.get + assert(task.executorId === "exec1") + } + assert(sched.startedTasks.toSet === Set(0, 1, 2)) + + // Re-offer the host -- now we should get no more tasks + assert(manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) === None) + + // Finish the first two tasks + manager.handleSuccessfulTask(0, createTaskResult(0)) + manager.handleSuccessfulTask(1, createTaskResult(1)) + assert(sched.endedTasks(0) === Success) + assert(sched.endedTasks(1) === Success) + assert(!sched.finishedManagers.contains(manager)) + + // Finish the last task + manager.handleSuccessfulTask(2, createTaskResult(2)) + assert(sched.endedTasks(2) === Success) + assert(sched.finishedManagers.contains(manager)) + } + + test("basic delay scheduling") { + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) + val taskSet = createTaskSet(4, + Seq(TaskLocation("host1", "exec1")), + Seq(TaskLocation("host2", "exec2")), + Seq(TaskLocation("host1"), TaskLocation("host2", "exec2")), + Seq() // Last task has no locality prefs + ) + val clock = new FakeClock + val manager = new TaskSetManager(sched, taskSet, clock) + + // First offer host1, exec1: first task should be chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) + + // Offer host1, exec1 again: the last task, which has no prefs, should be chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 3) + + // Offer host1, exec1 again, at PROCESS_LOCAL level: nothing should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) === None) + + clock.advance(LOCALITY_WAIT) + + // Offer host1, exec1 again, at PROCESS_LOCAL level: nothing should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) === None) + + // Offer host1, exec1 again, at NODE_LOCAL level: we should choose task 2 + assert(manager.resourceOffer("exec1", "host1", 1, NODE_LOCAL).get.index == 2) + + // Offer host1, exec1 again, at NODE_LOCAL level: nothing should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, NODE_LOCAL) === None) + + // Offer host1, exec1 again, at ANY level: nothing should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None) + + clock.advance(LOCALITY_WAIT) + + // Offer host1, exec1 again, at ANY level: task 1 should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 1) + + // Offer host1, exec1 again, at ANY level: nothing should be chosen as we've launched all tasks + assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None) + } + + test("delay scheduling with fallback") { + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc, + ("exec1", "host1"), ("exec2", "host2"), ("exec3", "host3")) + val taskSet = createTaskSet(5, + Seq(TaskLocation("host1")), + Seq(TaskLocation("host2")), + Seq(TaskLocation("host2")), + Seq(TaskLocation("host3")), + Seq(TaskLocation("host2")) + ) + val clock = new FakeClock + val manager = new TaskSetManager(sched, taskSet, clock) + + // First offer host1: first task should be chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) + + // Offer host1 again: nothing should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None) + + clock.advance(LOCALITY_WAIT) + + // Offer host1 again: second task (on host2) should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 1) + + // Offer host1 again: third task (on host2) should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 2) + + // Offer host2: fifth task (also on host2) should get chosen + assert(manager.resourceOffer("exec2", "host2", 1, ANY).get.index === 4) + + // Now that we've launched a local task, we should no longer launch the task for host3 + assert(manager.resourceOffer("exec2", "host2", 1, ANY) === None) + + clock.advance(LOCALITY_WAIT) + + // After another delay, we can go ahead and launch that task non-locally + assert(manager.resourceOffer("exec2", "host2", 1, ANY).get.index === 3) + } + + test("delay scheduling with failed hosts") { + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) + val taskSet = createTaskSet(3, + Seq(TaskLocation("host1")), + Seq(TaskLocation("host2")), + Seq(TaskLocation("host3")) + ) + val clock = new FakeClock + val manager = new TaskSetManager(sched, taskSet, clock) + + // First offer host1: first task should be chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) + + // Offer host1 again: third task should be chosen immediately because host3 is not up + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 2) + + // After this, nothing should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None) + + // Now mark host2 as dead + sched.removeExecutor("exec2") + manager.executorLost("exec2", "host2") + + // Task 1 should immediately be launched on host1 because its original host is gone + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 1) + + // Now that all tasks have launched, nothing new should be launched anywhere else + assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None) + assert(manager.resourceOffer("exec2", "host2", 1, ANY) === None) + } + + test("task result lost") { + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) + val taskSet = createTaskSet(1) + val clock = new FakeClock + val manager = new TaskSetManager(sched, taskSet, clock) + + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) + + // Tell it the task has finished but the result was lost. + manager.handleFailedTask(0, TaskState.FINISHED, Some(TaskResultLost)) + assert(sched.endedTasks(0) === TaskResultLost) + + // Re-offer the host -- now we should get task 0 again. + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) + } + + test("repeated failures lead to task set abortion") { + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) + val taskSet = createTaskSet(1) + val clock = new FakeClock + val manager = new TaskSetManager(sched, taskSet, clock) + + // Fail the task MAX_TASK_FAILURES times, and check that the task set is aborted + // after the last failure. + (0 until manager.MAX_TASK_FAILURES).foreach { index => + val offerResult = manager.resourceOffer("exec1", "host1", 1, ANY) + assert(offerResult != None, + "Expect resource offer on iteration %s to return a task".format(index)) + assert(offerResult.get.index === 0) + manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, Some(TaskResultLost)) + if (index < manager.MAX_TASK_FAILURES) { + assert(!sched.taskSetsFailed.contains(taskSet.id)) + } else { + assert(sched.taskSetsFailed.contains(taskSet.id)) + } + } + } + + + /** + * Utility method to create a TaskSet, potentially setting a particular sequence of preferred + * locations for each task (given as varargs) if this sequence is not empty. + */ + def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = { + if (prefLocs.size != 0 && prefLocs.size != numTasks) { + throw new IllegalArgumentException("Wrong number of task locations") + } + val tasks = Array.tabulate[Task[_]](numTasks) { i => + new FakeTask(i, if (prefLocs.size != 0) prefLocs(i) else Nil) + } + new TaskSet(tasks, 0, 0, 0, null) + } + + def createTaskResult(id: Int): DirectTaskResult[Int] = { + new DirectTaskResult[Int](id, mutable.Map.empty, new TaskMetrics) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala deleted file mode 100644 index 95d3553d91..0000000000 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala +++ /dev/null @@ -1,267 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.cluster - -import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter - -import org.apache.spark._ -import org.apache.spark.scheduler._ -import org.apache.spark.scheduler.cluster._ -import scala.collection.mutable.ArrayBuffer - -import java.util.Properties - -class FakeTaskSetManager( - initPriority: Int, - initStageId: Int, - initNumTasks: Int, - clusterScheduler: ClusterScheduler, - taskSet: TaskSet) - extends ClusterTaskSetManager(clusterScheduler, taskSet) { - - parent = null - weight = 1 - minShare = 2 - runningTasks = 0 - priority = initPriority - stageId = initStageId - name = "TaskSet_"+stageId - override val numTasks = initNumTasks - tasksSuccessful = 0 - - def increaseRunningTasks(taskNum: Int) { - runningTasks += taskNum - if (parent != null) { - parent.increaseRunningTasks(taskNum) - } - } - - def decreaseRunningTasks(taskNum: Int) { - runningTasks -= taskNum - if (parent != null) { - parent.decreaseRunningTasks(taskNum) - } - } - - override def addSchedulable(schedulable: Schedulable) { - } - - override def removeSchedulable(schedulable: Schedulable) { - } - - override def getSchedulableByName(name: String): Schedulable = { - return null - } - - override def executorLost(executorId: String, host: String): Unit = { - } - - override def resourceOffer( - execId: String, - host: String, - availableCpus: Int, - maxLocality: TaskLocality.TaskLocality) - : Option[TaskDescription] = - { - if (tasksSuccessful + runningTasks < numTasks) { - increaseRunningTasks(1) - return Some(new TaskDescription(0, execId, "task 0:0", 0, null)) - } - return None - } - - override def checkSpeculatableTasks(): Boolean = { - return true - } - - def taskFinished() { - decreaseRunningTasks(1) - tasksSuccessful +=1 - if (tasksSuccessful == numTasks) { - parent.removeSchedulable(this) - } - } - - def abort() { - decreaseRunningTasks(runningTasks) - parent.removeSchedulable(this) - } -} - -class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging { - - def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: ClusterScheduler, taskSet: TaskSet): FakeTaskSetManager = { - new FakeTaskSetManager(priority, stage, numTasks, cs , taskSet) - } - - def resourceOffer(rootPool: Pool): Int = { - val taskSetQueue = rootPool.getSortedTaskSetQueue() - /* Just for Test*/ - for (manager <- taskSetQueue) { - logInfo("parentName:%s, parent running tasks:%d, name:%s,runningTasks:%d".format( - manager.parent.name, manager.parent.runningTasks, manager.name, manager.runningTasks)) - } - for (taskSet <- taskSetQueue) { - taskSet.resourceOffer("execId_1", "hostname_1", 1, TaskLocality.ANY) match { - case Some(task) => - return taskSet.stageId - case None => {} - } - } - -1 - } - - def checkTaskSetId(rootPool: Pool, expectedTaskSetId: Int) { - assert(resourceOffer(rootPool) === expectedTaskSetId) - } - - test("FIFO Scheduler Test") { - sc = new SparkContext("local", "ClusterSchedulerSuite") - val clusterScheduler = new ClusterScheduler(sc) - var tasks = ArrayBuffer[Task[_]]() - val task = new FakeTask(0) - tasks += task - val taskSet = new TaskSet(tasks.toArray,0,0,0,null) - - val rootPool = new Pool("", SchedulingMode.FIFO, 0, 0) - val schedulableBuilder = new FIFOSchedulableBuilder(rootPool) - schedulableBuilder.buildPools() - - val taskSetManager0 = createDummyTaskSetManager(0, 0, 2, clusterScheduler, taskSet) - val taskSetManager1 = createDummyTaskSetManager(0, 1, 2, clusterScheduler, taskSet) - val taskSetManager2 = createDummyTaskSetManager(0, 2, 2, clusterScheduler, taskSet) - schedulableBuilder.addTaskSetManager(taskSetManager0, null) - schedulableBuilder.addTaskSetManager(taskSetManager1, null) - schedulableBuilder.addTaskSetManager(taskSetManager2, null) - - checkTaskSetId(rootPool, 0) - resourceOffer(rootPool) - checkTaskSetId(rootPool, 1) - resourceOffer(rootPool) - taskSetManager1.abort() - checkTaskSetId(rootPool, 2) - } - - test("Fair Scheduler Test") { - sc = new SparkContext("local", "ClusterSchedulerSuite") - val clusterScheduler = new ClusterScheduler(sc) - var tasks = ArrayBuffer[Task[_]]() - val task = new FakeTask(0) - tasks += task - val taskSet = new TaskSet(tasks.toArray,0,0,0,null) - - val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() - System.setProperty("spark.scheduler.allocation.file", xmlPath) - val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) - val schedulableBuilder = new FairSchedulableBuilder(rootPool) - schedulableBuilder.buildPools() - - assert(rootPool.getSchedulableByName("default") != null) - assert(rootPool.getSchedulableByName("1") != null) - assert(rootPool.getSchedulableByName("2") != null) - assert(rootPool.getSchedulableByName("3") != null) - assert(rootPool.getSchedulableByName("1").minShare === 2) - assert(rootPool.getSchedulableByName("1").weight === 1) - assert(rootPool.getSchedulableByName("2").minShare === 3) - assert(rootPool.getSchedulableByName("2").weight === 1) - assert(rootPool.getSchedulableByName("3").minShare === 0) - assert(rootPool.getSchedulableByName("3").weight === 1) - - val properties1 = new Properties() - properties1.setProperty("spark.scheduler.pool","1") - val properties2 = new Properties() - properties2.setProperty("spark.scheduler.pool","2") - - val taskSetManager10 = createDummyTaskSetManager(1, 0, 1, clusterScheduler, taskSet) - val taskSetManager11 = createDummyTaskSetManager(1, 1, 1, clusterScheduler, taskSet) - val taskSetManager12 = createDummyTaskSetManager(1, 2, 2, clusterScheduler, taskSet) - schedulableBuilder.addTaskSetManager(taskSetManager10, properties1) - schedulableBuilder.addTaskSetManager(taskSetManager11, properties1) - schedulableBuilder.addTaskSetManager(taskSetManager12, properties1) - - val taskSetManager23 = createDummyTaskSetManager(2, 3, 2, clusterScheduler, taskSet) - val taskSetManager24 = createDummyTaskSetManager(2, 4, 2, clusterScheduler, taskSet) - schedulableBuilder.addTaskSetManager(taskSetManager23, properties2) - schedulableBuilder.addTaskSetManager(taskSetManager24, properties2) - - checkTaskSetId(rootPool, 0) - checkTaskSetId(rootPool, 3) - checkTaskSetId(rootPool, 3) - checkTaskSetId(rootPool, 1) - checkTaskSetId(rootPool, 4) - checkTaskSetId(rootPool, 2) - checkTaskSetId(rootPool, 2) - checkTaskSetId(rootPool, 4) - - taskSetManager12.taskFinished() - assert(rootPool.getSchedulableByName("1").runningTasks === 3) - taskSetManager24.abort() - assert(rootPool.getSchedulableByName("2").runningTasks === 2) - } - - test("Nested Pool Test") { - sc = new SparkContext("local", "ClusterSchedulerSuite") - val clusterScheduler = new ClusterScheduler(sc) - var tasks = ArrayBuffer[Task[_]]() - val task = new FakeTask(0) - tasks += task - val taskSet = new TaskSet(tasks.toArray,0,0,0,null) - - val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) - val pool0 = new Pool("0", SchedulingMode.FAIR, 3, 1) - val pool1 = new Pool("1", SchedulingMode.FAIR, 4, 1) - rootPool.addSchedulable(pool0) - rootPool.addSchedulable(pool1) - - val pool00 = new Pool("00", SchedulingMode.FAIR, 2, 2) - val pool01 = new Pool("01", SchedulingMode.FAIR, 1, 1) - pool0.addSchedulable(pool00) - pool0.addSchedulable(pool01) - - val pool10 = new Pool("10", SchedulingMode.FAIR, 2, 2) - val pool11 = new Pool("11", SchedulingMode.FAIR, 2, 1) - pool1.addSchedulable(pool10) - pool1.addSchedulable(pool11) - - val taskSetManager000 = createDummyTaskSetManager(0, 0, 5, clusterScheduler, taskSet) - val taskSetManager001 = createDummyTaskSetManager(0, 1, 5, clusterScheduler, taskSet) - pool00.addSchedulable(taskSetManager000) - pool00.addSchedulable(taskSetManager001) - - val taskSetManager010 = createDummyTaskSetManager(1, 2, 5, clusterScheduler, taskSet) - val taskSetManager011 = createDummyTaskSetManager(1, 3, 5, clusterScheduler, taskSet) - pool01.addSchedulable(taskSetManager010) - pool01.addSchedulable(taskSetManager011) - - val taskSetManager100 = createDummyTaskSetManager(2, 4, 5, clusterScheduler, taskSet) - val taskSetManager101 = createDummyTaskSetManager(2, 5, 5, clusterScheduler, taskSet) - pool10.addSchedulable(taskSetManager100) - pool10.addSchedulable(taskSetManager101) - - val taskSetManager110 = createDummyTaskSetManager(3, 6, 5, clusterScheduler, taskSet) - val taskSetManager111 = createDummyTaskSetManager(3, 7, 5, clusterScheduler, taskSet) - pool11.addSchedulable(taskSetManager110) - pool11.addSchedulable(taskSetManager111) - - checkTaskSetId(rootPool, 0) - checkTaskSetId(rootPool, 4) - checkTaskSetId(rootPool, 6) - checkTaskSetId(rootPool, 2) - } -} diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala deleted file mode 100644 index b97f2b19b5..0000000000 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala +++ /dev/null @@ -1,318 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.cluster - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable - -import org.scalatest.FunSuite - -import org.apache.spark._ -import org.apache.spark.scheduler._ -import org.apache.spark.executor.TaskMetrics -import java.nio.ByteBuffer -import org.apache.spark.util.{Utils, FakeClock} - -class FakeDAGScheduler(taskScheduler: FakeClusterScheduler) extends DAGScheduler(taskScheduler) { - override def taskStarted(task: Task[_], taskInfo: TaskInfo) { - taskScheduler.startedTasks += taskInfo.index - } - - override def taskEnded( - task: Task[_], - reason: TaskEndReason, - result: Any, - accumUpdates: mutable.Map[Long, Any], - taskInfo: TaskInfo, - taskMetrics: TaskMetrics) { - taskScheduler.endedTasks(taskInfo.index) = reason - } - - override def executorGained(execId: String, host: String) {} - - override def executorLost(execId: String) {} - - override def taskSetFailed(taskSet: TaskSet, reason: String) { - taskScheduler.taskSetsFailed += taskSet.id - } -} - -/** - * A mock ClusterScheduler implementation that just remembers information about tasks started and - * feedback received from the TaskSetManagers. Note that it's important to initialize this with - * a list of "live" executors and their hostnames for isExecutorAlive and hasExecutorsAliveOnHost - * to work, and these are required for locality in ClusterTaskSetManager. - */ -class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /* execId, host */) - extends ClusterScheduler(sc) -{ - val startedTasks = new ArrayBuffer[Long] - val endedTasks = new mutable.HashMap[Long, TaskEndReason] - val finishedManagers = new ArrayBuffer[TaskSetManager] - val taskSetsFailed = new ArrayBuffer[String] - - val executors = new mutable.HashMap[String, String] ++ liveExecutors - - dagScheduler = new FakeDAGScheduler(this) - - def removeExecutor(execId: String): Unit = executors -= execId - - override def taskSetFinished(manager: TaskSetManager): Unit = finishedManagers += manager - - override def isExecutorAlive(execId: String): Boolean = executors.contains(execId) - - override def hasExecutorsAliveOnHost(host: String): Boolean = executors.values.exists(_ == host) -} - -class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { - import TaskLocality.{ANY, PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL} - - val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong - - test("TaskSet with no preferences") { - sc = new SparkContext("local", "test") - val sched = new FakeClusterScheduler(sc, ("exec1", "host1")) - val taskSet = createTaskSet(1) - val manager = new ClusterTaskSetManager(sched, taskSet) - - // Offer a host with no CPUs - assert(manager.resourceOffer("exec1", "host1", 0, ANY) === None) - - // Offer a host with process-local as the constraint; this should work because the TaskSet - // above won't have any locality preferences - val taskOption = manager.resourceOffer("exec1", "host1", 2, TaskLocality.PROCESS_LOCAL) - assert(taskOption.isDefined) - val task = taskOption.get - assert(task.executorId === "exec1") - assert(sched.startedTasks.contains(0)) - - // Re-offer the host -- now we should get no more tasks - assert(manager.resourceOffer("exec1", "host1", 2, PROCESS_LOCAL) === None) - - // Tell it the task has finished - manager.handleSuccessfulTask(0, createTaskResult(0)) - assert(sched.endedTasks(0) === Success) - assert(sched.finishedManagers.contains(manager)) - } - - test("multiple offers with no preferences") { - sc = new SparkContext("local", "test") - val sched = new FakeClusterScheduler(sc, ("exec1", "host1")) - val taskSet = createTaskSet(3) - val manager = new ClusterTaskSetManager(sched, taskSet) - - // First three offers should all find tasks - for (i <- 0 until 3) { - val taskOption = manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) - assert(taskOption.isDefined) - val task = taskOption.get - assert(task.executorId === "exec1") - } - assert(sched.startedTasks.toSet === Set(0, 1, 2)) - - // Re-offer the host -- now we should get no more tasks - assert(manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) === None) - - // Finish the first two tasks - manager.handleSuccessfulTask(0, createTaskResult(0)) - manager.handleSuccessfulTask(1, createTaskResult(1)) - assert(sched.endedTasks(0) === Success) - assert(sched.endedTasks(1) === Success) - assert(!sched.finishedManagers.contains(manager)) - - // Finish the last task - manager.handleSuccessfulTask(2, createTaskResult(2)) - assert(sched.endedTasks(2) === Success) - assert(sched.finishedManagers.contains(manager)) - } - - test("basic delay scheduling") { - sc = new SparkContext("local", "test") - val sched = new FakeClusterScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) - val taskSet = createTaskSet(4, - Seq(TaskLocation("host1", "exec1")), - Seq(TaskLocation("host2", "exec2")), - Seq(TaskLocation("host1"), TaskLocation("host2", "exec2")), - Seq() // Last task has no locality prefs - ) - val clock = new FakeClock - val manager = new ClusterTaskSetManager(sched, taskSet, clock) - - // First offer host1, exec1: first task should be chosen - assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) - - // Offer host1, exec1 again: the last task, which has no prefs, should be chosen - assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 3) - - // Offer host1, exec1 again, at PROCESS_LOCAL level: nothing should get chosen - assert(manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) === None) - - clock.advance(LOCALITY_WAIT) - - // Offer host1, exec1 again, at PROCESS_LOCAL level: nothing should get chosen - assert(manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) === None) - - // Offer host1, exec1 again, at NODE_LOCAL level: we should choose task 2 - assert(manager.resourceOffer("exec1", "host1", 1, NODE_LOCAL).get.index == 2) - - // Offer host1, exec1 again, at NODE_LOCAL level: nothing should get chosen - assert(manager.resourceOffer("exec1", "host1", 1, NODE_LOCAL) === None) - - // Offer host1, exec1 again, at ANY level: nothing should get chosen - assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None) - - clock.advance(LOCALITY_WAIT) - - // Offer host1, exec1 again, at ANY level: task 1 should get chosen - assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 1) - - // Offer host1, exec1 again, at ANY level: nothing should be chosen as we've launched all tasks - assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None) - } - - test("delay scheduling with fallback") { - sc = new SparkContext("local", "test") - val sched = new FakeClusterScheduler(sc, - ("exec1", "host1"), ("exec2", "host2"), ("exec3", "host3")) - val taskSet = createTaskSet(5, - Seq(TaskLocation("host1")), - Seq(TaskLocation("host2")), - Seq(TaskLocation("host2")), - Seq(TaskLocation("host3")), - Seq(TaskLocation("host2")) - ) - val clock = new FakeClock - val manager = new ClusterTaskSetManager(sched, taskSet, clock) - - // First offer host1: first task should be chosen - assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) - - // Offer host1 again: nothing should get chosen - assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None) - - clock.advance(LOCALITY_WAIT) - - // Offer host1 again: second task (on host2) should get chosen - assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 1) - - // Offer host1 again: third task (on host2) should get chosen - assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 2) - - // Offer host2: fifth task (also on host2) should get chosen - assert(manager.resourceOffer("exec2", "host2", 1, ANY).get.index === 4) - - // Now that we've launched a local task, we should no longer launch the task for host3 - assert(manager.resourceOffer("exec2", "host2", 1, ANY) === None) - - clock.advance(LOCALITY_WAIT) - - // After another delay, we can go ahead and launch that task non-locally - assert(manager.resourceOffer("exec2", "host2", 1, ANY).get.index === 3) - } - - test("delay scheduling with failed hosts") { - sc = new SparkContext("local", "test") - val sched = new FakeClusterScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) - val taskSet = createTaskSet(3, - Seq(TaskLocation("host1")), - Seq(TaskLocation("host2")), - Seq(TaskLocation("host3")) - ) - val clock = new FakeClock - val manager = new ClusterTaskSetManager(sched, taskSet, clock) - - // First offer host1: first task should be chosen - assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) - - // Offer host1 again: third task should be chosen immediately because host3 is not up - assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 2) - - // After this, nothing should get chosen - assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None) - - // Now mark host2 as dead - sched.removeExecutor("exec2") - manager.executorLost("exec2", "host2") - - // Task 1 should immediately be launched on host1 because its original host is gone - assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 1) - - // Now that all tasks have launched, nothing new should be launched anywhere else - assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None) - assert(manager.resourceOffer("exec2", "host2", 1, ANY) === None) - } - - test("task result lost") { - sc = new SparkContext("local", "test") - val sched = new FakeClusterScheduler(sc, ("exec1", "host1")) - val taskSet = createTaskSet(1) - val clock = new FakeClock - val manager = new ClusterTaskSetManager(sched, taskSet, clock) - - assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) - - // Tell it the task has finished but the result was lost. - manager.handleFailedTask(0, TaskState.FINISHED, Some(TaskResultLost)) - assert(sched.endedTasks(0) === TaskResultLost) - - // Re-offer the host -- now we should get task 0 again. - assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) - } - - test("repeated failures lead to task set abortion") { - sc = new SparkContext("local", "test") - val sched = new FakeClusterScheduler(sc, ("exec1", "host1")) - val taskSet = createTaskSet(1) - val clock = new FakeClock - val manager = new ClusterTaskSetManager(sched, taskSet, clock) - - // Fail the task MAX_TASK_FAILURES times, and check that the task set is aborted - // after the last failure. - (0 until manager.MAX_TASK_FAILURES).foreach { index => - val offerResult = manager.resourceOffer("exec1", "host1", 1, ANY) - assert(offerResult != None, - "Expect resource offer on iteration %s to return a task".format(index)) - assert(offerResult.get.index === 0) - manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, Some(TaskResultLost)) - if (index < manager.MAX_TASK_FAILURES) { - assert(!sched.taskSetsFailed.contains(taskSet.id)) - } else { - assert(sched.taskSetsFailed.contains(taskSet.id)) - } - } - } - - - /** - * Utility method to create a TaskSet, potentially setting a particular sequence of preferred - * locations for each task (given as varargs) if this sequence is not empty. - */ - def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = { - if (prefLocs.size != 0 && prefLocs.size != numTasks) { - throw new IllegalArgumentException("Wrong number of task locations") - } - val tasks = Array.tabulate[Task[_]](numTasks) { i => - new FakeTask(i, if (prefLocs.size != 0) prefLocs(i) else Nil) - } - new TaskSet(tasks, 0, 0, 0, null) - } - - def createTaskResult(id: Int): DirectTaskResult[Int] = { - new DirectTaskResult[Int](id, mutable.Map.empty, new TaskMetrics) - } -} diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala deleted file mode 100644 index 0f01515179..0000000000 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.cluster - -import org.apache.spark.TaskContext -import org.apache.spark.scheduler.{TaskLocation, Task} - -class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0) { - override def runTask(context: TaskContext): Int = 0 - - override def preferredLocations: Seq[TaskLocation] = prefLocs -} diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala deleted file mode 100644 index 77d3038614..0000000000 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.cluster - -import java.nio.ByteBuffer - -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} - -import org.apache.spark.{LocalSparkContext, SparkContext, SparkEnv} -import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult} -import org.apache.spark.storage.TaskResultBlockId - -/** - * Removes the TaskResult from the BlockManager before delegating to a normal TaskResultGetter. - * - * Used to test the case where a BlockManager evicts the task result (or dies) before the - * TaskResult is retrieved. - */ -class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler) - extends TaskResultGetter(sparkEnv, scheduler) { - var removedResult = false - - override def enqueueSuccessfulTask( - taskSetManager: ClusterTaskSetManager, tid: Long, serializedData: ByteBuffer) { - if (!removedResult) { - // Only remove the result once, since we'd like to test the case where the task eventually - // succeeds. - serializer.get().deserialize[TaskResult[_]](serializedData) match { - case IndirectTaskResult(blockId) => - sparkEnv.blockManager.master.removeBlock(blockId) - case directResult: DirectTaskResult[_] => - taskSetManager.abort("Internal error: expect only indirect results") - } - serializedData.rewind() - removedResult = true - } - super.enqueueSuccessfulTask(taskSetManager, tid, serializedData) - } -} - -/** - * Tests related to handling task results (both direct and indirect). - */ -class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndAfterAll - with LocalSparkContext { - - override def beforeAll { - // Set the Akka frame size to be as small as possible (it must be an integer, so 1 is as small - // as we can make it) so the tests don't take too long. - System.setProperty("spark.akka.frameSize", "1") - } - - before { - sc = new SparkContext("local", "test") - } - - override def afterAll { - System.clearProperty("spark.akka.frameSize") - } - - test("handling results smaller than Akka frame size") { - val result = sc.parallelize(Seq(1), 1).map(x => 2 * x).reduce((x, y) => x) - assert(result === 2) - } - - test("handling results larger than Akka frame size") { - val akkaFrameSize = - sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt - val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x) - assert(result === 1.to(akkaFrameSize).toArray) - - val RESULT_BLOCK_ID = TaskResultBlockId(0) - assert(sc.env.blockManager.master.getLocations(RESULT_BLOCK_ID).size === 0, - "Expect result to be removed from the block manager.") - } - - test("task retried if result missing from block manager") { - // If this test hangs, it's probably because no resource offers were made after the task - // failed. - val scheduler: ClusterScheduler = sc.taskScheduler match { - case clusterScheduler: ClusterScheduler => - clusterScheduler - case _ => - assert(false, "Expect local cluster to use ClusterScheduler") - throw new ClassCastException - } - scheduler.taskResultGetter = new ResultDeletingTaskResultGetter(sc.env, scheduler) - val akkaFrameSize = - sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt - val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x) - assert(result === 1.to(akkaFrameSize).toArray) - - // Make sure two tasks were run (one failed one, and a second retried one). - assert(scheduler.nextTaskId.get() === 2) - } -} - -- cgit v1.2.3 From fb64828b0b573f3a77938592f168af7aa3a2b6c5 Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Thu, 31 Oct 2013 23:42:56 -0700 Subject: Cleaned up imports and fixed test bug --- .../main/scala/org/apache/spark/scheduler/TaskScheduler.scala | 3 +-- .../main/scala/org/apache/spark/scheduler/TaskSetManager.scala | 1 - .../scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala | 9 +++++---- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 3f694dd25d..b4ec695ece 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -27,7 +27,6 @@ import scala.collection.mutable.HashSet import org.apache.spark._ import org.apache.spark.TaskState.TaskState -import org.apache.spark.scheduler._ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode /** @@ -449,7 +448,7 @@ private[spark] class TaskScheduler(val sc: SparkContext, isLocal: Boolean = fals } -object TaskScheduler { +private[spark] object TaskScheduler { /** * Used to balance containers across hosts. * diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 13271b10f3..90b6519027 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -28,7 +28,6 @@ import scala.math.min import org.apache.spark.{ExceptionFailure, FetchFailed, Logging, Resubmitted, SparkEnv, Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState} import org.apache.spark.TaskState.TaskState -import org.apache.spark.scheduler._ import org.apache.spark.util.{SystemClock, Clock} diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 394a1bb06f..5b5a2178f3 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -36,14 +36,15 @@ import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} /** * TaskScheduler that records the task sets that the DAGScheduler requested executed. */ -class TaskSetRecordingTaskScheduler(sc: SparkContext) extends TaskScheduler(sc) { +class TaskSetRecordingTaskScheduler(sc: SparkContext, + mapOutputTrackerMaster: MapOutputTrackerMaster) extends TaskScheduler(sc) { /** Set of TaskSets the DAGScheduler has requested executed. */ val taskSets = scala.collection.mutable.Buffer[TaskSet]() override def start() = {} override def stop() = {} override def submitTasks(taskSet: TaskSet) = { // normally done by TaskSetManager - taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch) + taskSet.tasks.foreach(_.epoch = mapOutputTrackerMaster.getEpoch) taskSets += taskSet } override def cancelTasks(stageId: Int) {} @@ -97,11 +98,11 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont before { sc = new SparkContext("local", "DAGSchedulerSuite") - taskScheduler = new TaskSetRecordingTaskScheduler(sc) + mapOutputTracker = new MapOutputTrackerMaster() + taskScheduler = new TaskSetRecordingTaskScheduler(sc, mapOutputTracker) taskScheduler.taskSets.clear() cacheLocations.clear() results.clear() - mapOutputTracker = new MapOutputTrackerMaster() scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null) { override def runLocally(job: ActiveJob) { // don't bother with the thread while unit testing -- cgit v1.2.3 From 68e5ad58b7e7e3e1b42852de8d0fdf9e9b9c1a14 Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Wed, 13 Nov 2013 14:32:50 -0800 Subject: Extracted TaskScheduler interface. Also changed the default maximum number of task failures to be 0 when running in local mode. --- .../main/scala/org/apache/spark/SparkContext.scala | 14 +- .../apache/spark/scheduler/ClusterScheduler.scala | 493 +++++++++++++++++++++ .../apache/spark/scheduler/TaskResultGetter.scala | 2 +- .../org/apache/spark/scheduler/TaskScheduler.scala | 492 -------------------- .../apache/spark/scheduler/TaskSetManager.scala | 17 +- .../cluster/CoarseGrainedSchedulerBackend.scala | 4 +- .../scheduler/cluster/SimrSchedulerBackend.scala | 4 +- .../cluster/SparkDeploySchedulerBackend.scala | 4 +- .../mesos/CoarseMesosSchedulerBackend.scala | 4 +- .../cluster/mesos/MesosSchedulerBackend.scala | 4 +- .../spark/scheduler/local/LocalBackend.scala | 9 +- .../spark/scheduler/ClusterSchedulerSuite.scala | 265 +++++++++++ .../apache/spark/scheduler/DAGSchedulerSuite.scala | 2 +- .../spark/scheduler/TaskResultGetterSuite.scala | 6 +- .../spark/scheduler/TaskSchedulerSuite.scala | 265 ----------- .../spark/scheduler/TaskSetManagerSuite.scala | 21 +- 16 files changed, 806 insertions(+), 800 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala delete mode 100644 core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala create mode 100644 core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 1850436ff2..e8ff4da475 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -159,26 +159,26 @@ class SparkContext( master match { case "local" => - val scheduler = new TaskScheduler(this) + val scheduler = new ClusterScheduler(this, isLocal = true) val backend = new LocalBackend(scheduler, 1) scheduler.initialize(backend) scheduler case LOCAL_N_REGEX(threads) => - val scheduler = new TaskScheduler(this) + val scheduler = new ClusterScheduler(this, isLocal = true) val backend = new LocalBackend(scheduler, threads.toInt) scheduler.initialize(backend) scheduler case SPARK_REGEX(sparkUrl) => - val scheduler = new TaskScheduler(this) + val scheduler = new ClusterScheduler(this) val masterUrls = sparkUrl.split(",").map("spark://" + _) val backend = new SparkDeploySchedulerBackend(scheduler, this, masterUrls, appName) scheduler.initialize(backend) scheduler case SIMR_REGEX(simrUrl) => - val scheduler = new TaskScheduler(this) + val scheduler = new ClusterScheduler(this) val backend = new SimrSchedulerBackend(scheduler, this, simrUrl) scheduler.initialize(backend) scheduler @@ -192,7 +192,7 @@ class SparkContext( memoryPerSlaveInt, SparkContext.executorMemoryRequested)) } - val scheduler = new TaskScheduler(this) + val scheduler = new ClusterScheduler(this, isLocal = true) val localCluster = new LocalSparkCluster( numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt) val masterUrls = localCluster.start() @@ -207,7 +207,7 @@ class SparkContext( val scheduler = try { val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler") val cons = clazz.getConstructor(classOf[SparkContext]) - cons.newInstance(this).asInstanceOf[TaskScheduler] + cons.newInstance(this).asInstanceOf[ClusterScheduler] } catch { // TODO: Enumerate the exact reasons why it can fail // But irrespective of it, it means we cannot proceed ! @@ -221,7 +221,7 @@ class SparkContext( case MESOS_REGEX(mesosUrl) => MesosNativeLibrary.load() - val scheduler = new TaskScheduler(this) + val scheduler = new ClusterScheduler(this) val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean val backend = if (coarseGrained) { new CoarseMesosSchedulerBackend(scheduler, this, mesosUrl, appName) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala new file mode 100644 index 0000000000..c7d1295215 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala @@ -0,0 +1,493 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.nio.ByteBuffer +import java.util.concurrent.atomic.AtomicLong +import java.util.{TimerTask, Timer} + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet + +import org.apache.spark._ +import org.apache.spark.TaskState.TaskState +import org.apache.spark.scheduler.SchedulingMode.SchedulingMode + +/** + * Schedules tasks for multiple types of clusters by acting through a SchedulerBackend. + * It can also work with a local setup by using a LocalBackend and setting isLocal to true. + * It handles common logic, like determining a scheduling order across jobs, waking up to launch + * speculative tasks, etc. + * + * Clients should first call initialize() and start(), then submit task sets through the + * runTasks method. + * + * THREADING: SchedulerBackends and task-submitting clients can call this class from multiple + * threads, so it needs locks in public API methods to maintain its state. In addition, some + * SchedulerBackends sycnchronize on themselves when they want to send events here, and then + * acquire a lock on us, so we need to make sure that we don't try to lock the backend while + * we are holding a lock on ourselves. + */ +private[spark] class ClusterScheduler(val sc: SparkContext, isLocal: Boolean = false) + extends TaskScheduler with Logging { + + // How often to check for speculative tasks + val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong + + // Threshold above which we warn user initial TaskSet may be starved + val STARVATION_TIMEOUT = System.getProperty("spark.starvation.timeout", "15000").toLong + + // TaskSetManagers are not thread safe, so any access to one should be synchronized + // on this class. + val activeTaskSets = new HashMap[String, TaskSetManager] + + val MAX_TASK_FAILURES = { + if (isLocal) { + // No sense in retrying if all tasks run locally! + 0 + } else { + System.getProperty("spark.task.maxFailures", "4").toInt + } + } + + val taskIdToTaskSetId = new HashMap[Long, String] + val taskIdToExecutorId = new HashMap[Long, String] + val taskSetTaskIds = new HashMap[String, HashSet[Long]] + + @volatile private var hasReceivedTask = false + @volatile private var hasLaunchedTask = false + private val starvationTimer = new Timer(true) + + // Incrementing task IDs + val nextTaskId = new AtomicLong(0) + + // Which executor IDs we have executors on + val activeExecutorIds = new HashSet[String] + + // The set of executors we have on each host; this is used to compute hostsAlive, which + // in turn is used to decide when we can attain data locality on a given host + private val executorsByHost = new HashMap[String, HashSet[String]] + + private val executorIdToHost = new HashMap[String, String] + + // Listener object to pass upcalls into + var dagScheduler: DAGScheduler = null + + var backend: SchedulerBackend = null + + val mapOutputTracker = SparkEnv.get.mapOutputTracker + + var schedulableBuilder: SchedulableBuilder = null + var rootPool: Pool = null + // default scheduler is FIFO + val schedulingMode: SchedulingMode = SchedulingMode.withName( + System.getProperty("spark.scheduler.mode", "FIFO")) + + // This is a var so that we can reset it for testing purposes. + private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this) + + override def setDAGScheduler(dagScheduler: DAGScheduler) { + this.dagScheduler = dagScheduler + } + + def initialize(context: SchedulerBackend) { + backend = context + // temporarily set rootPool name to empty + rootPool = new Pool("", schedulingMode, 0, 0) + schedulableBuilder = { + schedulingMode match { + case SchedulingMode.FIFO => + new FIFOSchedulableBuilder(rootPool) + case SchedulingMode.FAIR => + new FairSchedulableBuilder(rootPool) + } + } + schedulableBuilder.buildPools() + } + + def newTaskId(): Long = nextTaskId.getAndIncrement() + + override def start() { + backend.start() + + if (!isLocal && System.getProperty("spark.speculation", "false").toBoolean) { + new Thread("TaskScheduler speculation check") { + setDaemon(true) + + override def run() { + logInfo("Starting speculative execution thread") + while (true) { + try { + Thread.sleep(SPECULATION_INTERVAL) + } catch { + case e: InterruptedException => {} + } + checkSpeculatableTasks() + } + } + }.start() + } + } + + override def submitTasks(taskSet: TaskSet) { + val tasks = taskSet.tasks + logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") + this.synchronized { + val manager = new TaskSetManager(this, taskSet, MAX_TASK_FAILURES) + activeTaskSets(taskSet.id) = manager + schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) + taskSetTaskIds(taskSet.id) = new HashSet[Long]() + + if (!isLocal && !hasReceivedTask) { + starvationTimer.scheduleAtFixedRate(new TimerTask() { + override def run() { + if (!hasLaunchedTask) { + logWarning("Initial job has not accepted any resources; " + + "check your cluster UI to ensure that workers are registered " + + "and have sufficient memory") + } else { + this.cancel() + } + } + }, STARVATION_TIMEOUT, STARVATION_TIMEOUT) + } + hasReceivedTask = true + } + backend.reviveOffers() + } + + override def cancelTasks(stageId: Int): Unit = synchronized { + logInfo("Cancelling stage " + stageId) + activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) => + // There are two possible cases here: + // 1. The task set manager has been created and some tasks have been scheduled. + // In this case, send a kill signal to the executors to kill the task and then abort + // the stage. + // 2. The task set manager has been created but no tasks has been scheduled. In this case, + // simply abort the stage. + val taskIds = taskSetTaskIds(tsm.taskSet.id) + if (taskIds.size > 0) { + taskIds.foreach { tid => + val execId = taskIdToExecutorId(tid) + backend.killTask(tid, execId) + } + } + tsm.error("Stage %d was cancelled".format(stageId)) + } + } + + def taskSetFinished(manager: TaskSetManager): Unit = synchronized { + // Check to see if the given task set has been removed. This is possible in the case of + // multiple unrecoverable task failures (e.g. if the entire task set is killed when it has + // more than one running tasks). + if (activeTaskSets.contains(manager.taskSet.id)) { + activeTaskSets -= manager.taskSet.id + manager.parent.removeSchedulable(manager) + logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name)) + taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) + taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id) + taskSetTaskIds.remove(manager.taskSet.id) + } + } + + /** + * Called by cluster manager to offer resources on slaves. We respond by asking our active task + * sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so + * that tasks are balanced across the cluster. + */ + def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized { + SparkEnv.set(sc.env) + + // Mark each slave as alive and remember its hostname + for (o <- offers) { + executorIdToHost(o.executorId) = o.host + if (!executorsByHost.contains(o.host)) { + executorsByHost(o.host) = new HashSet[String]() + executorGained(o.executorId, o.host) + } + } + + // Build a list of tasks to assign to each worker + val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores)) + val availableCpus = offers.map(o => o.cores).toArray + val sortedTaskSets = rootPool.getSortedTaskSetQueue() + for (taskSet <- sortedTaskSets) { + logDebug("parentName: %s, name: %s, runningTasks: %s".format( + taskSet.parent.name, taskSet.name, taskSet.runningTasks)) + } + + // Take each TaskSet in our scheduling order, and then offer it each node in increasing order + // of locality levels so that it gets a chance to launch local tasks on all of them. + var launchedTask = false + for (taskSet <- sortedTaskSets; maxLocality <- TaskLocality.values) { + do { + launchedTask = false + for (i <- 0 until offers.size) { + val execId = offers(i).executorId + val host = offers(i).host + for (task <- taskSet.resourceOffer(execId, host, availableCpus(i), maxLocality)) { + tasks(i) += task + val tid = task.taskId + taskIdToTaskSetId(tid) = taskSet.taskSet.id + taskSetTaskIds(taskSet.taskSet.id) += tid + taskIdToExecutorId(tid) = execId + activeExecutorIds += execId + executorsByHost(host) += execId + availableCpus(i) -= 1 + launchedTask = true + } + } + } while (launchedTask) + } + + if (tasks.size > 0) { + hasLaunchedTask = true + } + return tasks + } + + def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { + var failedExecutor: Option[String] = None + var taskFailed = false + synchronized { + try { + if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) { + // We lost this entire executor, so remember that it's gone + val execId = taskIdToExecutorId(tid) + if (activeExecutorIds.contains(execId)) { + removeExecutor(execId) + failedExecutor = Some(execId) + } + } + taskIdToTaskSetId.get(tid) match { + case Some(taskSetId) => + if (TaskState.isFinished(state)) { + taskIdToTaskSetId.remove(tid) + if (taskSetTaskIds.contains(taskSetId)) { + taskSetTaskIds(taskSetId) -= tid + } + taskIdToExecutorId.remove(tid) + } + if (state == TaskState.FAILED) { + taskFailed = true + } + activeTaskSets.get(taskSetId).foreach { taskSet => + if (state == TaskState.FINISHED) { + taskSet.removeRunningTask(tid) + taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData) + } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) { + taskSet.removeRunningTask(tid) + taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData) + } + } + case None => + logInfo("Ignoring update from TID " + tid + " because its task set is gone") + } + } catch { + case e: Exception => logError("Exception in statusUpdate", e) + } + } + // Update the DAGScheduler without holding a lock on this, since that can deadlock + if (failedExecutor != None) { + dagScheduler.executorLost(failedExecutor.get) + backend.reviveOffers() + } + if (taskFailed) { + // Also revive offers if a task had failed for some reason other than host lost + backend.reviveOffers() + } + } + + def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long) { + taskSetManager.handleTaskGettingResult(tid) + } + + def handleSuccessfulTask( + taskSetManager: TaskSetManager, + tid: Long, + taskResult: DirectTaskResult[_]) = synchronized { + taskSetManager.handleSuccessfulTask(tid, taskResult) + } + + def handleFailedTask( + taskSetManager: TaskSetManager, + tid: Long, + taskState: TaskState, + reason: Option[TaskEndReason]) = synchronized { + taskSetManager.handleFailedTask(tid, taskState, reason) + if (taskState == TaskState.FINISHED) { + // The task finished successfully but the result was lost, so we should revive offers. + backend.reviveOffers() + } + } + + def error(message: String) { + synchronized { + if (activeTaskSets.size > 0) { + // Have each task set throw a SparkException with the error + for ((taskSetId, manager) <- activeTaskSets) { + try { + manager.error(message) + } catch { + case e: Exception => logError("Exception in error callback", e) + } + } + } else { + // No task sets are active but we still got an error. Just exit since this + // must mean the error is during registration. + // It might be good to do something smarter here in the future. + logError("Exiting due to error from task scheduler: " + message) + System.exit(1) + } + } + } + + override def stop() { + if (backend != null) { + backend.stop() + } + if (taskResultGetter != null) { + taskResultGetter.stop() + } + + // sleeping for an arbitrary 5 seconds : to ensure that messages are sent out. + // TODO: Do something better ! + Thread.sleep(5000L) + } + + override def defaultParallelism() = backend.defaultParallelism() + + // Check for speculatable tasks in all our active jobs. + def checkSpeculatableTasks() { + var shouldRevive = false + synchronized { + shouldRevive = rootPool.checkSpeculatableTasks() + } + if (shouldRevive) { + backend.reviveOffers() + } + } + + // Check for pending tasks in all our active jobs. + def hasPendingTasks: Boolean = { + synchronized { + rootPool.hasPendingTasks() + } + } + + def executorLost(executorId: String, reason: ExecutorLossReason) { + var failedExecutor: Option[String] = None + + synchronized { + if (activeExecutorIds.contains(executorId)) { + val hostPort = executorIdToHost(executorId) + logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason)) + removeExecutor(executorId) + failedExecutor = Some(executorId) + } else { + // We may get multiple executorLost() calls with different loss reasons. For example, one + // may be triggered by a dropped connection from the slave while another may be a report + // of executor termination from Mesos. We produce log messages for both so we eventually + // report the termination reason. + logError("Lost an executor " + executorId + " (already removed): " + reason) + } + } + // Call dagScheduler.executorLost without holding the lock on this to prevent deadlock + if (failedExecutor != None) { + dagScheduler.executorLost(failedExecutor.get) + backend.reviveOffers() + } + } + + /** Remove an executor from all our data structures and mark it as lost */ + private def removeExecutor(executorId: String) { + activeExecutorIds -= executorId + val host = executorIdToHost(executorId) + val execs = executorsByHost.getOrElse(host, new HashSet) + execs -= executorId + if (execs.isEmpty) { + executorsByHost -= host + } + executorIdToHost -= executorId + rootPool.executorLost(executorId, host) + } + + def executorGained(execId: String, host: String) { + dagScheduler.executorGained(execId, host) + } + + def getExecutorsAliveOnHost(host: String): Option[Set[String]] = synchronized { + executorsByHost.get(host).map(_.toSet) + } + + def hasExecutorsAliveOnHost(host: String): Boolean = synchronized { + executorsByHost.contains(host) + } + + def isExecutorAlive(execId: String): Boolean = synchronized { + activeExecutorIds.contains(execId) + } + + // By default, rack is unknown + def getRackForHost(value: String): Option[String] = None +} + + +private[spark] object ClusterScheduler { + /** + * Used to balance containers across hosts. + * + * Accepts a map of hosts to resource offers for that host, and returns a prioritized list of + * resource offers representing the order in which the offers should be used. The resource + * offers are ordered such that we'll allocate one container on each host before allocating a + * second container on any host, and so on, in order to reduce the damage if a host fails. + * + * For example, given , , , returns + * [o1, o5, o4, 02, o6, o3] + */ + def prioritizeContainers[K, T] (map: HashMap[K, ArrayBuffer[T]]): List[T] = { + val _keyList = new ArrayBuffer[K](map.size) + _keyList ++= map.keys + + // order keyList based on population of value in map + val keyList = _keyList.sortWith( + (left, right) => map(left).size > map(right).size + ) + + val retval = new ArrayBuffer[T](keyList.size * 2) + var index = 0 + var found = true + + while (found) { + found = false + for (key <- keyList) { + val containerList: ArrayBuffer[T] = map.get(key).getOrElse(null) + assert(containerList != null) + // Get the index'th entry for this host - if present + if (index < containerList.size){ + retval += containerList.apply(index) + found = true + } + } + index += 1 + } + + retval.toList + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 5408fa7353..a77ff35323 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -28,7 +28,7 @@ import org.apache.spark.util.Utils /** * Runs a thread pool that deserializes and remotely fetches (if necessary) task results. */ -private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskScheduler) +private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler) extends Logging { private val THREADS = System.getProperty("spark.resultGetter.threads", "4").toInt private val getTaskResultExecutor = Utils.newDaemonFixedThreadPool( diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala deleted file mode 100644 index b4ec695ece..0000000000 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ /dev/null @@ -1,492 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler - -import java.nio.ByteBuffer -import java.util.concurrent.atomic.AtomicLong -import java.util.{TimerTask, Timer} - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet - -import org.apache.spark._ -import org.apache.spark.TaskState.TaskState -import org.apache.spark.scheduler.SchedulingMode.SchedulingMode - -/** - * Schedules tasks for a single SparkContext. Receives a set of tasks from the DAGScheduler for - * each stage, and is responsible for sending tasks to executors, running them, retrying if there - * are failures, and mitigating stragglers. Returns events to the DAGScheduler. - * - * Clients should first call initialize() and start(), then submit task sets through the - * runTasks method. - * - * This class can work with multiple types of clusters by acting through a SchedulerBackend. - * It can also work with a local setup by using a LocalBackend and setting isLocal to true. - * It handles common logic, like determining a scheduling order across jobs, waking up to launch - * speculative tasks, etc. - * - * THREADING: SchedulerBackends and task-submitting clients can call this class from multiple - * threads, so it needs locks in public API methods to maintain its state. In addition, some - * SchedulerBackends sycnchronize on themselves when they want to send events here, and then - * acquire a lock on us, so we need to make sure that we don't try to lock the backend while - * we are holding a lock on ourselves. - */ -private[spark] class TaskScheduler(val sc: SparkContext, isLocal: Boolean = false) extends Logging { - // How often to check for speculative tasks - val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong - - // Threshold above which we warn user initial TaskSet may be starved - val STARVATION_TIMEOUT = System.getProperty("spark.starvation.timeout", "15000").toLong - - // TaskSetManagers are not thread safe, so any access to one should be synchronized - // on this class. - val activeTaskSets = new HashMap[String, TaskSetManager] - - val taskIdToTaskSetId = new HashMap[Long, String] - val taskIdToExecutorId = new HashMap[Long, String] - val taskSetTaskIds = new HashMap[String, HashSet[Long]] - - @volatile private var hasReceivedTask = false - @volatile private var hasLaunchedTask = false - private val starvationTimer = new Timer(true) - - // Incrementing task IDs - val nextTaskId = new AtomicLong(0) - - // Which executor IDs we have executors on - val activeExecutorIds = new HashSet[String] - - // The set of executors we have on each host; this is used to compute hostsAlive, which - // in turn is used to decide when we can attain data locality on a given host - private val executorsByHost = new HashMap[String, HashSet[String]] - - private val executorIdToHost = new HashMap[String, String] - - // Listener object to pass upcalls into - var dagScheduler: DAGScheduler = null - - var backend: SchedulerBackend = null - - val mapOutputTracker = SparkEnv.get.mapOutputTracker - - var schedulableBuilder: SchedulableBuilder = null - var rootPool: Pool = null - // default scheduler is FIFO - val schedulingMode: SchedulingMode = SchedulingMode.withName( - System.getProperty("spark.scheduler.mode", "FIFO")) - - // This is a var so that we can reset it for testing purposes. - private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this) - - def setDAGScheduler(dagScheduler: DAGScheduler) { - this.dagScheduler = dagScheduler - } - - def initialize(context: SchedulerBackend) { - backend = context - // temporarily set rootPool name to empty - rootPool = new Pool("", schedulingMode, 0, 0) - schedulableBuilder = { - schedulingMode match { - case SchedulingMode.FIFO => - new FIFOSchedulableBuilder(rootPool) - case SchedulingMode.FAIR => - new FairSchedulableBuilder(rootPool) - } - } - schedulableBuilder.buildPools() - } - - def newTaskId(): Long = nextTaskId.getAndIncrement() - - def start() { - backend.start() - - if (!isLocal && System.getProperty("spark.speculation", "false").toBoolean) { - new Thread("TaskScheduler speculation check") { - setDaemon(true) - - override def run() { - logInfo("Starting speculative execution thread") - while (true) { - try { - Thread.sleep(SPECULATION_INTERVAL) - } catch { - case e: InterruptedException => {} - } - checkSpeculatableTasks() - } - } - }.start() - } - } - - def submitTasks(taskSet: TaskSet) { - val tasks = taskSet.tasks - logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") - this.synchronized { - val manager = new TaskSetManager(this, taskSet) - activeTaskSets(taskSet.id) = manager - schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) - taskSetTaskIds(taskSet.id) = new HashSet[Long]() - - if (!isLocal && !hasReceivedTask) { - starvationTimer.scheduleAtFixedRate(new TimerTask() { - override def run() { - if (!hasLaunchedTask) { - logWarning("Initial job has not accepted any resources; " + - "check your cluster UI to ensure that workers are registered " + - "and have sufficient memory") - } else { - this.cancel() - } - } - }, STARVATION_TIMEOUT, STARVATION_TIMEOUT) - } - hasReceivedTask = true - } - backend.reviveOffers() - } - - def cancelTasks(stageId: Int): Unit = synchronized { - logInfo("Cancelling stage " + stageId) - activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) => - // There are two possible cases here: - // 1. The task set manager has been created and some tasks have been scheduled. - // In this case, send a kill signal to the executors to kill the task and then abort - // the stage. - // 2. The task set manager has been created but no tasks has been scheduled. In this case, - // simply abort the stage. - val taskIds = taskSetTaskIds(tsm.taskSet.id) - if (taskIds.size > 0) { - taskIds.foreach { tid => - val execId = taskIdToExecutorId(tid) - backend.killTask(tid, execId) - } - } - tsm.error("Stage %d was cancelled".format(stageId)) - } - } - - def taskSetFinished(manager: TaskSetManager): Unit = synchronized { - // Check to see if the given task set has been removed. This is possible in the case of - // multiple unrecoverable task failures (e.g. if the entire task set is killed when it has - // more than one running tasks). - if (activeTaskSets.contains(manager.taskSet.id)) { - activeTaskSets -= manager.taskSet.id - manager.parent.removeSchedulable(manager) - logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name)) - taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) - taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id) - taskSetTaskIds.remove(manager.taskSet.id) - } - } - - /** - * Called by cluster manager to offer resources on slaves. We respond by asking our active task - * sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so - * that tasks are balanced across the cluster. - */ - def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized { - SparkEnv.set(sc.env) - - // Mark each slave as alive and remember its hostname - for (o <- offers) { - executorIdToHost(o.executorId) = o.host - if (!executorsByHost.contains(o.host)) { - executorsByHost(o.host) = new HashSet[String]() - executorGained(o.executorId, o.host) - } - } - - // Build a list of tasks to assign to each worker - val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores)) - val availableCpus = offers.map(o => o.cores).toArray - val sortedTaskSets = rootPool.getSortedTaskSetQueue() - for (taskSet <- sortedTaskSets) { - logDebug("parentName: %s, name: %s, runningTasks: %s".format( - taskSet.parent.name, taskSet.name, taskSet.runningTasks)) - } - - // Take each TaskSet in our scheduling order, and then offer it each node in increasing order - // of locality levels so that it gets a chance to launch local tasks on all of them. - var launchedTask = false - for (taskSet <- sortedTaskSets; maxLocality <- TaskLocality.values) { - do { - launchedTask = false - for (i <- 0 until offers.size) { - val execId = offers(i).executorId - val host = offers(i).host - for (task <- taskSet.resourceOffer(execId, host, availableCpus(i), maxLocality)) { - tasks(i) += task - val tid = task.taskId - taskIdToTaskSetId(tid) = taskSet.taskSet.id - taskSetTaskIds(taskSet.taskSet.id) += tid - taskIdToExecutorId(tid) = execId - activeExecutorIds += execId - executorsByHost(host) += execId - availableCpus(i) -= 1 - launchedTask = true - } - } - } while (launchedTask) - } - - if (tasks.size > 0) { - hasLaunchedTask = true - } - return tasks - } - - def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { - var failedExecutor: Option[String] = None - var taskFailed = false - synchronized { - try { - if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) { - // We lost this entire executor, so remember that it's gone - val execId = taskIdToExecutorId(tid) - if (activeExecutorIds.contains(execId)) { - removeExecutor(execId) - failedExecutor = Some(execId) - } - } - taskIdToTaskSetId.get(tid) match { - case Some(taskSetId) => - if (TaskState.isFinished(state)) { - taskIdToTaskSetId.remove(tid) - if (taskSetTaskIds.contains(taskSetId)) { - taskSetTaskIds(taskSetId) -= tid - } - taskIdToExecutorId.remove(tid) - } - if (state == TaskState.FAILED) { - taskFailed = true - } - activeTaskSets.get(taskSetId).foreach { taskSet => - if (state == TaskState.FINISHED) { - taskSet.removeRunningTask(tid) - taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData) - } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) { - taskSet.removeRunningTask(tid) - taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData) - } - } - case None => - logInfo("Ignoring update from TID " + tid + " because its task set is gone") - } - } catch { - case e: Exception => logError("Exception in statusUpdate", e) - } - } - // Update the DAGScheduler without holding a lock on this, since that can deadlock - if (failedExecutor != None) { - dagScheduler.executorLost(failedExecutor.get) - backend.reviveOffers() - } - if (taskFailed) { - // Also revive offers if a task had failed for some reason other than host lost - backend.reviveOffers() - } - } - - def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long) { - taskSetManager.handleTaskGettingResult(tid) - } - - def handleSuccessfulTask( - taskSetManager: TaskSetManager, - tid: Long, - taskResult: DirectTaskResult[_]) = synchronized { - taskSetManager.handleSuccessfulTask(tid, taskResult) - } - - def handleFailedTask( - taskSetManager: TaskSetManager, - tid: Long, - taskState: TaskState, - reason: Option[TaskEndReason]) = synchronized { - taskSetManager.handleFailedTask(tid, taskState, reason) - if (taskState == TaskState.FINISHED) { - // The task finished successfully but the result was lost, so we should revive offers. - backend.reviveOffers() - } - } - - def error(message: String) { - synchronized { - if (activeTaskSets.size > 0) { - // Have each task set throw a SparkException with the error - for ((taskSetId, manager) <- activeTaskSets) { - try { - manager.error(message) - } catch { - case e: Exception => logError("Exception in error callback", e) - } - } - } else { - // No task sets are active but we still got an error. Just exit since this - // must mean the error is during registration. - // It might be good to do something smarter here in the future. - logError("Exiting due to error from task scheduler: " + message) - System.exit(1) - } - } - } - - def stop() { - if (backend != null) { - backend.stop() - } - if (taskResultGetter != null) { - taskResultGetter.stop() - } - - // sleeping for an arbitrary 5 seconds : to ensure that messages are sent out. - // TODO: Do something better ! - Thread.sleep(5000L) - } - - def defaultParallelism() = backend.defaultParallelism() - - // Check for speculatable tasks in all our active jobs. - def checkSpeculatableTasks() { - var shouldRevive = false - synchronized { - shouldRevive = rootPool.checkSpeculatableTasks() - } - if (shouldRevive) { - backend.reviveOffers() - } - } - - // Check for pending tasks in all our active jobs. - def hasPendingTasks: Boolean = { - synchronized { - rootPool.hasPendingTasks() - } - } - - def executorLost(executorId: String, reason: ExecutorLossReason) { - var failedExecutor: Option[String] = None - - synchronized { - if (activeExecutorIds.contains(executorId)) { - val hostPort = executorIdToHost(executorId) - logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason)) - removeExecutor(executorId) - failedExecutor = Some(executorId) - } else { - // We may get multiple executorLost() calls with different loss reasons. For example, one - // may be triggered by a dropped connection from the slave while another may be a report - // of executor termination from Mesos. We produce log messages for both so we eventually - // report the termination reason. - logError("Lost an executor " + executorId + " (already removed): " + reason) - } - } - // Call dagScheduler.executorLost without holding the lock on this to prevent deadlock - if (failedExecutor != None) { - dagScheduler.executorLost(failedExecutor.get) - backend.reviveOffers() - } - } - - /** Remove an executor from all our data structures and mark it as lost */ - private def removeExecutor(executorId: String) { - activeExecutorIds -= executorId - val host = executorIdToHost(executorId) - val execs = executorsByHost.getOrElse(host, new HashSet) - execs -= executorId - if (execs.isEmpty) { - executorsByHost -= host - } - executorIdToHost -= executorId - rootPool.executorLost(executorId, host) - } - - def executorGained(execId: String, host: String) { - dagScheduler.executorGained(execId, host) - } - - def getExecutorsAliveOnHost(host: String): Option[Set[String]] = synchronized { - executorsByHost.get(host).map(_.toSet) - } - - def hasExecutorsAliveOnHost(host: String): Boolean = synchronized { - executorsByHost.contains(host) - } - - def isExecutorAlive(execId: String): Boolean = synchronized { - activeExecutorIds.contains(execId) - } - - // By default, rack is unknown - def getRackForHost(value: String): Option[String] = None - - /** - * Invoked after the system has successfully been initialized. YARN uses this to bootstrap - * allocation of resources based on preferred locations, wait for slave registrations, etc. - */ - def postStartHook() { } -} - - -private[spark] object TaskScheduler { - /** - * Used to balance containers across hosts. - * - * Accepts a map of hosts to resource offers for that host, and returns a prioritized list of - * resource offers representing the order in which the offers should be used. The resource - * offers are ordered such that we'll allocate one container on each host before allocating a - * second container on any host, and so on, in order to reduce the damage if a host fails. - * - * For example, given , , , returns - * [o1, o5, o4, 02, o6, o3] - */ - def prioritizeContainers[K, T] (map: HashMap[K, ArrayBuffer[T]]): List[T] = { - val _keyList = new ArrayBuffer[K](map.size) - _keyList ++= map.keys - - // order keyList based on population of value in map - val keyList = _keyList.sortWith( - (left, right) => map(left).size > map(right).size - ) - - val retval = new ArrayBuffer[T](keyList.size * 2) - var index = 0 - var found = true - - while (found) { - found = false - for (key <- keyList) { - val containerList: ArrayBuffer[T] = map.get(key).getOrElse(null) - assert(containerList != null) - // Get the index'th entry for this host - if present - if (index < containerList.size){ - retval += containerList.apply(index) - found = true - } - } - index += 1 - } - - retval.toList - } -} diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 90b6519027..8757d7fd2a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -40,19 +40,22 @@ import org.apache.spark.util.{SystemClock, Clock} * * THREADING: This class is designed to only be called from code with a lock on the * TaskScheduler (e.g. its event handlers). It should not be called from other threads. + * + * @param sched the ClusterScheduler associated with the TaskSetManager + * @param taskSet the TaskSet to manage scheduling for + * @param maxTaskFailures if any particular task fails more than this number of times, the entire + * task set will be aborted */ private[spark] class TaskSetManager( - sched: TaskScheduler, + sched: ClusterScheduler, val taskSet: TaskSet, + val maxTaskFailures: Int, clock: Clock = SystemClock) extends Schedulable with Logging { // CPUs to request per task val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toInt - // Maximum times a task is allowed to fail before failing the job - val MAX_TASK_FAILURES = System.getProperty("spark.task.maxFailures", "4").toInt - // Quantile of tasks at which to start speculation val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble @@ -521,10 +524,10 @@ private[spark] class TaskSetManager( addPendingTask(index) if (state != TaskState.KILLED) { numFailures(index) += 1 - if (numFailures(index) > MAX_TASK_FAILURES) { + if (numFailures(index) > maxTaskFailures) { logError("Task %s:%d failed more than %d times; aborting job".format( - taskSet.id, index, MAX_TASK_FAILURES)) - abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES)) + taskSet.id, index, maxTaskFailures)) + abort("Task %s:%d failed more than %d times".format(taskSet.id, index, maxTaskFailures)) } } } else { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index b8ac498527..f5548fc2da 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -29,7 +29,7 @@ import akka.util.Duration import akka.util.duration._ import org.apache.spark.{SparkException, Logging, TaskState} -import org.apache.spark.scheduler.{SchedulerBackend, SlaveLost, TaskDescription, TaskScheduler, +import org.apache.spark.scheduler.{SchedulerBackend, SlaveLost, TaskDescription, ClusterScheduler, WorkerOffer} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.Utils @@ -43,7 +43,7 @@ import org.apache.spark.util.Utils * (spark.deploy.*). */ private[spark] -class CoarseGrainedSchedulerBackend(scheduler: TaskScheduler, actorSystem: ActorSystem) +class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: ActorSystem) extends SchedulerBackend with Logging { // Use an atomic variable to track total number of cores in the cluster for simplicity and speed diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index a589e7456f..40fdfcddb1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -21,10 +21,10 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, FileSystem} import org.apache.spark.{Logging, SparkContext} -import org.apache.spark.scheduler.TaskScheduler +import org.apache.spark.scheduler.ClusterScheduler private[spark] class SimrSchedulerBackend( - scheduler: TaskScheduler, + scheduler: ClusterScheduler, sc: SparkContext, driverFilePath: String) extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 15c600a1ec..acf15dbc40 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -22,11 +22,11 @@ import scala.collection.mutable.HashMap import org.apache.spark.{Logging, SparkContext} import org.apache.spark.deploy.client.{Client, ClientListener} import org.apache.spark.deploy.{Command, ApplicationDescription} -import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskScheduler} +import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, ClusterScheduler} import org.apache.spark.util.Utils private[spark] class SparkDeploySchedulerBackend( - scheduler: TaskScheduler, + scheduler: ClusterScheduler, sc: SparkContext, masters: Array[String], appName: String) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 310da0027e..226ea46cc7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -30,7 +30,7 @@ import org.apache.mesos._ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} import org.apache.spark.{SparkException, Logging, SparkContext, TaskState} -import org.apache.spark.scheduler.TaskScheduler +import org.apache.spark.scheduler.ClusterScheduler import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend /** @@ -44,7 +44,7 @@ import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend * remove this. */ private[spark] class CoarseMesosSchedulerBackend( - scheduler: TaskScheduler, + scheduler: ClusterScheduler, sc: SparkContext, master: String, appName: String) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index c0e99df0b6..3acad1bb46 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -31,7 +31,7 @@ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTas import org.apache.spark.{Logging, SparkException, SparkContext, TaskState} import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SchedulerBackend, SlaveLost, - TaskDescription, TaskScheduler, WorkerOffer} + TaskDescription, ClusterScheduler, WorkerOffer} import org.apache.spark.util.Utils /** @@ -40,7 +40,7 @@ import org.apache.spark.util.Utils * from multiple apps can run on different cores) and in time (a core can switch ownership). */ private[spark] class MesosSchedulerBackend( - scheduler: TaskScheduler, + scheduler: ClusterScheduler, sc: SparkContext, master: String, appName: String) diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 96c3a03602..3e9d31cd5e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -24,16 +24,17 @@ import akka.actor.{Actor, ActorRef, Props} import org.apache.spark.{SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} -import org.apache.spark.scheduler.{SchedulerBackend, TaskScheduler, WorkerOffer} +import org.apache.spark.scheduler.{SchedulerBackend, ClusterScheduler, WorkerOffer} /** - * LocalBackend sits behind a TaskScheduler and handles launching tasks on a single Executor - * (created by the LocalBackend) running locally. + * LocalBackend is used when running a local version of Spark where the executor, backend, and + * master all run in the same JVM. It sits behind a ClusterScheduler and handles launching tasks + * on a single Executor (created by the LocalBackend) running locally. * * THREADING: Because methods can be called both from the Executor and the TaskScheduler, and * because the Executor class is not thread safe, all methods are synchronized. */ -private[spark] class LocalBackend(scheduler: TaskScheduler, private val totalCores: Int) +private[spark] class LocalBackend(scheduler: ClusterScheduler, private val totalCores: Int) extends SchedulerBackend with ExecutorBackend { private var freeCores = totalCores diff --git a/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala new file mode 100644 index 0000000000..96adcf7198 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala @@ -0,0 +1,265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter + +import org.apache.spark._ +import scala.collection.mutable.ArrayBuffer + +import java.util.Properties + +class FakeTaskSetManager( + initPriority: Int, + initStageId: Int, + initNumTasks: Int, + taskScheduler: ClusterScheduler, + taskSet: TaskSet) + extends TaskSetManager(taskScheduler, taskSet, 1) { + + parent = null + weight = 1 + minShare = 2 + runningTasks = 0 + priority = initPriority + stageId = initStageId + name = "TaskSet_"+stageId + override val numTasks = initNumTasks + tasksSuccessful = 0 + + def increaseRunningTasks(taskNum: Int) { + runningTasks += taskNum + if (parent != null) { + parent.increaseRunningTasks(taskNum) + } + } + + def decreaseRunningTasks(taskNum: Int) { + runningTasks -= taskNum + if (parent != null) { + parent.decreaseRunningTasks(taskNum) + } + } + + override def addSchedulable(schedulable: Schedulable) { + } + + override def removeSchedulable(schedulable: Schedulable) { + } + + override def getSchedulableByName(name: String): Schedulable = { + return null + } + + override def executorLost(executorId: String, host: String): Unit = { + } + + override def resourceOffer( + execId: String, + host: String, + availableCpus: Int, + maxLocality: TaskLocality.TaskLocality) + : Option[TaskDescription] = + { + if (tasksSuccessful + runningTasks < numTasks) { + increaseRunningTasks(1) + return Some(new TaskDescription(0, execId, "task 0:0", 0, null)) + } + return None + } + + override def checkSpeculatableTasks(): Boolean = { + return true + } + + def taskFinished() { + decreaseRunningTasks(1) + tasksSuccessful +=1 + if (tasksSuccessful == numTasks) { + parent.removeSchedulable(this) + } + } + + def abort() { + decreaseRunningTasks(runningTasks) + parent.removeSchedulable(this) + } +} + +class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging { + + def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: ClusterScheduler, taskSet: TaskSet): FakeTaskSetManager = { + new FakeTaskSetManager(priority, stage, numTasks, cs , taskSet) + } + + def resourceOffer(rootPool: Pool): Int = { + val taskSetQueue = rootPool.getSortedTaskSetQueue() + /* Just for Test*/ + for (manager <- taskSetQueue) { + logInfo("parentName:%s, parent running tasks:%d, name:%s,runningTasks:%d".format( + manager.parent.name, manager.parent.runningTasks, manager.name, manager.runningTasks)) + } + for (taskSet <- taskSetQueue) { + taskSet.resourceOffer("execId_1", "hostname_1", 1, TaskLocality.ANY) match { + case Some(task) => + return taskSet.stageId + case None => {} + } + } + -1 + } + + def checkTaskSetId(rootPool: Pool, expectedTaskSetId: Int) { + assert(resourceOffer(rootPool) === expectedTaskSetId) + } + + test("FIFO Scheduler Test") { + sc = new SparkContext("local", "TaskSchedulerSuite") + val taskScheduler = new ClusterScheduler(sc) + var tasks = ArrayBuffer[Task[_]]() + val task = new FakeTask(0) + tasks += task + val taskSet = new TaskSet(tasks.toArray,0,0,0,null) + + val rootPool = new Pool("", SchedulingMode.FIFO, 0, 0) + val schedulableBuilder = new FIFOSchedulableBuilder(rootPool) + schedulableBuilder.buildPools() + + val taskSetManager0 = createDummyTaskSetManager(0, 0, 2, taskScheduler, taskSet) + val taskSetManager1 = createDummyTaskSetManager(0, 1, 2, taskScheduler, taskSet) + val taskSetManager2 = createDummyTaskSetManager(0, 2, 2, taskScheduler, taskSet) + schedulableBuilder.addTaskSetManager(taskSetManager0, null) + schedulableBuilder.addTaskSetManager(taskSetManager1, null) + schedulableBuilder.addTaskSetManager(taskSetManager2, null) + + checkTaskSetId(rootPool, 0) + resourceOffer(rootPool) + checkTaskSetId(rootPool, 1) + resourceOffer(rootPool) + taskSetManager1.abort() + checkTaskSetId(rootPool, 2) + } + + test("Fair Scheduler Test") { + sc = new SparkContext("local", "TaskSchedulerSuite") + val taskScheduler = new ClusterScheduler(sc) + var tasks = ArrayBuffer[Task[_]]() + val task = new FakeTask(0) + tasks += task + val taskSet = new TaskSet(tasks.toArray,0,0,0,null) + + val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() + System.setProperty("spark.scheduler.allocation.file", xmlPath) + val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) + val schedulableBuilder = new FairSchedulableBuilder(rootPool) + schedulableBuilder.buildPools() + + assert(rootPool.getSchedulableByName("default") != null) + assert(rootPool.getSchedulableByName("1") != null) + assert(rootPool.getSchedulableByName("2") != null) + assert(rootPool.getSchedulableByName("3") != null) + assert(rootPool.getSchedulableByName("1").minShare === 2) + assert(rootPool.getSchedulableByName("1").weight === 1) + assert(rootPool.getSchedulableByName("2").minShare === 3) + assert(rootPool.getSchedulableByName("2").weight === 1) + assert(rootPool.getSchedulableByName("3").minShare === 0) + assert(rootPool.getSchedulableByName("3").weight === 1) + + val properties1 = new Properties() + properties1.setProperty("spark.scheduler.pool","1") + val properties2 = new Properties() + properties2.setProperty("spark.scheduler.pool","2") + + val taskSetManager10 = createDummyTaskSetManager(1, 0, 1, taskScheduler, taskSet) + val taskSetManager11 = createDummyTaskSetManager(1, 1, 1, taskScheduler, taskSet) + val taskSetManager12 = createDummyTaskSetManager(1, 2, 2, taskScheduler, taskSet) + schedulableBuilder.addTaskSetManager(taskSetManager10, properties1) + schedulableBuilder.addTaskSetManager(taskSetManager11, properties1) + schedulableBuilder.addTaskSetManager(taskSetManager12, properties1) + + val taskSetManager23 = createDummyTaskSetManager(2, 3, 2, taskScheduler, taskSet) + val taskSetManager24 = createDummyTaskSetManager(2, 4, 2, taskScheduler, taskSet) + schedulableBuilder.addTaskSetManager(taskSetManager23, properties2) + schedulableBuilder.addTaskSetManager(taskSetManager24, properties2) + + checkTaskSetId(rootPool, 0) + checkTaskSetId(rootPool, 3) + checkTaskSetId(rootPool, 3) + checkTaskSetId(rootPool, 1) + checkTaskSetId(rootPool, 4) + checkTaskSetId(rootPool, 2) + checkTaskSetId(rootPool, 2) + checkTaskSetId(rootPool, 4) + + taskSetManager12.taskFinished() + assert(rootPool.getSchedulableByName("1").runningTasks === 3) + taskSetManager24.abort() + assert(rootPool.getSchedulableByName("2").runningTasks === 2) + } + + test("Nested Pool Test") { + sc = new SparkContext("local", "TaskSchedulerSuite") + val taskScheduler = new ClusterScheduler(sc) + var tasks = ArrayBuffer[Task[_]]() + val task = new FakeTask(0) + tasks += task + val taskSet = new TaskSet(tasks.toArray,0,0,0,null) + + val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) + val pool0 = new Pool("0", SchedulingMode.FAIR, 3, 1) + val pool1 = new Pool("1", SchedulingMode.FAIR, 4, 1) + rootPool.addSchedulable(pool0) + rootPool.addSchedulable(pool1) + + val pool00 = new Pool("00", SchedulingMode.FAIR, 2, 2) + val pool01 = new Pool("01", SchedulingMode.FAIR, 1, 1) + pool0.addSchedulable(pool00) + pool0.addSchedulable(pool01) + + val pool10 = new Pool("10", SchedulingMode.FAIR, 2, 2) + val pool11 = new Pool("11", SchedulingMode.FAIR, 2, 1) + pool1.addSchedulable(pool10) + pool1.addSchedulable(pool11) + + val taskSetManager000 = createDummyTaskSetManager(0, 0, 5, taskScheduler, taskSet) + val taskSetManager001 = createDummyTaskSetManager(0, 1, 5, taskScheduler, taskSet) + pool00.addSchedulable(taskSetManager000) + pool00.addSchedulable(taskSetManager001) + + val taskSetManager010 = createDummyTaskSetManager(1, 2, 5, taskScheduler, taskSet) + val taskSetManager011 = createDummyTaskSetManager(1, 3, 5, taskScheduler, taskSet) + pool01.addSchedulable(taskSetManager010) + pool01.addSchedulable(taskSetManager011) + + val taskSetManager100 = createDummyTaskSetManager(2, 4, 5, taskScheduler, taskSet) + val taskSetManager101 = createDummyTaskSetManager(2, 5, 5, taskScheduler, taskSet) + pool10.addSchedulable(taskSetManager100) + pool10.addSchedulable(taskSetManager101) + + val taskSetManager110 = createDummyTaskSetManager(3, 6, 5, taskScheduler, taskSet) + val taskSetManager111 = createDummyTaskSetManager(3, 7, 5, taskScheduler, taskSet) + pool11.addSchedulable(taskSetManager110) + pool11.addSchedulable(taskSetManager111) + + checkTaskSetId(rootPool, 0) + checkTaskSetId(rootPool, 4) + checkTaskSetId(rootPool, 6) + checkTaskSetId(rootPool, 2) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 5b5a2178f3..24689a7093 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -37,7 +37,7 @@ import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} * TaskScheduler that records the task sets that the DAGScheduler requested executed. */ class TaskSetRecordingTaskScheduler(sc: SparkContext, - mapOutputTrackerMaster: MapOutputTrackerMaster) extends TaskScheduler(sc) { + mapOutputTrackerMaster: MapOutputTrackerMaster) extends ClusterScheduler(sc) { /** Set of TaskSets the DAGScheduler has requested executed. */ val taskSets = scala.collection.mutable.Buffer[TaskSet]() override def start() = {} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index 30e6bc5721..2ac2d7a36a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.storage.TaskResultBlockId * Used to test the case where a BlockManager evicts the task result (or dies) before the * TaskResult is retrieved. */ -class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskScheduler) +class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler) extends TaskResultGetter(sparkEnv, scheduler) { var removedResult = false @@ -91,8 +91,8 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA test("task retried if result missing from block manager") { // If this test hangs, it's probably because no resource offers were made after the task // failed. - val scheduler: TaskScheduler = sc.taskScheduler match { - case clusterScheduler: TaskScheduler => + val scheduler: ClusterScheduler = sc.taskScheduler match { + case clusterScheduler: ClusterScheduler => clusterScheduler case _ => assert(false, "Expect local cluster to use TaskScheduler") diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerSuite.scala deleted file mode 100644 index bfbffdf261..0000000000 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerSuite.scala +++ /dev/null @@ -1,265 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler - -import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter - -import org.apache.spark._ -import scala.collection.mutable.ArrayBuffer - -import java.util.Properties - -class FakeTaskSetManager( - initPriority: Int, - initStageId: Int, - initNumTasks: Int, - taskScheduler: TaskScheduler, - taskSet: TaskSet) - extends TaskSetManager(taskScheduler, taskSet) { - - parent = null - weight = 1 - minShare = 2 - runningTasks = 0 - priority = initPriority - stageId = initStageId - name = "TaskSet_"+stageId - override val numTasks = initNumTasks - tasksSuccessful = 0 - - def increaseRunningTasks(taskNum: Int) { - runningTasks += taskNum - if (parent != null) { - parent.increaseRunningTasks(taskNum) - } - } - - def decreaseRunningTasks(taskNum: Int) { - runningTasks -= taskNum - if (parent != null) { - parent.decreaseRunningTasks(taskNum) - } - } - - override def addSchedulable(schedulable: Schedulable) { - } - - override def removeSchedulable(schedulable: Schedulable) { - } - - override def getSchedulableByName(name: String): Schedulable = { - return null - } - - override def executorLost(executorId: String, host: String): Unit = { - } - - override def resourceOffer( - execId: String, - host: String, - availableCpus: Int, - maxLocality: TaskLocality.TaskLocality) - : Option[TaskDescription] = - { - if (tasksSuccessful + runningTasks < numTasks) { - increaseRunningTasks(1) - return Some(new TaskDescription(0, execId, "task 0:0", 0, null)) - } - return None - } - - override def checkSpeculatableTasks(): Boolean = { - return true - } - - def taskFinished() { - decreaseRunningTasks(1) - tasksSuccessful +=1 - if (tasksSuccessful == numTasks) { - parent.removeSchedulable(this) - } - } - - def abort() { - decreaseRunningTasks(runningTasks) - parent.removeSchedulable(this) - } -} - -class TaskSchedulerSuite extends FunSuite with LocalSparkContext with Logging { - - def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: TaskScheduler, taskSet: TaskSet): FakeTaskSetManager = { - new FakeTaskSetManager(priority, stage, numTasks, cs , taskSet) - } - - def resourceOffer(rootPool: Pool): Int = { - val taskSetQueue = rootPool.getSortedTaskSetQueue() - /* Just for Test*/ - for (manager <- taskSetQueue) { - logInfo("parentName:%s, parent running tasks:%d, name:%s,runningTasks:%d".format( - manager.parent.name, manager.parent.runningTasks, manager.name, manager.runningTasks)) - } - for (taskSet <- taskSetQueue) { - taskSet.resourceOffer("execId_1", "hostname_1", 1, TaskLocality.ANY) match { - case Some(task) => - return taskSet.stageId - case None => {} - } - } - -1 - } - - def checkTaskSetId(rootPool: Pool, expectedTaskSetId: Int) { - assert(resourceOffer(rootPool) === expectedTaskSetId) - } - - test("FIFO Scheduler Test") { - sc = new SparkContext("local", "TaskSchedulerSuite") - val taskScheduler = new TaskScheduler(sc) - var tasks = ArrayBuffer[Task[_]]() - val task = new FakeTask(0) - tasks += task - val taskSet = new TaskSet(tasks.toArray,0,0,0,null) - - val rootPool = new Pool("", SchedulingMode.FIFO, 0, 0) - val schedulableBuilder = new FIFOSchedulableBuilder(rootPool) - schedulableBuilder.buildPools() - - val taskSetManager0 = createDummyTaskSetManager(0, 0, 2, taskScheduler, taskSet) - val taskSetManager1 = createDummyTaskSetManager(0, 1, 2, taskScheduler, taskSet) - val taskSetManager2 = createDummyTaskSetManager(0, 2, 2, taskScheduler, taskSet) - schedulableBuilder.addTaskSetManager(taskSetManager0, null) - schedulableBuilder.addTaskSetManager(taskSetManager1, null) - schedulableBuilder.addTaskSetManager(taskSetManager2, null) - - checkTaskSetId(rootPool, 0) - resourceOffer(rootPool) - checkTaskSetId(rootPool, 1) - resourceOffer(rootPool) - taskSetManager1.abort() - checkTaskSetId(rootPool, 2) - } - - test("Fair Scheduler Test") { - sc = new SparkContext("local", "TaskSchedulerSuite") - val taskScheduler = new TaskScheduler(sc) - var tasks = ArrayBuffer[Task[_]]() - val task = new FakeTask(0) - tasks += task - val taskSet = new TaskSet(tasks.toArray,0,0,0,null) - - val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() - System.setProperty("spark.scheduler.allocation.file", xmlPath) - val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) - val schedulableBuilder = new FairSchedulableBuilder(rootPool) - schedulableBuilder.buildPools() - - assert(rootPool.getSchedulableByName("default") != null) - assert(rootPool.getSchedulableByName("1") != null) - assert(rootPool.getSchedulableByName("2") != null) - assert(rootPool.getSchedulableByName("3") != null) - assert(rootPool.getSchedulableByName("1").minShare === 2) - assert(rootPool.getSchedulableByName("1").weight === 1) - assert(rootPool.getSchedulableByName("2").minShare === 3) - assert(rootPool.getSchedulableByName("2").weight === 1) - assert(rootPool.getSchedulableByName("3").minShare === 0) - assert(rootPool.getSchedulableByName("3").weight === 1) - - val properties1 = new Properties() - properties1.setProperty("spark.scheduler.pool","1") - val properties2 = new Properties() - properties2.setProperty("spark.scheduler.pool","2") - - val taskSetManager10 = createDummyTaskSetManager(1, 0, 1, taskScheduler, taskSet) - val taskSetManager11 = createDummyTaskSetManager(1, 1, 1, taskScheduler, taskSet) - val taskSetManager12 = createDummyTaskSetManager(1, 2, 2, taskScheduler, taskSet) - schedulableBuilder.addTaskSetManager(taskSetManager10, properties1) - schedulableBuilder.addTaskSetManager(taskSetManager11, properties1) - schedulableBuilder.addTaskSetManager(taskSetManager12, properties1) - - val taskSetManager23 = createDummyTaskSetManager(2, 3, 2, taskScheduler, taskSet) - val taskSetManager24 = createDummyTaskSetManager(2, 4, 2, taskScheduler, taskSet) - schedulableBuilder.addTaskSetManager(taskSetManager23, properties2) - schedulableBuilder.addTaskSetManager(taskSetManager24, properties2) - - checkTaskSetId(rootPool, 0) - checkTaskSetId(rootPool, 3) - checkTaskSetId(rootPool, 3) - checkTaskSetId(rootPool, 1) - checkTaskSetId(rootPool, 4) - checkTaskSetId(rootPool, 2) - checkTaskSetId(rootPool, 2) - checkTaskSetId(rootPool, 4) - - taskSetManager12.taskFinished() - assert(rootPool.getSchedulableByName("1").runningTasks === 3) - taskSetManager24.abort() - assert(rootPool.getSchedulableByName("2").runningTasks === 2) - } - - test("Nested Pool Test") { - sc = new SparkContext("local", "TaskSchedulerSuite") - val taskScheduler = new TaskScheduler(sc) - var tasks = ArrayBuffer[Task[_]]() - val task = new FakeTask(0) - tasks += task - val taskSet = new TaskSet(tasks.toArray,0,0,0,null) - - val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) - val pool0 = new Pool("0", SchedulingMode.FAIR, 3, 1) - val pool1 = new Pool("1", SchedulingMode.FAIR, 4, 1) - rootPool.addSchedulable(pool0) - rootPool.addSchedulable(pool1) - - val pool00 = new Pool("00", SchedulingMode.FAIR, 2, 2) - val pool01 = new Pool("01", SchedulingMode.FAIR, 1, 1) - pool0.addSchedulable(pool00) - pool0.addSchedulable(pool01) - - val pool10 = new Pool("10", SchedulingMode.FAIR, 2, 2) - val pool11 = new Pool("11", SchedulingMode.FAIR, 2, 1) - pool1.addSchedulable(pool10) - pool1.addSchedulable(pool11) - - val taskSetManager000 = createDummyTaskSetManager(0, 0, 5, taskScheduler, taskSet) - val taskSetManager001 = createDummyTaskSetManager(0, 1, 5, taskScheduler, taskSet) - pool00.addSchedulable(taskSetManager000) - pool00.addSchedulable(taskSetManager001) - - val taskSetManager010 = createDummyTaskSetManager(1, 2, 5, taskScheduler, taskSet) - val taskSetManager011 = createDummyTaskSetManager(1, 3, 5, taskScheduler, taskSet) - pool01.addSchedulable(taskSetManager010) - pool01.addSchedulable(taskSetManager011) - - val taskSetManager100 = createDummyTaskSetManager(2, 4, 5, taskScheduler, taskSet) - val taskSetManager101 = createDummyTaskSetManager(2, 5, 5, taskScheduler, taskSet) - pool10.addSchedulable(taskSetManager100) - pool10.addSchedulable(taskSetManager101) - - val taskSetManager110 = createDummyTaskSetManager(3, 6, 5, taskScheduler, taskSet) - val taskSetManager111 = createDummyTaskSetManager(3, 7, 5, taskScheduler, taskSet) - pool11.addSchedulable(taskSetManager110) - pool11.addSchedulable(taskSetManager111) - - checkTaskSetId(rootPool, 0) - checkTaskSetId(rootPool, 4) - checkTaskSetId(rootPool, 6) - checkTaskSetId(rootPool, 2) - } -} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index fe3ea7b594..592bb11364 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -58,7 +58,7 @@ class FakeDAGScheduler(taskScheduler: FakeTaskScheduler) extends DAGScheduler(ta * to work, and these are required for locality in TaskSetManager. */ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* execId, host */) - extends TaskScheduler(sc) + extends ClusterScheduler(sc) { val startedTasks = new ArrayBuffer[Long] val endedTasks = new mutable.HashMap[Long, TaskEndReason] @@ -82,12 +82,13 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { import TaskLocality.{ANY, PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL} val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong + val MAX_TASK_FAILURES = 4 test("TaskSet with no preferences") { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = createTaskSet(1) - val manager = new TaskSetManager(sched, taskSet) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) // Offer a host with no CPUs assert(manager.resourceOffer("exec1", "host1", 0, ANY) === None) @@ -113,7 +114,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = createTaskSet(3) - val manager = new TaskSetManager(sched, taskSet) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) // First three offers should all find tasks for (i <- 0 until 3) { @@ -150,7 +151,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { Seq() // Last task has no locality prefs ) val clock = new FakeClock - val manager = new TaskSetManager(sched, taskSet, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // First offer host1, exec1: first task should be chosen assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) @@ -196,7 +197,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { Seq(TaskLocation("host2")) ) val clock = new FakeClock - val manager = new TaskSetManager(sched, taskSet, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // First offer host1: first task should be chosen assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) @@ -233,7 +234,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { Seq(TaskLocation("host3")) ) val clock = new FakeClock - val manager = new TaskSetManager(sched, taskSet, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // First offer host1: first task should be chosen assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) @@ -261,7 +262,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = createTaskSet(1) val clock = new FakeClock - val manager = new TaskSetManager(sched, taskSet, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) @@ -278,17 +279,17 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = createTaskSet(1) val clock = new FakeClock - val manager = new TaskSetManager(sched, taskSet, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // Fail the task MAX_TASK_FAILURES times, and check that the task set is aborted // after the last failure. - (0 until manager.MAX_TASK_FAILURES).foreach { index => + (0 until MAX_TASK_FAILURES).foreach { index => val offerResult = manager.resourceOffer("exec1", "host1", 1, ANY) assert(offerResult != None, "Expect resource offer on iteration %s to return a task".format(index)) assert(offerResult.get.index === 0) manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, Some(TaskResultLost)) - if (index < manager.MAX_TASK_FAILURES) { + if (index < MAX_TASK_FAILURES) { assert(!sched.taskSetsFailed.contains(taskSet.id)) } else { assert(sched.taskSetsFailed.contains(taskSet.id)) -- cgit v1.2.3 From 46f9c6b858cf9737b7d46b22b75bfc847244331b Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Wed, 13 Nov 2013 15:46:41 -0800 Subject: Fixed naming issues and added back ability to specify max task failures. --- .../main/scala/org/apache/spark/SparkContext.scala | 17 +++- .../apache/spark/scheduler/ClusterScheduler.scala | 19 ++--- .../apache/spark/scheduler/SchedulerBackend.scala | 2 +- .../org/apache/spark/scheduler/TaskScheduler.scala | 56 +++++++++++++ .../apache/spark/scheduler/TaskSetManager.scala | 2 +- .../cluster/CoarseGrainedSchedulerBackend.scala | 2 +- .../cluster/mesos/MesosSchedulerBackend.scala | 2 +- .../test/scala/org/apache/spark/FailureSuite.scala | 20 ++--- .../spark/scheduler/ClusterSchedulerSuite.scala | 48 +++++------ .../apache/spark/scheduler/DAGSchedulerSuite.scala | 97 +++++++++++----------- .../spark/scheduler/TaskResultGetterSuite.scala | 13 +-- .../spark/scheduler/TaskSetManagerSuite.scala | 20 ++--- .../scheduler/cluster/YarnClusterScheduler.scala | 6 +- 13 files changed, 177 insertions(+), 127 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 10db2fa7e7..06bea0c535 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -156,6 +156,8 @@ class SparkContext( private[spark] var taskScheduler: TaskScheduler = { // Regular expression used for local[N] master format val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r + // Regular expression for local[N, maxRetries], used in tests with failing tasks + val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r // Regular expression for simulating a Spark cluster of [N, cores, memory] locally val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r // Regular expression for connecting to Spark deploy clusters @@ -165,19 +167,28 @@ class SparkContext( // Regular expression for connection to Simr cluster val SIMR_REGEX = """simr://(.*)""".r + // When running locally, don't try to re-execute tasks on failure. + val MAX_LOCAL_TASK_FAILURES = 0 + master match { case "local" => - val scheduler = new ClusterScheduler(this, isLocal = true) + val scheduler = new ClusterScheduler(this, MAX_LOCAL_TASK_FAILURES, isLocal = true) val backend = new LocalBackend(scheduler, 1) scheduler.initialize(backend) scheduler case LOCAL_N_REGEX(threads) => - val scheduler = new ClusterScheduler(this, isLocal = true) + val scheduler = new ClusterScheduler(this, MAX_LOCAL_TASK_FAILURES, isLocal = true) val backend = new LocalBackend(scheduler, threads.toInt) scheduler.initialize(backend) scheduler + case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => + val scheduler = new ClusterScheduler(this, maxFailures.toInt, isLocal = true) + val backend = new LocalBackend(scheduler, threads.toInt) + scheduler.initialize(backend) + scheduler + case SPARK_REGEX(sparkUrl) => val scheduler = new ClusterScheduler(this) val masterUrls = sparkUrl.split(",").map("spark://" + _) @@ -200,7 +211,7 @@ class SparkContext( memoryPerSlaveInt, SparkContext.executorMemoryRequested)) } - val scheduler = new ClusterScheduler(this, isLocal = true) + val scheduler = new ClusterScheduler(this) val localCluster = new LocalSparkCluster( numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt) val masterUrls = localCluster.start() diff --git a/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala index c5d7ca0481..37d554715d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala @@ -46,8 +46,10 @@ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode * acquire a lock on us, so we need to make sure that we don't try to lock the backend while * we are holding a lock on ourselves. */ -private[spark] class ClusterScheduler(val sc: SparkContext, isLocal: Boolean = false) - extends TaskScheduler with Logging { +private[spark] class ClusterScheduler( + val sc: SparkContext, + val maxTaskFailures : Int = System.getProperty("spark.task.maxFailures", "4").toInt, + isLocal: Boolean = false) extends TaskScheduler with Logging { // How often to check for speculative tasks val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong @@ -59,15 +61,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext, isLocal: Boolean = f // on this class. val activeTaskSets = new HashMap[String, TaskSetManager] - val MAX_TASK_FAILURES = { - if (isLocal) { - // No sense in retrying if all tasks run locally! - 0 - } else { - System.getProperty("spark.task.maxFailures", "4").toInt - } - } - val taskIdToTaskSetId = new HashMap[Long, String] val taskIdToExecutorId = new HashMap[Long, String] val taskSetTaskIds = new HashMap[String, HashSet[Long]] @@ -142,7 +135,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext, isLocal: Boolean = f val tasks = taskSet.tasks logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") this.synchronized { - val manager = new TaskSetManager(this, taskSet, MAX_TASK_FAILURES) + val manager = new TaskSetManager(this, taskSet, maxTaskFailures) activeTaskSets(taskSet.id) = manager schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) taskSetTaskIds(taskSet.id) = new HashSet[Long]() @@ -345,7 +338,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext, isLocal: Boolean = f // No task sets are active but we still got an error. Just exit since this // must mean the error is during registration. // It might be good to do something smarter here in the future. - logError("Exiting due to error from task scheduler: " + message) + logError("Exiting due to error from cluster scheduler: " + message) System.exit(1) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala index 1f0839a0e1..89aa098664 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkContext /** * A backend interface for scheduling systems that allows plugging in different ones under - * TaskScheduler. We assume a Mesos-like model where the application gets resource offers as + * ClusterScheduler. We assume a Mesos-like model where the application gets resource offers as * machines become available and can launch tasks on them. */ private[spark] trait SchedulerBackend { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala new file mode 100644 index 0000000000..17b6d97e90 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.scheduler.SchedulingMode.SchedulingMode + +/** + * Low-level task scheduler interface, currently implemented exclusively by the ClusterScheduler. + * This interface allows plugging in different task schedulers. Each TaskScheduler schedulers tasks + * for a single SparkContext. These schedulers get sets of tasks submitted to them from the + * DAGScheduler for each stage, and are responsible for sending the tasks to the cluster, running + * them, retrying if there are failures, and mitigating stragglers. They return events to the + * DAGScheduler. + */ +private[spark] trait TaskScheduler { + + def rootPool: Pool + + def schedulingMode: SchedulingMode + + def start(): Unit + + // Invoked after system has successfully initialized (typically in spark context). + // Yarn uses this to bootstrap allocation of resources based on preferred locations, wait for slave registerations, etc. + def postStartHook() { } + + // Disconnect from the cluster. + def stop(): Unit + + // Submit a sequence of tasks to run. + def submitTasks(taskSet: TaskSet): Unit + + // Cancel a stage. + def cancelTasks(stageId: Int) + + // Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called. + def setDAGScheduler(dagScheduler: DAGScheduler): Unit + + // Get the default level of parallelism to use in the cluster, as a hint for sizing jobs. + def defaultParallelism(): Int +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 8757d7fd2a..bc35e53220 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -32,7 +32,7 @@ import org.apache.spark.util.{SystemClock, Clock} /** - * Schedules the tasks within a single TaskSet in the TaskScheduler. This class keeps track of + * Schedules the tasks within a single TaskSet in the ClusterScheduler. This class keeps track of * each task, retries tasks if they fail (up to a limited number of times), and * handles locality-aware scheduling for this TaskSet via delay scheduling. The main interfaces * to it are resourceOffer, which asks the TaskSet whether it wants to run a task on one node, diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 3bb715e7d0..3af02b42b2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -29,7 +29,7 @@ import akka.util.Duration import akka.util.duration._ import org.apache.spark.{SparkException, Logging, TaskState} -import org.apache.spark.scheduler.{SchedulerBackend, SlaveLost, TaskDescription, ClusterScheduler, +import org.apache.spark.scheduler.{ClusterScheduler, SchedulerBackend, SlaveLost, TaskDescription, WorkerOffer} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.Utils diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 3acad1bb46..773b980c53 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -209,7 +209,7 @@ private[spark] class MesosSchedulerBackend( getResource(offer.getResourcesList, "cpus").toInt) } - // Call into the TaskScheduler + // Call into the ClusterScheduler val taskLists = scheduler.resourceOffers(offerableWorkers) // Build a list of Mesos tasks for each slave diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index 2f7d6dff38..af448fcb37 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.FunSuite import SparkContext._ import org.apache.spark.util.NonSerializable @@ -37,20 +37,12 @@ object FailureSuiteState { } } -class FailureSuite extends FunSuite with LocalSparkContext with BeforeAndAfterAll { - - override def beforeAll { - System.setProperty("spark.task.maxFailures", "1") - } - - override def afterAll { - System.clearProperty("spark.task.maxFailures") - } +class FailureSuite extends FunSuite with LocalSparkContext { // Run a 3-task map job in which task 1 deterministically fails once, and check // whether the job completes successfully and we ran 4 tasks in total. test("failure in a single-stage job") { - sc = new SparkContext("local[1]", "test") + sc = new SparkContext("local[1,1]", "test") val results = sc.makeRDD(1 to 3, 3).map { x => FailureSuiteState.synchronized { FailureSuiteState.tasksRun += 1 @@ -70,7 +62,7 @@ class FailureSuite extends FunSuite with LocalSparkContext with BeforeAndAfterAl // Run a map-reduce job in which a reduce task deterministically fails once. test("failure in a two-stage job") { - sc = new SparkContext("local[1]", "test") + sc = new SparkContext("local[1,1]", "test") val results = sc.makeRDD(1 to 3).map(x => (x, x)).groupByKey(3).map { case (k, v) => FailureSuiteState.synchronized { @@ -90,7 +82,7 @@ class FailureSuite extends FunSuite with LocalSparkContext with BeforeAndAfterAl } test("failure because task results are not serializable") { - sc = new SparkContext("local[1]", "test") + sc = new SparkContext("local[1,1]", "test") val results = sc.makeRDD(1 to 3).map(x => new NonSerializable) val thrown = intercept[SparkException] { @@ -103,7 +95,7 @@ class FailureSuite extends FunSuite with LocalSparkContext with BeforeAndAfterAl } test("failure because task closure is not serializable") { - sc = new SparkContext("local[1]", "test") + sc = new SparkContext("local[1,1]", "test") val a = new NonSerializable // Non-serializable closure in the final result stage diff --git a/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala index 96adcf7198..35a06c4875 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala @@ -29,9 +29,9 @@ class FakeTaskSetManager( initPriority: Int, initStageId: Int, initNumTasks: Int, - taskScheduler: ClusterScheduler, + clusterScheduler: ClusterScheduler, taskSet: TaskSet) - extends TaskSetManager(taskScheduler, taskSet, 1) { + extends TaskSetManager(clusterScheduler, taskSet, 0) { parent = null weight = 1 @@ -130,8 +130,8 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging } test("FIFO Scheduler Test") { - sc = new SparkContext("local", "TaskSchedulerSuite") - val taskScheduler = new ClusterScheduler(sc) + sc = new SparkContext("local", "ClusterSchedulerSuite") + val clusterScheduler = new ClusterScheduler(sc) var tasks = ArrayBuffer[Task[_]]() val task = new FakeTask(0) tasks += task @@ -141,9 +141,9 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging val schedulableBuilder = new FIFOSchedulableBuilder(rootPool) schedulableBuilder.buildPools() - val taskSetManager0 = createDummyTaskSetManager(0, 0, 2, taskScheduler, taskSet) - val taskSetManager1 = createDummyTaskSetManager(0, 1, 2, taskScheduler, taskSet) - val taskSetManager2 = createDummyTaskSetManager(0, 2, 2, taskScheduler, taskSet) + val taskSetManager0 = createDummyTaskSetManager(0, 0, 2, clusterScheduler, taskSet) + val taskSetManager1 = createDummyTaskSetManager(0, 1, 2, clusterScheduler, taskSet) + val taskSetManager2 = createDummyTaskSetManager(0, 2, 2, clusterScheduler, taskSet) schedulableBuilder.addTaskSetManager(taskSetManager0, null) schedulableBuilder.addTaskSetManager(taskSetManager1, null) schedulableBuilder.addTaskSetManager(taskSetManager2, null) @@ -157,8 +157,8 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging } test("Fair Scheduler Test") { - sc = new SparkContext("local", "TaskSchedulerSuite") - val taskScheduler = new ClusterScheduler(sc) + sc = new SparkContext("local", "ClusterSchedulerSuite") + val clusterScheduler = new ClusterScheduler(sc) var tasks = ArrayBuffer[Task[_]]() val task = new FakeTask(0) tasks += task @@ -186,15 +186,15 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging val properties2 = new Properties() properties2.setProperty("spark.scheduler.pool","2") - val taskSetManager10 = createDummyTaskSetManager(1, 0, 1, taskScheduler, taskSet) - val taskSetManager11 = createDummyTaskSetManager(1, 1, 1, taskScheduler, taskSet) - val taskSetManager12 = createDummyTaskSetManager(1, 2, 2, taskScheduler, taskSet) + val taskSetManager10 = createDummyTaskSetManager(1, 0, 1, clusterScheduler, taskSet) + val taskSetManager11 = createDummyTaskSetManager(1, 1, 1, clusterScheduler, taskSet) + val taskSetManager12 = createDummyTaskSetManager(1, 2, 2, clusterScheduler, taskSet) schedulableBuilder.addTaskSetManager(taskSetManager10, properties1) schedulableBuilder.addTaskSetManager(taskSetManager11, properties1) schedulableBuilder.addTaskSetManager(taskSetManager12, properties1) - val taskSetManager23 = createDummyTaskSetManager(2, 3, 2, taskScheduler, taskSet) - val taskSetManager24 = createDummyTaskSetManager(2, 4, 2, taskScheduler, taskSet) + val taskSetManager23 = createDummyTaskSetManager(2, 3, 2, clusterScheduler, taskSet) + val taskSetManager24 = createDummyTaskSetManager(2, 4, 2, clusterScheduler, taskSet) schedulableBuilder.addTaskSetManager(taskSetManager23, properties2) schedulableBuilder.addTaskSetManager(taskSetManager24, properties2) @@ -214,8 +214,8 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging } test("Nested Pool Test") { - sc = new SparkContext("local", "TaskSchedulerSuite") - val taskScheduler = new ClusterScheduler(sc) + sc = new SparkContext("local", "ClusterSchedulerSuite") + val clusterScheduler = new ClusterScheduler(sc) var tasks = ArrayBuffer[Task[_]]() val task = new FakeTask(0) tasks += task @@ -237,23 +237,23 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging pool1.addSchedulable(pool10) pool1.addSchedulable(pool11) - val taskSetManager000 = createDummyTaskSetManager(0, 0, 5, taskScheduler, taskSet) - val taskSetManager001 = createDummyTaskSetManager(0, 1, 5, taskScheduler, taskSet) + val taskSetManager000 = createDummyTaskSetManager(0, 0, 5, clusterScheduler, taskSet) + val taskSetManager001 = createDummyTaskSetManager(0, 1, 5, clusterScheduler, taskSet) pool00.addSchedulable(taskSetManager000) pool00.addSchedulable(taskSetManager001) - val taskSetManager010 = createDummyTaskSetManager(1, 2, 5, taskScheduler, taskSet) - val taskSetManager011 = createDummyTaskSetManager(1, 3, 5, taskScheduler, taskSet) + val taskSetManager010 = createDummyTaskSetManager(1, 2, 5, clusterScheduler, taskSet) + val taskSetManager011 = createDummyTaskSetManager(1, 3, 5, clusterScheduler, taskSet) pool01.addSchedulable(taskSetManager010) pool01.addSchedulable(taskSetManager011) - val taskSetManager100 = createDummyTaskSetManager(2, 4, 5, taskScheduler, taskSet) - val taskSetManager101 = createDummyTaskSetManager(2, 5, 5, taskScheduler, taskSet) + val taskSetManager100 = createDummyTaskSetManager(2, 4, 5, clusterScheduler, taskSet) + val taskSetManager101 = createDummyTaskSetManager(2, 5, 5, clusterScheduler, taskSet) pool10.addSchedulable(taskSetManager100) pool10.addSchedulable(taskSetManager101) - val taskSetManager110 = createDummyTaskSetManager(3, 6, 5, taskScheduler, taskSet) - val taskSetManager111 = createDummyTaskSetManager(3, 7, 5, taskScheduler, taskSet) + val taskSetManager110 = createDummyTaskSetManager(3, 6, 5, clusterScheduler, taskSet) + val taskSetManager111 = createDummyTaskSetManager(3, 7, 5, clusterScheduler, taskSet) pool11.addSchedulable(taskSetManager110) pool11.addSchedulable(taskSetManager111) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 24689a7093..00f2fdd657 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -33,25 +33,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} -/** - * TaskScheduler that records the task sets that the DAGScheduler requested executed. - */ -class TaskSetRecordingTaskScheduler(sc: SparkContext, - mapOutputTrackerMaster: MapOutputTrackerMaster) extends ClusterScheduler(sc) { - /** Set of TaskSets the DAGScheduler has requested executed. */ - val taskSets = scala.collection.mutable.Buffer[TaskSet]() - override def start() = {} - override def stop() = {} - override def submitTasks(taskSet: TaskSet) = { - // normally done by TaskSetManager - taskSet.tasks.foreach(_.epoch = mapOutputTrackerMaster.getEpoch) - taskSets += taskSet - } - override def cancelTasks(stageId: Int) {} - override def setDAGScheduler(dagScheduler: DAGScheduler) = {} - override def defaultParallelism() = 2 -} - /** * Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler * rather than spawning an event loop thread as happens in the real code. They use EasyMock @@ -65,7 +46,24 @@ class TaskSetRecordingTaskScheduler(sc: SparkContext, * and capturing the resulting TaskSets from the mock TaskScheduler. */ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { - var taskScheduler: TaskSetRecordingTaskScheduler = null + + /** Set of TaskSets the DAGScheduler has requested executed. */ + val taskSets = scala.collection.mutable.Buffer[TaskSet]() + val taskScheduler = new TaskScheduler() { + override def rootPool: Pool = null + override def schedulingMode: SchedulingMode = SchedulingMode.NONE + override def start() = {} + override def stop() = {} + override def submitTasks(taskSet: TaskSet) = { + // normally done by TaskSetManager + taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch) + taskSets += taskSet + } + override def cancelTasks(stageId: Int) {} + override def setDAGScheduler(dagScheduler: DAGScheduler) = {} + override def defaultParallelism() = 2 + } + var mapOutputTracker: MapOutputTrackerMaster = null var scheduler: DAGScheduler = null @@ -98,11 +96,10 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont before { sc = new SparkContext("local", "DAGSchedulerSuite") - mapOutputTracker = new MapOutputTrackerMaster() - taskScheduler = new TaskSetRecordingTaskScheduler(sc, mapOutputTracker) - taskScheduler.taskSets.clear() + taskSets.clear() cacheLocations.clear() results.clear() + mapOutputTracker = new MapOutputTrackerMaster() scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null) { override def runLocally(job: ActiveJob) { // don't bother with the thread while unit testing @@ -207,7 +204,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont test("run trivial job") { val rdd = makeRdd(1, Nil) submit(rdd, Array(0)) - complete(taskScheduler.taskSets(0), List((Success, 42))) + complete(taskSets(0), List((Success, 42))) assert(results === Map(0 -> 42)) } @@ -228,7 +225,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont val baseRdd = makeRdd(1, Nil) val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd))) submit(finalRdd, Array(0)) - complete(taskScheduler.taskSets(0), Seq((Success, 42))) + complete(taskSets(0), Seq((Success, 42))) assert(results === Map(0 -> 42)) } @@ -238,7 +235,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont cacheLocations(baseRdd.id -> 0) = Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")) submit(finalRdd, Array(0)) - val taskSet = taskScheduler.taskSets(0) + val taskSet = taskSets(0) assertLocations(taskSet, Seq(Seq("hostA", "hostB"))) complete(taskSet, Seq((Success, 42))) assert(results === Map(0 -> 42)) @@ -246,7 +243,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont test("trivial job failure") { submit(makeRdd(1, Nil), Array(0)) - failed(taskScheduler.taskSets(0), "some failure") + failed(taskSets(0), "some failure") assert(failure.getMessage === "Job aborted: some failure") } @@ -256,12 +253,12 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont val shuffleId = shuffleDep.shuffleId val reduceRdd = makeRdd(1, List(shuffleDep)) submit(reduceRdd, Array(0)) - complete(taskScheduler.taskSets(0), Seq( + complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)))) assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) - complete(taskScheduler.taskSets(1), Seq((Success, 42))) + complete(taskSets(1), Seq((Success, 42))) assert(results === Map(0 -> 42)) } @@ -271,11 +268,11 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont val shuffleId = shuffleDep.shuffleId val reduceRdd = makeRdd(2, List(shuffleDep)) submit(reduceRdd, Array(0, 1)) - complete(taskScheduler.taskSets(0), Seq( + complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)))) // the 2nd ResultTask failed - complete(taskScheduler.taskSets(1), Seq( + complete(taskSets(1), Seq( (Success, 42), (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), null))) // this will get called @@ -283,10 +280,10 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont // ask the scheduler to try it again scheduler.resubmitFailedStages() // have the 2nd attempt pass - complete(taskScheduler.taskSets(2), Seq((Success, makeMapStatus("hostA", 1)))) + complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1)))) // we can see both result blocks now assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB")) - complete(taskScheduler.taskSets(3), Seq((Success, 43))) + complete(taskSets(3), Seq((Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) } @@ -302,7 +299,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont val newEpoch = mapOutputTracker.getEpoch assert(newEpoch > oldEpoch) val noAccum = Map[Long, Any]() - val taskSet = taskScheduler.taskSets(0) + val taskSet = taskSets(0) // should be ignored for being too old runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum, null, null)) // should work because it's a non-failed host @@ -314,7 +311,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum, null, null)) assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) - complete(taskScheduler.taskSets(1), Seq((Success, 42), (Success, 43))) + complete(taskSets(1), Seq((Success, 42), (Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) } @@ -329,14 +326,14 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont runEvent(ExecutorLost("exec-hostA")) // DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks // rather than marking it is as failed and waiting. - complete(taskScheduler.taskSets(0), Seq( + complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)))) // have hostC complete the resubmitted task - complete(taskScheduler.taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) + complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) - complete(taskScheduler.taskSets(2), Seq((Success, 42))) + complete(taskSets(2), Seq((Success, 42))) assert(results === Map(0 -> 42)) } @@ -348,23 +345,23 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont val finalRdd = makeRdd(1, List(shuffleDepTwo)) submit(finalRdd, Array(0)) // have the first stage complete normally - complete(taskScheduler.taskSets(0), Seq( + complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 2)), (Success, makeMapStatus("hostB", 2)))) // have the second stage complete normally - complete(taskScheduler.taskSets(1), Seq( + complete(taskSets(1), Seq( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostC", 1)))) // fail the third stage because hostA went down - complete(taskScheduler.taskSets(2), Seq( + complete(taskSets(2), Seq( (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null))) // TODO assert this: // blockManagerMaster.removeExecutor("exec-hostA") // have DAGScheduler try again scheduler.resubmitFailedStages() - complete(taskScheduler.taskSets(3), Seq((Success, makeMapStatus("hostA", 2)))) - complete(taskScheduler.taskSets(4), Seq((Success, makeMapStatus("hostA", 1)))) - complete(taskScheduler.taskSets(5), Seq((Success, 42))) + complete(taskSets(3), Seq((Success, makeMapStatus("hostA", 2)))) + complete(taskSets(4), Seq((Success, makeMapStatus("hostA", 1)))) + complete(taskSets(5), Seq((Success, 42))) assert(results === Map(0 -> 42)) } @@ -378,24 +375,24 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD")) cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC")) // complete stage 2 - complete(taskScheduler.taskSets(0), Seq( + complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 2)), (Success, makeMapStatus("hostB", 2)))) // complete stage 1 - complete(taskScheduler.taskSets(1), Seq( + complete(taskSets(1), Seq( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)))) // pretend stage 0 failed because hostA went down - complete(taskScheduler.taskSets(2), Seq( + complete(taskSets(2), Seq( (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null))) // TODO assert this: // blockManagerMaster.removeExecutor("exec-hostA") // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun. scheduler.resubmitFailedStages() - assertLocations(taskScheduler.taskSets(3), Seq(Seq("hostD"))) + assertLocations(taskSets(3), Seq(Seq("hostD"))) // allow hostD to recover - complete(taskScheduler.taskSets(3), Seq((Success, makeMapStatus("hostD", 1)))) - complete(taskScheduler.taskSets(4), Seq((Success, 42))) + complete(taskSets(3), Seq((Success, makeMapStatus("hostD", 1)))) + complete(taskSets(4), Seq((Success, 42))) assert(results === Map(0 -> 42)) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index 2ac2d7a36a..b0d1902c67 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -64,20 +64,18 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA System.setProperty("spark.akka.frameSize", "1") } - before { - sc = new SparkContext("local", "test") - } - override def afterAll { System.clearProperty("spark.akka.frameSize") } test("handling results smaller than Akka frame size") { + sc = new SparkContext("local", "test") val result = sc.parallelize(Seq(1), 1).map(x => 2 * x).reduce((x, y) => x) assert(result === 2) } - test("handling results larger than Akka frame size") { + test("handling results larger than Akka frame size") { + sc = new SparkContext("local", "test") val akkaFrameSize = sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x) @@ -89,13 +87,16 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA } test("task retried if result missing from block manager") { + // Set the maximum number of task failures to > 0, so that the task set isn't aborted + // after the result is missing. + sc = new SparkContext("local[1,1]", "test") // If this test hangs, it's probably because no resource offers were made after the task // failed. val scheduler: ClusterScheduler = sc.taskScheduler match { case clusterScheduler: ClusterScheduler => clusterScheduler case _ => - assert(false, "Expect local cluster to use TaskScheduler") + assert(false, "Expect local cluster to use ClusterScheduler") throw new ClassCastException } scheduler.taskResultGetter = new ResultDeletingTaskResultGetter(sc.env, scheduler) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 592bb11364..4bbb51532d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.executor.TaskMetrics import java.nio.ByteBuffer import org.apache.spark.util.{Utils, FakeClock} -class FakeDAGScheduler(taskScheduler: FakeTaskScheduler) extends DAGScheduler(taskScheduler) { +class FakeDAGScheduler(taskScheduler: FakeClusterScheduler) extends DAGScheduler(taskScheduler) { override def taskStarted(task: Task[_], taskInfo: TaskInfo) { taskScheduler.startedTasks += taskInfo.index } @@ -52,12 +52,12 @@ class FakeDAGScheduler(taskScheduler: FakeTaskScheduler) extends DAGScheduler(ta } /** - * A mock TaskScheduler implementation that just remembers information about tasks started and + * A mock ClusterScheduler implementation that just remembers information about tasks started and * feedback received from the TaskSetManagers. Note that it's important to initialize this with * a list of "live" executors and their hostnames for isExecutorAlive and hasExecutorsAliveOnHost * to work, and these are required for locality in TaskSetManager. */ -class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* execId, host */) +class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /* execId, host */) extends ClusterScheduler(sc) { val startedTasks = new ArrayBuffer[Long] @@ -86,7 +86,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { test("TaskSet with no preferences") { sc = new SparkContext("local", "test") - val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) + val sched = new FakeClusterScheduler(sc, ("exec1", "host1")) val taskSet = createTaskSet(1) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) @@ -112,7 +112,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { test("multiple offers with no preferences") { sc = new SparkContext("local", "test") - val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) + val sched = new FakeClusterScheduler(sc, ("exec1", "host1")) val taskSet = createTaskSet(3) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) @@ -143,7 +143,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { test("basic delay scheduling") { sc = new SparkContext("local", "test") - val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) + val sched = new FakeClusterScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) val taskSet = createTaskSet(4, Seq(TaskLocation("host1", "exec1")), Seq(TaskLocation("host2", "exec2")), @@ -187,7 +187,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { test("delay scheduling with fallback") { sc = new SparkContext("local", "test") - val sched = new FakeTaskScheduler(sc, + val sched = new FakeClusterScheduler(sc, ("exec1", "host1"), ("exec2", "host2"), ("exec3", "host3")) val taskSet = createTaskSet(5, Seq(TaskLocation("host1")), @@ -227,7 +227,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { test("delay scheduling with failed hosts") { sc = new SparkContext("local", "test") - val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) + val sched = new FakeClusterScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) val taskSet = createTaskSet(3, Seq(TaskLocation("host1")), Seq(TaskLocation("host2")), @@ -259,7 +259,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { test("task result lost") { sc = new SparkContext("local", "test") - val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) + val sched = new FakeClusterScheduler(sc, ("exec1", "host1")) val taskSet = createTaskSet(1) val clock = new FakeClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) @@ -276,7 +276,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { test("repeated failures lead to task set abortion") { sc = new SparkContext("local", "test") - val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) + val sched = new FakeClusterScheduler(sc, ("exec1", "host1")) val taskSet = createTaskSet(1) val clock = new FakeClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala index e873400680..4e988b8017 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala @@ -21,16 +21,16 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark._ import org.apache.spark.deploy.yarn.{ApplicationMaster, YarnAllocationHandler} -import org.apache.spark.scheduler.TaskScheduler +import org.apache.spark.scheduler.ClusterScheduler import org.apache.spark.util.Utils /** * - * This is a simple extension to TaskScheduler - to ensure that appropriate initialization of + * This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of * ApplicationMaster, etc. is done */ private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) - extends TaskScheduler(sc) { + extends ClusterScheduler(sc) { logInfo("Created YarnClusterScheduler") -- cgit v1.2.3 From c64690d7252248df97bbe4b2bef8f540b977842d Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Thu, 14 Nov 2013 09:34:56 -0800 Subject: Changed local backend to use Akka actor --- .../spark/scheduler/local/LocalBackend.scala | 80 +++++++++++++++------- 1 file changed, 57 insertions(+), 23 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 3e9d31cd5e..d9b941d694 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -21,21 +21,26 @@ import java.nio.ByteBuffer import akka.actor.{Actor, ActorRef, Props} -import org.apache.spark.{SparkContext, SparkEnv, TaskState} +import org.apache.spark.{Logging, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} import org.apache.spark.scheduler.{SchedulerBackend, ClusterScheduler, WorkerOffer} +private case class ReviveOffers() + +private case class StatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) + +private case class KillTask(taskId: Long) + /** - * LocalBackend is used when running a local version of Spark where the executor, backend, and - * master all run in the same JVM. It sits behind a ClusterScheduler and handles launching tasks - * on a single Executor (created by the LocalBackend) running locally. - * - * THREADING: Because methods can be called both from the Executor and the TaskScheduler, and - * because the Executor class is not thread safe, all methods are synchronized. + * Calls to LocalBackend are all serialized through LocalActor. Using an actor makes the calls on + * LocalBackend asynchronous, which is necessary to prevent deadlock between LocalBackend + * and the ClusterScheduler. */ -private[spark] class LocalBackend(scheduler: ClusterScheduler, private val totalCores: Int) - extends SchedulerBackend with ExecutorBackend { +private[spark] class LocalActor( + scheduler: ClusterScheduler, + executorBackend: LocalBackend, + private val totalCores: Int) extends Actor with Logging { private var freeCores = totalCores @@ -44,31 +49,60 @@ private[spark] class LocalBackend(scheduler: ClusterScheduler, private val total val executor = new Executor(localExecutorId, localExecutorHostname, Seq.empty, isLocal = true) - override def start() { - } + def receive = { + case ReviveOffers => + reviveOffers() - override def stop() { + case StatusUpdate(taskId, state, serializedData) => + scheduler.statusUpdate(taskId, state, serializedData) + if (TaskState.isFinished(state)) { + freeCores += 1 + reviveOffers() + } + + case KillTask(taskId) => + executor.killTask(taskId) } - override def reviveOffers() = synchronized { - val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores)) + def reviveOffers() { + val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores)) for (task <- scheduler.resourceOffers(offers).flatten) { freeCores -= 1 - executor.launchTask(this, task.taskId, task.serializedTask) + executor.launchTask(executorBackend, task.taskId, task.serializedTask) } } +} + +/** + * LocalBackend is used when running a local version of Spark where the executor, backend, and + * master all run in the same JVM. It sits behind a ClusterScheduler and handles launching tasks + * on a single Executor (created by the LocalBackend) running locally. + */ +private[spark] class LocalBackend(scheduler: ClusterScheduler, private val totalCores: Int) + extends SchedulerBackend with ExecutorBackend { + + var localActor: ActorRef = null + + override def start() { + localActor = SparkEnv.get.actorSystem.actorOf( + Props(new LocalActor(scheduler, this, totalCores)), + "LocalBackendActor") + } + + override def stop() { + } + + override def reviveOffers() { + localActor ! ReviveOffers + } override def defaultParallelism() = totalCores - override def killTask(taskId: Long, executorId: String) = synchronized { - executor.killTask(taskId) + override def killTask(taskId: Long, executorId: String) { + localActor ! KillTask(taskId) } - override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) = synchronized { - scheduler.statusUpdate(taskId, state, serializedData) - if (TaskState.isFinished(state)) { - freeCores += 1 - reviveOffers() - } + override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) { + localActor ! StatusUpdate(taskId, state, serializedData) } } -- cgit v1.2.3 From 2b807e4f2f853a9b1e8cba5147d182e7b05022bc Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Thu, 14 Nov 2013 13:33:11 -0800 Subject: Fix bug where scheduler could hang after task failure. When a task fails, we need to call reviveOffers() so that the task can be rescheduled on a different machine. In the current code, the state in ClusterTaskSetManager indicating which tasks are pending may be updated after revive offers is called (there's a race condition here), so when revive offers is called, the task set manager does not yet realize that there are failed tasks that need to be relaunched. --- .../scala/org/apache/spark/scheduler/ClusterScheduler.scala | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala index 37d554715d..2e4ba53d9b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala @@ -250,7 +250,6 @@ private[spark] class ClusterScheduler( def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { var failedExecutor: Option[String] = None - var taskFailed = false synchronized { try { if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) { @@ -270,9 +269,6 @@ private[spark] class ClusterScheduler( } taskIdToExecutorId.remove(tid) } - if (state == TaskState.FAILED) { - taskFailed = true - } activeTaskSets.get(taskSetId).foreach { taskSet => if (state == TaskState.FINISHED) { taskSet.removeRunningTask(tid) @@ -294,10 +290,6 @@ private[spark] class ClusterScheduler( dagScheduler.executorLost(failedExecutor.get) backend.reviveOffers() } - if (taskFailed) { - // Also revive offers if a task had failed for some reason other than host lost - backend.reviveOffers() - } } def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long) { @@ -317,8 +309,9 @@ private[spark] class ClusterScheduler( taskState: TaskState, reason: Option[TaskEndReason]) = synchronized { taskSetManager.handleFailedTask(tid, taskState, reason) - if (taskState == TaskState.FINISHED) { - // The task finished successfully but the result was lost, so we should revive offers. + if (taskState != TaskState.KILLED) { + // Need to revive offers again now that the task set manager state has been updated to + // reflect failed tasks that need to be re-run. backend.reviveOffers() } } -- cgit v1.2.3 From 52144caaa70363ffcc63e1f52db32eb1654c1213 Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Thu, 14 Nov 2013 14:56:53 -0800 Subject: Don't retry tasks if result wasn't serializable --- .../scala/org/apache/spark/scheduler/TaskSetManager.scala | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index bc35e53220..e3929e61ac 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -29,6 +29,7 @@ import org.apache.spark.{ExceptionFailure, FetchFailed, Logging, Resubmitted, Sp Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.util.{SystemClock, Clock} +import java.io.NotSerializableException /** @@ -488,7 +489,16 @@ private[spark] class TaskSetManager( return case ef: ExceptionFailure => - sched.dagScheduler.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null)) + sched.dagScheduler.taskEnded( + tasks(index), ef, null, null, info, ef.metrics.getOrElse(null)) + if (ef.className == classOf[NotSerializableException].getName()) { + // If the task result wasn't rerializable, there's no point in trying to re-execute it. + logError("Task %s:%s had a not serializable result: %s; not retrying".format( + taskSet.id, index, ef.description)) + abort("Task %s:%s had a not serializable result: %s".format( + taskSet.id, index, ef.description)) + return + } val key = ef.description val now = clock.getTime() val (printFull, dupCount) = { -- cgit v1.2.3 From 2b0a6e7d9210ed828395243027c7001f7dae77a4 Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Fri, 15 Nov 2013 18:34:28 -0800 Subject: Fixed error message in ClusterScheduler to be consistent with the old LocalScheduler --- .../main/scala/org/apache/spark/scheduler/TaskSetManager.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index e3929e61ac..7989e6ab32 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -468,6 +468,7 @@ private[spark] class TaskSetManager( removeRunningTask(tid) val index = info.index info.markFailed() + var failureReason = "unknown" if (!successful(index)) { logWarning("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) copiesRunning(index) -= 1 @@ -500,6 +501,7 @@ private[spark] class TaskSetManager( return } val key = ef.description + failureReason = "Exception failure: %s".format(ef.description) val now = clock.getTime() val (printFull, dupCount) = { if (recentExceptions.contains(key)) { @@ -525,7 +527,8 @@ private[spark] class TaskSetManager( } case TaskResultLost => - logWarning("Lost result for TID %s on host %s".format(tid, info.host)) + failureReason = "Lost result for TID %s on host %s".format(tid, info.host) + logWarning(failureReason) sched.dagScheduler.taskEnded(tasks(index), TaskResultLost, null, null, info, null) case _ => {} @@ -537,7 +540,8 @@ private[spark] class TaskSetManager( if (numFailures(index) > maxTaskFailures) { logError("Task %s:%d failed more than %d times; aborting job".format( taskSet.id, index, maxTaskFailures)) - abort("Task %s:%d failed more than %d times".format(taskSet.id, index, maxTaskFailures)) + abort("Task %s:%d failed more than %d times (most recent failure: %s)".format( + taskSet.id, index, maxTaskFailures, failureReason)) } } } else { -- cgit v1.2.3 From 48e4f2ad141492d7dee579a1b7fb1ec49fefa2ae Mon Sep 17 00:00:00 2001 From: "wangda.tan" Date: Mon, 9 Dec 2013 00:02:59 +0800 Subject: SPARK-968, In stage UI, add an overview section that shows task stats grouped by executor id --- .../org/apache/spark/ui/jobs/ExecutorSummary.scala | 27 +++++++ .../org/apache/spark/ui/jobs/ExecutorTable.scala | 73 ++++++++++++++++++ .../scala/org/apache/spark/ui/jobs/IndexPage.scala | 7 ++ .../apache/spark/ui/jobs/JobProgressListener.scala | 38 +++++++++ .../spark/ui/jobs/JobProgressListenerSuite.scala | 89 ++++++++++++++++++++++ 5 files changed, 234 insertions(+) create mode 100644 core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala create mode 100644 core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala new file mode 100644 index 0000000000..f2ee12081c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +private[spark] class ExecutorSummary() { + var duration : Long = 0 + var totalTasks : Int = 0 + var failedTasks : Int = 0 + var succeedTasks : Int = 0 + var shuffleRead : Long = 0 + var shuffleWrite : Long = 0 +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala new file mode 100644 index 0000000000..c6823cd823 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + + +import scala.xml.Node + +import org.apache.spark.scheduler.SchedulingMode + + +/** Page showing executor summary */ +private[spark] class ExecutorTable(val parent: JobProgressUI) { + + val listener = parent.listener + val dateFmt = parent.dateFmt + val isFairScheduler = listener.sc.getSchedulingMode == SchedulingMode.FAIR + + def toNodeSeq(): Seq[Node] = { + listener.synchronized { + executorTable() + } + } + + /** Special table which merges two header cells. */ + private def executorTable[T](): Seq[Node] = { + + + + + + + + + + + + {createExecutorTable()} + +
Executor IDDuration#Tasks#Failed Tasks#Succeed TasksShuffle ReadShuffle Write
+ } + + private def createExecutorTable() : Seq[Node] = { + val executorIdToSummary = listener.executorIdToSummary + executorIdToSummary.toSeq.sortBy(_._1).map{ + case (k,v) => { + + {k} + {v.duration} ms + {v.totalTasks} + {v.failedTasks} + {v.succeedTasks} + {v.shuffleRead} + {v.shuffleWrite} + + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala index ca5a28625b..653a84b60f 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala @@ -45,6 +45,7 @@ private[spark] class IndexPage(parent: JobProgressUI) { val activeStagesTable = new StageTable(activeStages.sortBy(_.submissionTime).reverse, parent) val completedStagesTable = new StageTable(completedStages.sortBy(_.submissionTime).reverse, parent) val failedStagesTable = new StageTable(failedStages.sortBy(_.submissionTime).reverse, parent) + val executorTable = new ExecutorTable(parent) val pools = listener.sc.getAllPools val poolTable = new PoolTable(pools, listener) @@ -56,6 +57,10 @@ private[spark] class IndexPage(parent: JobProgressUI) { {parent.formatDuration(now - listener.sc.startTime)}
  • Scheduling Mode: {parent.sc.getSchedulingMode}
  • +
  • + Executor Summary: + {listener.executorIdToSummary.size} +
  • Active Stages: {activeStages.size} @@ -77,6 +82,8 @@ private[spark] class IndexPage(parent: JobProgressUI) { } else { Seq() }} ++ +

    Executor Summary

    ++ + executorTable.toNodeSeq++

    Active Stages ({activeStages.size})

    ++ activeStagesTable.toNodeSeq++

    Completed Stages ({completedStages.size})

    ++ diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 6b854740d6..2635478592 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -57,6 +57,7 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList val stageIdToTasksFailed = HashMap[Int, Int]() val stageIdToTaskInfos = HashMap[Int, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]]() + val executorIdToSummary = HashMap[String, ExecutorSummary]() override def onJobStart(jobStart: SparkListenerJobStart) {} @@ -114,6 +115,9 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList sid, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]()) taskList += ((taskStart.taskInfo, None, None)) stageIdToTaskInfos(sid) = taskList + val executorSummary = executorIdToSummary.getOrElseUpdate(key = taskStart.taskInfo.executorId, + op = new ExecutorSummary()) + executorSummary.totalTasks += 1 } override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) @@ -123,9 +127,43 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList } override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { + // update executor summary + val executorSummary = executorIdToSummary.get(taskEnd.taskInfo.executorId) + executorSummary match { + case Some(x) => { + // first update failed-task, succeed-task + taskEnd.reason match { + case e: ExceptionFailure => + x.failedTasks += 1 + case _ => + x.succeedTasks += 1 + } + + // update duration + x.duration += taskEnd.taskInfo.duration + + // update shuffle read/write + val shuffleRead = taskEnd.taskMetrics.shuffleReadMetrics + shuffleRead match { + case Some(s) => + x.shuffleRead += s.remoteBytesRead + case _ => {} + } + val shuffleWrite = taskEnd.taskMetrics.shuffleWriteMetrics + shuffleWrite match { + case Some(s) => { + x.shuffleWrite += s.shuffleBytesWritten + } + case _ => {} + } + } + case _ => {} + } + val sid = taskEnd.task.stageId val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]()) tasksActive -= taskEnd.taskInfo + val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) = taskEnd.reason match { case e: ExceptionFailure => diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala new file mode 100644 index 0000000000..90a58978c7 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +import org.scalatest.FunSuite +import org.apache.spark.scheduler._ +import org.apache.spark.SparkContext +import org.apache.spark.Success +import org.apache.spark.scheduler.SparkListenerTaskStart +import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics} + +class JobProgressListenerSuite extends FunSuite { + test("test executor id to summary") { + val sc = new SparkContext("local", "joblogger") + val listener = new JobProgressListener(sc) + val taskMetrics = new TaskMetrics() + val shuffleReadMetrics = new ShuffleReadMetrics() + + // nothing in it + assert(listener.executorIdToSummary.size == 0) + + // launched a task, should get an item in map + listener.onTaskStart(new SparkListenerTaskStart( + new ShuffleMapTask(0, null, null, 0, null), + new TaskInfo(1234L, 0, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL))) + assert(listener.executorIdToSummary.size == 1) + + // finish this task, should get updated shuffleRead + shuffleReadMetrics.remoteBytesRead = 1000 + taskMetrics.shuffleReadMetrics = Some(shuffleReadMetrics) + var taskInfo = new TaskInfo(1234L, 0, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL) + taskInfo.finishTime = 1 + listener.onTaskEnd(new SparkListenerTaskEnd( + new ShuffleMapTask(0, null, null, 0, null), Success, taskInfo, taskMetrics)) + assert(listener.executorIdToSummary.getOrElse("exe-1", fail()).shuffleRead == 1000) + + // finish a task with unknown executor-id, nothing should happen + taskInfo = new TaskInfo(1234L, 0, 1000L, "exe-unknown", "host1", TaskLocality.NODE_LOCAL) + taskInfo.finishTime = 1 + listener.onTaskEnd(new SparkListenerTaskEnd( + new ShuffleMapTask(0, null, null, 0, null), Success, taskInfo, taskMetrics)) + assert(listener.executorIdToSummary.size == 1) + + // launched a task + listener.onTaskStart(new SparkListenerTaskStart( + new ShuffleMapTask(0, null, null, 0, null), + new TaskInfo(1235L, 0, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL))) + assert(listener.executorIdToSummary.size == 1) + + // finish this task, should get updated duration + shuffleReadMetrics.remoteBytesRead = 1000 + taskMetrics.shuffleReadMetrics = Some(shuffleReadMetrics) + taskInfo = new TaskInfo(1235L, 0, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL) + taskInfo.finishTime = 1 + listener.onTaskEnd(new SparkListenerTaskEnd( + new ShuffleMapTask(0, null, null, 0, null), Success, taskInfo, taskMetrics)) + assert(listener.executorIdToSummary.getOrElse("exe-1", fail()).shuffleRead == 2000) + + // launched a task in another exec + listener.onTaskStart(new SparkListenerTaskStart( + new ShuffleMapTask(0, null, null, 0, null), + new TaskInfo(1236L, 0, 0L, "exe-2", "host1", TaskLocality.NODE_LOCAL))) + assert(listener.executorIdToSummary.size == 2) + + // finish this task, should get updated duration + shuffleReadMetrics.remoteBytesRead = 1000 + taskMetrics.shuffleReadMetrics = Some(shuffleReadMetrics) + taskInfo = new TaskInfo(1236L, 0, 0L, "exe-2", "host1", TaskLocality.NODE_LOCAL) + taskInfo.finishTime = 1 + listener.onTaskEnd(new SparkListenerTaskEnd( + new ShuffleMapTask(0, null, null, 0, null), Success, taskInfo, taskMetrics)) + assert(listener.executorIdToSummary.getOrElse("exe-2", fail()).shuffleRead == 1000) + } +} -- cgit v1.2.3 From ee68a85cff499c7aa5d448cc72a93e4de3c23c41 Mon Sep 17 00:00:00 2001 From: "wangda.tan" Date: Mon, 9 Dec 2013 09:38:58 +0800 Subject: SPARK-968, added sc finalize code to avoid akka rebinding to the same port --- .../scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 90a58978c7..861d37a862 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -85,5 +85,12 @@ class JobProgressListenerSuite extends FunSuite { listener.onTaskEnd(new SparkListenerTaskEnd( new ShuffleMapTask(0, null, null, 0, null), Success, taskInfo, taskMetrics)) assert(listener.executorIdToSummary.getOrElse("exe-2", fail()).shuffleRead == 1000) + + // do finalize + sc.stop() + + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") } } -- cgit v1.2.3 From 097e120c0c4132f007bfd0b0254b362ee9a02d8f Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 12 Dec 2013 20:41:51 -0800 Subject: Refactored streaming scheduler and added listener interface. - Refactored Scheduler + JobManager to JobGenerator + JobScheduler and added JobSet for cleaner code. Moved scheduler related code to streaming.scheduler package. - Added StreamingListener trait (similar to SparkListener) to enable gathering to streaming stats like processing times and delays. StreamingContext.addListener() to added listeners. - Deduped some code in streaming tests by modifying TestSuiteBase, and added StreamingListenerSuite. --- .../org/apache/spark/scheduler/SparkListener.scala | 2 +- .../org/apache/spark/streaming/Checkpoint.scala | 2 +- .../scala/org/apache/spark/streaming/DStream.scala | 11 +- .../org/apache/spark/streaming/DStreamGraph.scala | 1 + .../scala/org/apache/spark/streaming/Job.scala | 41 ----- .../org/apache/spark/streaming/JobManager.scala | 88 ----------- .../spark/streaming/NetworkInputTracker.scala | 174 -------------------- .../org/apache/spark/streaming/Scheduler.scala | 131 --------------- .../apache/spark/streaming/StreamingContext.scala | 17 +- .../spark/streaming/dstream/ForEachDStream.scala | 3 +- .../streaming/dstream/NetworkInputDStream.scala | 1 + .../spark/streaming/scheduler/BatchInfo.scala | 38 +++++ .../org/apache/spark/streaming/scheduler/Job.scala | 47 ++++++ .../spark/streaming/scheduler/JobGenerator.scala | 127 +++++++++++++++ .../spark/streaming/scheduler/JobScheduler.scala | 104 ++++++++++++ .../apache/spark/streaming/scheduler/JobSet.scala | 61 +++++++ .../streaming/scheduler/NetworkInputTracker.scala | 175 +++++++++++++++++++++ .../streaming/scheduler/StreamingListener.scala | 37 +++++ .../streaming/scheduler/StreamingListenerBus.scala | 81 ++++++++++ .../spark/streaming/BasicOperationsSuite.scala | 12 -- .../apache/spark/streaming/CheckpointSuite.scala | 26 ++- .../org/apache/spark/streaming/FailureSuite.scala | 13 +- .../apache/spark/streaming/InputStreamsSuite.scala | 12 -- .../spark/streaming/StreamingListenerSuite.scala | 71 +++++++++ .../org/apache/spark/streaming/TestSuiteBase.scala | 32 +++- .../spark/streaming/WindowOperationsSuite.scala | 14 +- 26 files changed, 811 insertions(+), 510 deletions(-) delete mode 100644 streaming/src/main/scala/org/apache/spark/streaming/Job.scala delete mode 100644 streaming/src/main/scala/org/apache/spark/streaming/JobManager.scala delete mode 100644 streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala delete mode 100644 streaming/src/main/scala/org/apache/spark/streaming/Scheduler.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 3841b5616d..2c5d87419d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -63,7 +63,7 @@ trait SparkListener { * Called when a task begins remotely fetching its result (will not be called for tasks that do * not need to fetch the result remotely). */ - def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { } + def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { } /** * Called when a task ends diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 9271914eb5..7b343d2376 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -40,7 +40,7 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) val graph = ssc.graph val checkpointDir = ssc.checkpointDir val checkpointDuration = ssc.checkpointDuration - val pendingTimes = ssc.scheduler.jobManager.getPendingTimes() + val pendingTimes = ssc.scheduler.getPendingTimes() val delaySeconds = MetadataCleaner.getDelaySeconds def validate() { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala index 9ceff754c4..8001c49a76 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala @@ -17,23 +17,18 @@ package org.apache.spark.streaming -import org.apache.spark.streaming.dstream._ import StreamingContext._ -import org.apache.spark.util.MetadataCleaner - -//import Time._ - +import org.apache.spark.streaming.dstream._ +import org.apache.spark.streaming.scheduler.Job import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.MetadataCleaner -import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import java.io.{ObjectInputStream, IOException, ObjectOutputStream} -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.conf.Configuration /** * A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala index b9a58fded6..daed7ff7c3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala @@ -21,6 +21,7 @@ import dstream.InputDStream import java.io.{ObjectInputStream, IOException, ObjectOutputStream} import collection.mutable.ArrayBuffer import org.apache.spark.Logging +import org.apache.spark.streaming.scheduler.Job final private[streaming] class DStreamGraph extends Serializable with Logging { initLogging() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Job.scala b/streaming/src/main/scala/org/apache/spark/streaming/Job.scala deleted file mode 100644 index 2128b7c7a6..0000000000 --- a/streaming/src/main/scala/org/apache/spark/streaming/Job.scala +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming - -import java.util.concurrent.atomic.AtomicLong - -private[streaming] -class Job(val time: Time, func: () => _) { - val id = Job.getNewId() - def run(): Long = { - val startTime = System.currentTimeMillis - func() - val stopTime = System.currentTimeMillis - (stopTime - startTime) - } - - override def toString = "streaming job " + id + " @ " + time -} - -private[streaming] -object Job { - val id = new AtomicLong(0) - - def getNewId() = id.getAndIncrement() -} - diff --git a/streaming/src/main/scala/org/apache/spark/streaming/JobManager.scala b/streaming/src/main/scala/org/apache/spark/streaming/JobManager.scala deleted file mode 100644 index 5233129506..0000000000 --- a/streaming/src/main/scala/org/apache/spark/streaming/JobManager.scala +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming - -import org.apache.spark.Logging -import org.apache.spark.SparkEnv -import java.util.concurrent.Executors -import collection.mutable.HashMap -import collection.mutable.ArrayBuffer - - -private[streaming] -class JobManager(ssc: StreamingContext, numThreads: Int = 1) extends Logging { - - class JobHandler(ssc: StreamingContext, job: Job) extends Runnable { - def run() { - SparkEnv.set(ssc.env) - try { - val timeTaken = job.run() - logInfo("Total delay: %.5f s for job %s of time %s (execution: %.5f s)".format( - (System.currentTimeMillis() - job.time.milliseconds) / 1000.0, job.id, job.time.milliseconds, timeTaken / 1000.0)) - } catch { - case e: Exception => - logError("Running " + job + " failed", e) - } - clearJob(job) - } - } - - initLogging() - - val jobExecutor = Executors.newFixedThreadPool(numThreads) - val jobs = new HashMap[Time, ArrayBuffer[Job]] - - def runJob(job: Job) { - jobs.synchronized { - jobs.getOrElseUpdate(job.time, new ArrayBuffer[Job]) += job - } - jobExecutor.execute(new JobHandler(ssc, job)) - logInfo("Added " + job + " to queue") - } - - def stop() { - jobExecutor.shutdown() - } - - private def clearJob(job: Job) { - var timeCleared = false - val time = job.time - jobs.synchronized { - val jobsOfTime = jobs.get(time) - if (jobsOfTime.isDefined) { - jobsOfTime.get -= job - if (jobsOfTime.get.isEmpty) { - jobs -= time - timeCleared = true - } - } else { - throw new Exception("Job finished for time " + job.time + - " but time does not exist in jobs") - } - } - if (timeCleared) { - ssc.scheduler.clearOldMetadata(time) - } - } - - def getPendingTimes(): Array[Time] = { - jobs.synchronized { - jobs.keySet.toArray - } - } -} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala deleted file mode 100644 index b97fb7e6e3..0000000000 --- a/streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala +++ /dev/null @@ -1,174 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming - -import org.apache.spark.streaming.dstream.{NetworkInputDStream, NetworkReceiver} -import org.apache.spark.streaming.dstream.{StopReceiver, ReportBlock, ReportError} -import org.apache.spark.Logging -import org.apache.spark.SparkEnv -import org.apache.spark.SparkContext._ - -import scala.collection.mutable.HashMap -import scala.collection.mutable.Queue - -import akka.actor._ -import akka.pattern.ask -import akka.util.duration._ -import akka.dispatch._ -import org.apache.spark.storage.BlockId - -private[streaming] sealed trait NetworkInputTrackerMessage -private[streaming] case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage -private[streaming] case class AddBlocks(streamId: Int, blockIds: Seq[BlockId], metadata: Any) extends NetworkInputTrackerMessage -private[streaming] case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage - -/** - * This class manages the execution of the receivers of NetworkInputDStreams. - */ -private[streaming] -class NetworkInputTracker( - @transient ssc: StreamingContext, - @transient networkInputStreams: Array[NetworkInputDStream[_]]) - extends Logging { - - val networkInputStreamMap = Map(networkInputStreams.map(x => (x.id, x)): _*) - val receiverExecutor = new ReceiverExecutor() - val receiverInfo = new HashMap[Int, ActorRef] - val receivedBlockIds = new HashMap[Int, Queue[BlockId]] - val timeout = 5000.milliseconds - - var currentTime: Time = null - - /** Start the actor and receiver execution thread. */ - def start() { - ssc.env.actorSystem.actorOf(Props(new NetworkInputTrackerActor), "NetworkInputTracker") - receiverExecutor.start() - } - - /** Stop the receiver execution thread. */ - def stop() { - // TODO: stop the actor as well - receiverExecutor.interrupt() - receiverExecutor.stopReceivers() - } - - /** Return all the blocks received from a receiver. */ - def getBlockIds(receiverId: Int, time: Time): Array[BlockId] = synchronized { - val queue = receivedBlockIds.synchronized { - receivedBlockIds.getOrElse(receiverId, new Queue[BlockId]()) - } - val result = queue.synchronized { - queue.dequeueAll(x => true) - } - logInfo("Stream " + receiverId + " received " + result.size + " blocks") - result.toArray - } - - /** Actor to receive messages from the receivers. */ - private class NetworkInputTrackerActor extends Actor { - def receive = { - case RegisterReceiver(streamId, receiverActor) => { - if (!networkInputStreamMap.contains(streamId)) { - throw new Exception("Register received for unexpected id " + streamId) - } - receiverInfo += ((streamId, receiverActor)) - logInfo("Registered receiver for network stream " + streamId + " from " + sender.path.address) - sender ! true - } - case AddBlocks(streamId, blockIds, metadata) => { - val tmp = receivedBlockIds.synchronized { - if (!receivedBlockIds.contains(streamId)) { - receivedBlockIds += ((streamId, new Queue[BlockId])) - } - receivedBlockIds(streamId) - } - tmp.synchronized { - tmp ++= blockIds - } - networkInputStreamMap(streamId).addMetadata(metadata) - } - case DeregisterReceiver(streamId, msg) => { - receiverInfo -= streamId - logError("De-registered receiver for network stream " + streamId - + " with message " + msg) - //TODO: Do something about the corresponding NetworkInputDStream - } - } - } - - /** This thread class runs all the receivers on the cluster. */ - class ReceiverExecutor extends Thread { - val env = ssc.env - - override def run() { - try { - SparkEnv.set(env) - startReceivers() - } catch { - case ie: InterruptedException => logInfo("ReceiverExecutor interrupted") - } finally { - stopReceivers() - } - } - - /** - * Get the receivers from the NetworkInputDStreams, distributes them to the - * worker nodes as a parallel collection, and runs them. - */ - def startReceivers() { - val receivers = networkInputStreams.map(nis => { - val rcvr = nis.getReceiver() - rcvr.setStreamId(nis.id) - rcvr - }) - - // Right now, we only honor preferences if all receivers have them - val hasLocationPreferences = receivers.map(_.getLocationPreference().isDefined).reduce(_ && _) - - // Create the parallel collection of receivers to distributed them on the worker nodes - val tempRDD = - if (hasLocationPreferences) { - val receiversWithPreferences = receivers.map(r => (r, Seq(r.getLocationPreference().toString))) - ssc.sc.makeRDD[NetworkReceiver[_]](receiversWithPreferences) - } - else { - ssc.sc.makeRDD(receivers, receivers.size) - } - - // Function to start the receiver on the worker node - val startReceiver = (iterator: Iterator[NetworkReceiver[_]]) => { - if (!iterator.hasNext) { - throw new Exception("Could not start receiver as details not found.") - } - iterator.next().start() - } - // Run the dummy Spark job to ensure that all slaves have registered. - // This avoids all the receivers to be scheduled on the same node. - ssc.sparkContext.makeRDD(1 to 50, 50).map(x => (x, 1)).reduceByKey(_ + _, 20).collect() - - // Distribute the receivers and start them - ssc.sparkContext.runJob(tempRDD, startReceiver) - } - - /** Stops the receivers. */ - def stopReceivers() { - // Signal the receivers to stop - receiverInfo.values.foreach(_ ! StopReceiver) - } - } -} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Scheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/Scheduler.scala deleted file mode 100644 index ed892e33e6..0000000000 --- a/streaming/src/main/scala/org/apache/spark/streaming/Scheduler.scala +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming - -import util.{ManualClock, RecurringTimer, Clock} -import org.apache.spark.SparkEnv -import org.apache.spark.Logging - -private[streaming] -class Scheduler(ssc: StreamingContext) extends Logging { - - initLogging() - - val concurrentJobs = System.getProperty("spark.streaming.concurrentJobs", "1").toInt - val jobManager = new JobManager(ssc, concurrentJobs) - val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) { - new CheckpointWriter(ssc.checkpointDir) - } else { - null - } - - val clockClass = System.getProperty( - "spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock") - val clock = Class.forName(clockClass).newInstance().asInstanceOf[Clock] - val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds, - longTime => generateJobs(new Time(longTime))) - val graph = ssc.graph - var latestTime: Time = null - - def start() = synchronized { - if (ssc.isCheckpointPresent) { - restart() - } else { - startFirstTime() - } - logInfo("Scheduler started") - } - - def stop() = synchronized { - timer.stop() - jobManager.stop() - if (checkpointWriter != null) checkpointWriter.stop() - ssc.graph.stop() - logInfo("Scheduler stopped") - } - - private def startFirstTime() { - val startTime = new Time(timer.getStartTime()) - graph.start(startTime - graph.batchDuration) - timer.start(startTime.milliseconds) - logInfo("Scheduler's timer started at " + startTime) - } - - private def restart() { - - // If manual clock is being used for testing, then - // either set the manual clock to the last checkpointed time, - // or if the property is defined set it to that time - if (clock.isInstanceOf[ManualClock]) { - val lastTime = ssc.initialCheckpoint.checkpointTime.milliseconds - val jumpTime = System.getProperty("spark.streaming.manualClock.jump", "0").toLong - clock.asInstanceOf[ManualClock].setTime(lastTime + jumpTime) - } - - val batchDuration = ssc.graph.batchDuration - - // Batches when the master was down, that is, - // between the checkpoint and current restart time - val checkpointTime = ssc.initialCheckpoint.checkpointTime - val restartTime = new Time(timer.getRestartTime(graph.zeroTime.milliseconds)) - val downTimes = checkpointTime.until(restartTime, batchDuration) - logInfo("Batches during down time: " + downTimes.mkString(", ")) - - // Batches that were unprocessed before failure - val pendingTimes = ssc.initialCheckpoint.pendingTimes - logInfo("Batches pending processing: " + pendingTimes.mkString(", ")) - // Reschedule jobs for these times - val timesToReschedule = (pendingTimes ++ downTimes).distinct.sorted(Time.ordering) - logInfo("Batches to reschedule: " + timesToReschedule.mkString(", ")) - timesToReschedule.foreach(time => - graph.generateJobs(time).foreach(jobManager.runJob) - ) - - // Restart the timer - timer.start(restartTime.milliseconds) - logInfo("Scheduler's timer restarted at " + restartTime) - } - - /** Generate jobs and perform checkpoint for the given `time`. */ - def generateJobs(time: Time) { - SparkEnv.set(ssc.env) - logInfo("\n-----------------------------------------------------\n") - graph.generateJobs(time).foreach(jobManager.runJob) - latestTime = time - doCheckpoint(time) - } - - /** - * Clear old metadata assuming jobs of `time` have finished processing. - * And also perform checkpoint. - */ - def clearOldMetadata(time: Time) { - ssc.graph.clearOldMetadata(time) - doCheckpoint(time) - } - - /** Perform checkpoint for the give `time`. */ - def doCheckpoint(time: Time) = synchronized { - if (ssc.checkpointDuration != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) { - logInfo("Checkpointing graph for time " + time) - ssc.graph.updateCheckpointData(time) - checkpointWriter.write(new Checkpoint(ssc, time)) - } - } -} - diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 70bf902143..83f1cadb48 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -46,6 +46,7 @@ import org.apache.hadoop.mapreduce.lib.input.TextInputFormat import org.apache.hadoop.fs.Path import twitter4j.Status import twitter4j.auth.Authorization +import org.apache.spark.streaming.scheduler._ /** @@ -146,9 +147,10 @@ class StreamingContext private ( } } - protected[streaming] var checkpointDuration: Duration = if (isCheckpointPresent) cp_.checkpointDuration else null - protected[streaming] var receiverJobThread: Thread = null - protected[streaming] var scheduler: Scheduler = null + protected[streaming] val checkpointDuration: Duration = { + if (isCheckpointPresent) cp_.checkpointDuration else graph.batchDuration + } + protected[streaming] val scheduler = new JobScheduler(this) /** * Return the associated Spark context @@ -510,6 +512,10 @@ class StreamingContext private ( graph.addOutputStream(outputStream) } + def addListener(streamingListener: StreamingListener) { + scheduler.listenerBus.addListener(streamingListener) + } + protected def validate() { assert(graph != null, "Graph is null") graph.validate() @@ -525,9 +531,6 @@ class StreamingContext private ( * Start the execution of the streams. */ def start() { - if (checkpointDir != null && checkpointDuration == null && graph != null) { - checkpointDuration = graph.batchDuration - } validate() @@ -545,7 +548,6 @@ class StreamingContext private ( Thread.sleep(1000) // Start the scheduler - scheduler = new Scheduler(this) scheduler.start() } @@ -556,7 +558,6 @@ class StreamingContext private ( try { if (scheduler != null) scheduler.stop() if (networkInputTracker != null) networkInputTracker.stop() - if (receiverJobThread != null) receiverJobThread.interrupt() sc.stop() logInfo("StreamingContext stopped successfully") } catch { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala index e21bac4602..0072248b7d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala @@ -18,7 +18,8 @@ package org.apache.spark.streaming.dstream import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.{Duration, DStream, Job, Time} +import org.apache.spark.streaming.{Duration, DStream, Time} +import org.apache.spark.streaming.scheduler.Job private[streaming] class ForEachDStream[T: ClassManifest] ( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala index a82862c802..1df7f547c9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala @@ -32,6 +32,7 @@ import org.apache.spark.streaming._ import org.apache.spark.{Logging, SparkEnv} import org.apache.spark.rdd.{RDD, BlockRDD} import org.apache.spark.storage.{BlockId, StorageLevel, StreamBlockId} +import org.apache.spark.streaming.scheduler.{DeregisterReceiver, AddBlocks, RegisterReceiver} /** * Abstract class for defining any InputDStream that has to start a receiver on worker diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala new file mode 100644 index 0000000000..798598ad50 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import org.apache.spark.streaming.Time + +case class BatchInfo( + batchTime: Time, + submissionTime: Long, + processingStartTime: Option[Long], + processingEndTime: Option[Long] + ) { + + def schedulingDelay = processingStartTime.map(_ - submissionTime) + + def processingDelay = processingEndTime.zip(processingStartTime).map(x => x._1 - x._2).headOption + + def totalDelay = schedulingDelay.zip(processingDelay).map(x => x._1 + x._2).headOption +} + + + + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala new file mode 100644 index 0000000000..bca5e1f1a5 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import java.util.concurrent.atomic.AtomicLong +import org.apache.spark.streaming.Time + +private[streaming] +class Job(val time: Time, func: () => _) { + var id: String = _ + + def run(): Long = { + val startTime = System.currentTimeMillis + func() + val stopTime = System.currentTimeMillis + (stopTime - startTime) + } + + def setId(number: Int) { + id = "streaming job " + time + "." + number + } + + override def toString = id +} +/* +private[streaming] +object Job { + val id = new AtomicLong(0) + + def getNewId() = id.getAndIncrement() +} +*/ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala new file mode 100644 index 0000000000..5d3ce9c398 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import org.apache.spark.SparkEnv +import org.apache.spark.Logging +import org.apache.spark.streaming.{Checkpoint, Time, CheckpointWriter} +import org.apache.spark.streaming.util.{ManualClock, RecurringTimer, Clock} + +private[streaming] +class JobGenerator(jobScheduler: JobScheduler) extends Logging { + + initLogging() + val ssc = jobScheduler.ssc + val clockClass = System.getProperty( + "spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock") + val clock = Class.forName(clockClass).newInstance().asInstanceOf[Clock] + val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds, + longTime => generateJobs(new Time(longTime))) + val graph = ssc.graph + lazy val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) { + new CheckpointWriter(ssc.checkpointDir) + } else { + null + } + + var latestTime: Time = null + + def start() = synchronized { + if (ssc.isCheckpointPresent) { + restart() + } else { + startFirstTime() + } + logInfo("JobGenerator started") + } + + def stop() = synchronized { + timer.stop() + if (checkpointWriter != null) checkpointWriter.stop() + ssc.graph.stop() + logInfo("JobGenerator stopped") + } + + private def startFirstTime() { + val startTime = new Time(timer.getStartTime()) + graph.start(startTime - graph.batchDuration) + timer.start(startTime.milliseconds) + logInfo("JobGenerator's timer started at " + startTime) + } + + private def restart() { + // If manual clock is being used for testing, then + // either set the manual clock to the last checkpointed time, + // or if the property is defined set it to that time + if (clock.isInstanceOf[ManualClock]) { + val lastTime = ssc.initialCheckpoint.checkpointTime.milliseconds + val jumpTime = System.getProperty("spark.streaming.manualClock.jump", "0").toLong + clock.asInstanceOf[ManualClock].setTime(lastTime + jumpTime) + } + + val batchDuration = ssc.graph.batchDuration + + // Batches when the master was down, that is, + // between the checkpoint and current restart time + val checkpointTime = ssc.initialCheckpoint.checkpointTime + val restartTime = new Time(timer.getRestartTime(graph.zeroTime.milliseconds)) + val downTimes = checkpointTime.until(restartTime, batchDuration) + logInfo("Batches during down time: " + downTimes.mkString(", ")) + + // Batches that were unprocessed before failure + val pendingTimes = ssc.initialCheckpoint.pendingTimes + logInfo("Batches pending processing: " + pendingTimes.mkString(", ")) + // Reschedule jobs for these times + val timesToReschedule = (pendingTimes ++ downTimes).distinct.sorted(Time.ordering) + logInfo("Batches to reschedule: " + timesToReschedule.mkString(", ")) + timesToReschedule.foreach(time => + jobScheduler.runJobs(time, graph.generateJobs(time)) + ) + + // Restart the timer + timer.start(restartTime.milliseconds) + logInfo("JobGenerator's timer restarted at " + restartTime) + } + + /** Generate jobs and perform checkpoint for the given `time`. */ + private def generateJobs(time: Time) { + SparkEnv.set(ssc.env) + logInfo("\n-----------------------------------------------------\n") + jobScheduler.runJobs(time, graph.generateJobs(time)) + latestTime = time + doCheckpoint(time) + } + + /** + * On batch completion, clear old metadata and checkpoint computation. + */ + private[streaming] def onBatchCompletion(time: Time) { + ssc.graph.clearOldMetadata(time) + doCheckpoint(time) + } + + /** Perform checkpoint for the give `time`. */ + private def doCheckpoint(time: Time) = synchronized { + if (checkpointWriter != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) { + logInfo("Checkpointing graph for time " + time) + ssc.graph.updateCheckpointData(time) + checkpointWriter.write(new Checkpoint(ssc, time)) + } + } +} + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala new file mode 100644 index 0000000000..14906fd720 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import org.apache.spark.Logging +import org.apache.spark.SparkEnv +import java.util.concurrent.{TimeUnit, ConcurrentHashMap, Executors} +import scala.collection.mutable.HashSet +import org.apache.spark.streaming._ + + +private[streaming] +class JobScheduler(val ssc: StreamingContext) extends Logging { + + initLogging() + + val jobSets = new ConcurrentHashMap[Time, JobSet] + val numConcurrentJobs = System.getProperty("spark.streaming.concurrentJobs", "1").toInt + val executor = Executors.newFixedThreadPool(numConcurrentJobs) + val generator = new JobGenerator(this) + val listenerBus = new StreamingListenerBus() + + def clock = generator.clock + + def start() { + generator.start() + } + + def stop() { + generator.stop() + executor.shutdown() + if (!executor.awaitTermination(5, TimeUnit.SECONDS)) { + executor.shutdownNow() + } + } + + def runJobs(time: Time, jobs: Seq[Job]) { + if (jobs.isEmpty) { + logInfo("No jobs added for time " + time) + } else { + val jobSet = new JobSet(time, jobs) + jobSets.put(time, jobSet) + jobSet.jobs.foreach(job => executor.execute(new JobHandler(job))) + logInfo("Added jobs for time " + time) + } + } + + def getPendingTimes(): Array[Time] = { + jobSets.keySet.toArray(new Array[Time](0)) + } + + private def beforeJobStart(job: Job) { + val jobSet = jobSets.get(job.time) + if (!jobSet.hasStarted) { + listenerBus.post(StreamingListenerBatchStarted(jobSet.toBatchInfo())) + } + jobSet.beforeJobStart(job) + logInfo("Starting job " + job.id + " from job set of time " + jobSet.time) + SparkEnv.set(generator.ssc.env) + } + + private def afterJobEnd(job: Job) { + val jobSet = jobSets.get(job.time) + jobSet.afterJobStop(job) + logInfo("Finished job " + job.id + " from job set of time " + jobSet.time) + if (jobSet.hasCompleted) { + listenerBus.post(StreamingListenerBatchCompleted(jobSet.toBatchInfo())) + jobSets.remove(jobSet.time) + generator.onBatchCompletion(jobSet.time) + logInfo("Total delay: %.3f s for time %s (execution: %.3f s)".format( + jobSet.totalDelay / 1000.0, jobSet.time.toString, + jobSet.processingDelay / 1000.0 + )) + } + } + + class JobHandler(job: Job) extends Runnable { + def run() { + beforeJobStart(job) + try { + job.run() + } catch { + case e: Exception => + logError("Running " + job + " failed", e) + } + afterJobEnd(job) + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala new file mode 100644 index 0000000000..05233d095b --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import scala.collection.mutable.HashSet +import org.apache.spark.streaming.Time + +private[streaming] +case class JobSet(time: Time, jobs: Seq[Job]) { + + private val incompleteJobs = new HashSet[Job]() + var submissionTime = System.currentTimeMillis() + var processingStartTime = -1L + var processingEndTime = -1L + + jobs.zipWithIndex.foreach { case (job, i) => job.setId(i) } + incompleteJobs ++= jobs + + def beforeJobStart(job: Job) { + if (processingStartTime < 0) processingStartTime = System.currentTimeMillis() + } + + def afterJobStop(job: Job) { + incompleteJobs -= job + if (hasCompleted) processingEndTime = System.currentTimeMillis() + } + + def hasStarted() = (processingStartTime > 0) + + def hasCompleted() = incompleteJobs.isEmpty + + def processingDelay = processingEndTime - processingStartTime + + def totalDelay = { + processingEndTime - time.milliseconds + } + + def toBatchInfo(): BatchInfo = { + new BatchInfo( + time, + submissionTime, + if (processingStartTime >= 0 ) Some(processingStartTime) else None, + if (processingEndTime >= 0 ) Some(processingEndTime) else None + ) + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala new file mode 100644 index 0000000000..c759302a61 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import org.apache.spark.streaming.dstream.{NetworkInputDStream, NetworkReceiver} +import org.apache.spark.streaming.dstream.{StopReceiver, ReportBlock, ReportError} +import org.apache.spark.Logging +import org.apache.spark.SparkEnv +import org.apache.spark.SparkContext._ + +import scala.collection.mutable.HashMap +import scala.collection.mutable.Queue + +import akka.actor._ +import akka.pattern.ask +import akka.util.duration._ +import akka.dispatch._ +import org.apache.spark.storage.BlockId +import org.apache.spark.streaming.{Time, StreamingContext} + +private[streaming] sealed trait NetworkInputTrackerMessage +private[streaming] case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage +private[streaming] case class AddBlocks(streamId: Int, blockIds: Seq[BlockId], metadata: Any) extends NetworkInputTrackerMessage +private[streaming] case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage + +/** + * This class manages the execution of the receivers of NetworkInputDStreams. + */ +private[streaming] +class NetworkInputTracker( + @transient ssc: StreamingContext, + @transient networkInputStreams: Array[NetworkInputDStream[_]]) + extends Logging { + + val networkInputStreamMap = Map(networkInputStreams.map(x => (x.id, x)): _*) + val receiverExecutor = new ReceiverExecutor() + val receiverInfo = new HashMap[Int, ActorRef] + val receivedBlockIds = new HashMap[Int, Queue[BlockId]] + val timeout = 5000.milliseconds + + var currentTime: Time = null + + /** Start the actor and receiver execution thread. */ + def start() { + ssc.env.actorSystem.actorOf(Props(new NetworkInputTrackerActor), "NetworkInputTracker") + receiverExecutor.start() + } + + /** Stop the receiver execution thread. */ + def stop() { + // TODO: stop the actor as well + receiverExecutor.interrupt() + receiverExecutor.stopReceivers() + } + + /** Return all the blocks received from a receiver. */ + def getBlockIds(receiverId: Int, time: Time): Array[BlockId] = synchronized { + val queue = receivedBlockIds.synchronized { + receivedBlockIds.getOrElse(receiverId, new Queue[BlockId]()) + } + val result = queue.synchronized { + queue.dequeueAll(x => true) + } + logInfo("Stream " + receiverId + " received " + result.size + " blocks") + result.toArray + } + + /** Actor to receive messages from the receivers. */ + private class NetworkInputTrackerActor extends Actor { + def receive = { + case RegisterReceiver(streamId, receiverActor) => { + if (!networkInputStreamMap.contains(streamId)) { + throw new Exception("Register received for unexpected id " + streamId) + } + receiverInfo += ((streamId, receiverActor)) + logInfo("Registered receiver for network stream " + streamId + " from " + sender.path.address) + sender ! true + } + case AddBlocks(streamId, blockIds, metadata) => { + val tmp = receivedBlockIds.synchronized { + if (!receivedBlockIds.contains(streamId)) { + receivedBlockIds += ((streamId, new Queue[BlockId])) + } + receivedBlockIds(streamId) + } + tmp.synchronized { + tmp ++= blockIds + } + networkInputStreamMap(streamId).addMetadata(metadata) + } + case DeregisterReceiver(streamId, msg) => { + receiverInfo -= streamId + logError("De-registered receiver for network stream " + streamId + + " with message " + msg) + //TODO: Do something about the corresponding NetworkInputDStream + } + } + } + + /** This thread class runs all the receivers on the cluster. */ + class ReceiverExecutor extends Thread { + val env = ssc.env + + override def run() { + try { + SparkEnv.set(env) + startReceivers() + } catch { + case ie: InterruptedException => logInfo("ReceiverExecutor interrupted") + } finally { + stopReceivers() + } + } + + /** + * Get the receivers from the NetworkInputDStreams, distributes them to the + * worker nodes as a parallel collection, and runs them. + */ + def startReceivers() { + val receivers = networkInputStreams.map(nis => { + val rcvr = nis.getReceiver() + rcvr.setStreamId(nis.id) + rcvr + }) + + // Right now, we only honor preferences if all receivers have them + val hasLocationPreferences = receivers.map(_.getLocationPreference().isDefined).reduce(_ && _) + + // Create the parallel collection of receivers to distributed them on the worker nodes + val tempRDD = + if (hasLocationPreferences) { + val receiversWithPreferences = receivers.map(r => (r, Seq(r.getLocationPreference().toString))) + ssc.sc.makeRDD[NetworkReceiver[_]](receiversWithPreferences) + } + else { + ssc.sc.makeRDD(receivers, receivers.size) + } + + // Function to start the receiver on the worker node + val startReceiver = (iterator: Iterator[NetworkReceiver[_]]) => { + if (!iterator.hasNext) { + throw new Exception("Could not start receiver as details not found.") + } + iterator.next().start() + } + // Run the dummy Spark job to ensure that all slaves have registered. + // This avoids all the receivers to be scheduled on the same node. + ssc.sparkContext.makeRDD(1 to 50, 50).map(x => (x, 1)).reduceByKey(_ + _, 20).collect() + + // Distribute the receivers and start them + ssc.sparkContext.runJob(tempRDD, startReceiver) + } + + /** Stops the receivers. */ + def stopReceivers() { + // Signal the receivers to stop + receiverInfo.values.foreach(_ ! StopReceiver) + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala new file mode 100644 index 0000000000..49fd0d29c3 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +sealed trait StreamingListenerEvent + +case class StreamingListenerBatchCompleted(batchInfo: BatchInfo) extends StreamingListenerEvent + +case class StreamingListenerBatchStarted(batchInfo: BatchInfo) extends StreamingListenerEvent + +trait StreamingListener { + + /** + * Called when processing of a batch has completed + */ + def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { } + + /** + * Called when processing of a batch has started + */ + def onBatchStarted(batchStarted: StreamingListenerBatchStarted) { } +} \ No newline at end of file diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala new file mode 100644 index 0000000000..324e491914 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import org.apache.spark.Logging +import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import java.util.concurrent.LinkedBlockingQueue + +/** Asynchronously passes StreamingListenerEvents to registered StreamingListeners. */ +private[spark] class StreamingListenerBus() extends Logging { + private val listeners = new ArrayBuffer[StreamingListener]() with SynchronizedBuffer[StreamingListener] + + /* Cap the capacity of the SparkListenerEvent queue so we get an explicit error (rather than + * an OOM exception) if it's perpetually being added to more quickly than it's being drained. */ + private val EVENT_QUEUE_CAPACITY = 10000 + private val eventQueue = new LinkedBlockingQueue[StreamingListenerEvent](EVENT_QUEUE_CAPACITY) + private var queueFullErrorMessageLogged = false + + new Thread("StreamingListenerBus") { + setDaemon(true) + override def run() { + while (true) { + val event = eventQueue.take + event match { + case batchStarted: StreamingListenerBatchStarted => + listeners.foreach(_.onBatchStarted(batchStarted)) + case batchCompleted: StreamingListenerBatchCompleted => + listeners.foreach(_.onBatchCompleted(batchCompleted)) + case _ => + } + } + } + }.start() + + def addListener(listener: StreamingListener) { + listeners += listener + } + + def post(event: StreamingListenerEvent) { + val eventAdded = eventQueue.offer(event) + if (!eventAdded && !queueFullErrorMessageLogged) { + logError("Dropping SparkListenerEvent because no remaining room in event queue. " + + "This likely means one of the SparkListeners is too slow and cannot keep up with the " + + "rate at which tasks are being started by the scheduler.") + queueFullErrorMessageLogged = true + } + } + + /** + * Waits until there are no more events in the queue, or until the specified time has elapsed. + * Used for testing only. Returns true if the queue has emptied and false is the specified time + * elapsed before the queue emptied. + */ + def waitUntilEmpty(timeoutMillis: Int): Boolean = { + val finishTime = System.currentTimeMillis + timeoutMillis + while (!eventQueue.isEmpty()) { + if (System.currentTimeMillis > finishTime) { + return false + } + /* Sleep rather than using wait/notify, because this is used only for testing and wait/notify + * add overhead in the general case. */ + Thread.sleep(10) + } + return true + } +} \ No newline at end of file diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 259ef1608c..b35ca00b53 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -26,18 +26,6 @@ import util.ManualClock class BasicOperationsSuite extends TestSuiteBase { - override def framework() = "BasicOperationsSuite" - - before { - System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") - } - - after { - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.driver.port") - System.clearProperty("spark.hostPort") - } - test("map") { val input = Seq(1 to 4, 5 to 8, 9 to 12) testOperation( diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index beb20831bd..c93075e3b3 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -34,31 +34,25 @@ import com.google.common.io.Files * the checkpointing of a DStream's RDDs as well as the checkpointing of * the whole DStream graph. */ -class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { +class CheckpointSuite extends TestSuiteBase { - System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") + var ssc: StreamingContext = null + + override def batchDuration = Milliseconds(500) + + override def actuallyWait = true // to allow checkpoints to be written - before { + override def beforeFunction() { + super.beforeFunction() FileUtils.deleteDirectory(new File(checkpointDir)) } - after { + override def afterFunction() { + super.afterFunction() if (ssc != null) ssc.stop() FileUtils.deleteDirectory(new File(checkpointDir)) - - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.driver.port") - System.clearProperty("spark.hostPort") } - var ssc: StreamingContext = null - - override def framework = "CheckpointSuite" - - override def batchDuration = Milliseconds(500) - - override def actuallyWait = true - test("basic rdd checkpoints + dstream graph checkpoint recovery") { assert(batchDuration === Milliseconds(500), "batchDuration for this test must be 1 second") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala index 6337c5359c..da9b04de1a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala @@ -32,17 +32,22 @@ import collection.mutable.ArrayBuffer * This testsuite tests master failures at random times while the stream is running using * the real clock. */ -class FailureSuite extends FunSuite with BeforeAndAfter with Logging { +class FailureSuite extends TestSuiteBase with Logging { var directory = "FailureSuite" val numBatches = 30 - val batchDuration = Milliseconds(1000) - before { + override def batchDuration = Milliseconds(1000) + + override def useManualClock = false + + override def beforeFunction() { + super.beforeFunction() FileUtils.deleteDirectory(new File(directory)) } - after { + override def afterFunction() { + super.afterFunction() FileUtils.deleteDirectory(new File(directory)) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 7dc82decef..62a9f120b4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -50,18 +50,6 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val testPort = 9999 - override def checkpointDir = "checkpoint" - - before { - System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") - } - - after { - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.driver.port") - System.clearProperty("spark.hostPort") - } - test("socket input stream") { // Start the server val testServer = new TestServer() diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala new file mode 100644 index 0000000000..826c839932 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import org.apache.spark.streaming.scheduler._ +import scala.collection.mutable.ArrayBuffer +import org.scalatest.matchers.ShouldMatchers + +class StreamingListenerSuite extends TestSuiteBase with ShouldMatchers{ + + val input = (1 to 4).map(Seq(_)).toSeq + val operation = (d: DStream[Int]) => d.map(x => x) + + // To make sure that the processing start and end times in collected + // information are different for successive batches + override def batchDuration = Milliseconds(100) + override def actuallyWait = true + + test("basic BatchInfo generation") { + val ssc = setupStreams(input, operation) + val collector = new BatchInfoCollector + ssc.addListener(collector) + runStreams(ssc, input.size, input.size) + val batchInfos = collector.batchInfos + batchInfos should have size 4 + + batchInfos.foreach(info => { + info.schedulingDelay should not be None + info.processingDelay should not be None + info.totalDelay should not be None + info.schedulingDelay.get should be >= 0L + info.processingDelay.get should be >= 0L + info.totalDelay.get should be >= 0L + }) + + isInIncreasingOrder(batchInfos.map(_.submissionTime)) should be (true) + isInIncreasingOrder(batchInfos.map(_.processingStartTime.get)) should be (true) + isInIncreasingOrder(batchInfos.map(_.processingEndTime.get)) should be (true) + } + + /** Check if a sequence of numbers is in increasing order */ + def isInIncreasingOrder(seq: Seq[Long]): Boolean = { + for(i <- 1 until seq.size) { + if (seq(i - 1) > seq(i)) return false + } + true + } + + /** Listener that collects information on processed batches */ + class BatchInfoCollector extends StreamingListener { + val batchInfos = new ArrayBuffer[BatchInfo] + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { + batchInfos += batchCompleted.batchInfo + } + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 8c8c359e6e..fbbeb8f0ee 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -109,7 +109,7 @@ class TestOutputStreamWithPartitions[T: ClassManifest](parent: DStream[T], trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { // Name of the framework for Spark context - def framework = "TestSuiteBase" + def framework = this.getClass.getSimpleName // Master for Spark context def master = "local[2]" @@ -126,9 +126,39 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { // Maximum time to wait before the test times out def maxWaitTimeMillis = 10000 + // Whether to use manual clock or not + def useManualClock = true + // Whether to actually wait in real time before changing manual clock def actuallyWait = false + // Default before function for any streaming test suite. Override this + // if you want to add your stuff to "before" (i.e., don't call before { } ) + def beforeFunction() { + if (useManualClock) { + System.setProperty( + "spark.streaming.clock", + "org.apache.spark.streaming.util.ManualClock" + ) + } else { + System.clearProperty("spark.streaming.clock") + } + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") + } + + // Default after function for any streaming test suite. Override this + // if you want to add your stuff to "after" (i.e., don't call after { } ) + def afterFunction() { + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") + } + + before(beforeFunction) + after(afterFunction) + /** * Set up required DStreams to test the DStream operation using the two sequences * of input collections. diff --git a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala index f50e05c0d8..6b4aaefcdf 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala @@ -22,19 +22,9 @@ import collection.mutable.ArrayBuffer class WindowOperationsSuite extends TestSuiteBase { - System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") + override def maxWaitTimeMillis = 20000 // large window tests can sometimes take longer - override def framework = "WindowOperationsSuite" - - override def maxWaitTimeMillis = 20000 - - override def batchDuration = Seconds(1) - - after { - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.driver.port") - System.clearProperty("spark.hostPort") - } + override def batchDuration = Seconds(1) // making sure its visible in this class val largerSlideInput = Seq( Seq(("a", 1)), -- cgit v1.2.3 From 36060f4f50ead2632117bb12e8c5bc1fb4f91f1e Mon Sep 17 00:00:00 2001 From: "wangda.tan" Date: Tue, 17 Dec 2013 17:55:38 +0800 Subject: spark-898, changes according to review comments --- .../org/apache/spark/ui/exec/ExecutorsUI.scala | 39 +++++++++++++++++-- .../org/apache/spark/ui/jobs/ExecutorSummary.scala | 3 +- .../org/apache/spark/ui/jobs/ExecutorTable.scala | 40 ++++++++++--------- .../scala/org/apache/spark/ui/jobs/IndexPage.scala | 5 +-- .../apache/spark/ui/jobs/JobProgressListener.scala | 31 ++++++++------- .../scala/org/apache/spark/ui/jobs/StagePage.scala | 3 +- .../spark/ui/jobs/JobProgressListenerSuite.scala | 45 ++++++---------------- 7 files changed, 90 insertions(+), 76 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala index e596690bc3..808bbe8c8f 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala @@ -56,7 +56,8 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize)).fold(0L)(_+_) val execHead = Seq("Executor ID", "Address", "RDD blocks", "Memory used", "Disk used", - "Active tasks", "Failed tasks", "Complete tasks", "Total tasks") + "Active tasks", "Failed tasks", "Complete tasks", "Total tasks", "Duration", "Shuffle Read", + "Shuffle Write") def execRow(kv: Seq[String]) = { @@ -73,6 +74,9 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { {kv(7)} {kv(8)} {kv(9)} + {Utils.msDurationToString(kv(10).toLong)} + {Utils.bytesToString(kv(11).toLong)} + {Utils.bytesToString(kv(12).toLong)} } @@ -111,6 +115,9 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { val failedTasks = listener.executorToTasksFailed.getOrElse(execId, 0) val completedTasks = listener.executorToTasksComplete.getOrElse(execId, 0) val totalTasks = activeTasks + failedTasks + completedTasks + val totalDuration = listener.executorToDuration.getOrElse(execId, 0) + val totalShuffleRead = listener.executorToShuffleRead.getOrElse(execId, 0) + val totalShuffleWrite = listener.executorToShuffleWrite.getOrElse(execId, 0) Seq( execId, @@ -122,7 +129,10 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { activeTasks.toString, failedTasks.toString, completedTasks.toString, - totalTasks.toString + totalTasks.toString, + totalDuration.toString, + totalShuffleRead.toString, + totalShuffleWrite.toString ) } @@ -130,6 +140,9 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { val executorToTasksActive = HashMap[String, HashSet[TaskInfo]]() val executorToTasksComplete = HashMap[String, Int]() val executorToTasksFailed = HashMap[String, Int]() + val executorToDuration = HashMap[String, Long]() + val executorToShuffleRead = HashMap[String, Long]() + val executorToShuffleWrite = HashMap[String, Long]() override def onTaskStart(taskStart: SparkListenerTaskStart) { val eid = taskStart.taskInfo.executorId @@ -137,9 +150,12 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { activeTasks += taskStart.taskInfo } - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { val eid = taskEnd.taskInfo.executorId val activeTasks = executorToTasksActive.getOrElseUpdate(eid, new HashSet[TaskInfo]()) + val newDuration = executorToDuration.getOrElse(eid, 0L) + taskEnd.taskInfo.duration + executorToDuration.put(eid, newDuration) + activeTasks -= taskEnd.taskInfo val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) = taskEnd.reason match { @@ -150,6 +166,23 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { executorToTasksComplete(eid) = executorToTasksComplete.getOrElse(eid, 0) + 1 (None, Option(taskEnd.taskMetrics)) } + + // update shuffle read/write + val shuffleRead = taskEnd.taskMetrics.shuffleReadMetrics + shuffleRead match { + case Some(s) => + val newShuffleRead = executorToShuffleRead.getOrElse(eid, 0L) + s.remoteBytesRead + executorToShuffleRead.put(eid, newShuffleRead) + case _ => {} + } + val shuffleWrite = taskEnd.taskMetrics.shuffleWriteMetrics + shuffleWrite match { + case Some(s) => { + val newShuffleWrite = executorToShuffleWrite.getOrElse(eid, 0L) + s.shuffleBytesWritten + executorToShuffleWrite.put(eid, newShuffleWrite) + } + case _ => {} + } } } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala index f2ee12081c..75c0dd2c7f 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala @@ -19,9 +19,8 @@ package org.apache.spark.ui.jobs private[spark] class ExecutorSummary() { var duration : Long = 0 - var totalTasks : Int = 0 var failedTasks : Int = 0 - var succeedTasks : Int = 0 + var succeededTasks : Int = 0 var shuffleRead : Long = 0 var shuffleWrite : Long = 0 } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index c6823cd823..763d5a344b 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -17,14 +17,13 @@ package org.apache.spark.ui.jobs - import scala.xml.Node import org.apache.spark.scheduler.SchedulingMode - +import org.apache.spark.util.Utils /** Page showing executor summary */ -private[spark] class ExecutorTable(val parent: JobProgressUI) { +private[spark] class ExecutorTable(val parent: JobProgressUI, val stageId: Int) { val listener = parent.listener val dateFmt = parent.dateFmt @@ -42,9 +41,9 @@ private[spark] class ExecutorTable(val parent: JobProgressUI) { Executor ID Duration - #Tasks - #Failed Tasks - #Succeed Tasks + Total Tasks + Failed Tasks + Succeeded Tasks Shuffle Read Shuffle Write @@ -55,19 +54,24 @@ private[spark] class ExecutorTable(val parent: JobProgressUI) { } private def createExecutorTable() : Seq[Node] = { - val executorIdToSummary = listener.executorIdToSummary - executorIdToSummary.toSeq.sortBy(_._1).map{ - case (k,v) => { - - {k} - {v.duration} ms - {v.totalTasks} - {v.failedTasks} - {v.succeedTasks} - {v.shuffleRead} - {v.shuffleWrite} - + val executorIdToSummary = listener.stageIdToExecutorSummaries.get(stageId) + executorIdToSummary match { + case Some(x) => { + x.toSeq.sortBy(_._1).map{ + case (k,v) => { + + {k} + {parent.formatDuration(v.duration)} + {v.failedTasks + v.succeededTasks} + {v.failedTasks} + {v.succeededTasks} + {Utils.bytesToString(v.shuffleRead)} + {Utils.bytesToString(v.shuffleWrite)} + + } + } } + case _ => { Seq[Node]() } } } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala index 653a84b60f..854afb665a 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala @@ -45,7 +45,6 @@ private[spark] class IndexPage(parent: JobProgressUI) { val activeStagesTable = new StageTable(activeStages.sortBy(_.submissionTime).reverse, parent) val completedStagesTable = new StageTable(completedStages.sortBy(_.submissionTime).reverse, parent) val failedStagesTable = new StageTable(failedStages.sortBy(_.submissionTime).reverse, parent) - val executorTable = new ExecutorTable(parent) val pools = listener.sc.getAllPools val poolTable = new PoolTable(pools, listener) @@ -59,7 +58,7 @@ private[spark] class IndexPage(parent: JobProgressUI) {
  • Scheduling Mode: {parent.sc.getSchedulingMode}
  • Executor Summary: - {listener.executorIdToSummary.size} + {listener.stageIdToExecutorSummaries.size}
  • Active Stages: @@ -82,8 +81,6 @@ private[spark] class IndexPage(parent: JobProgressUI) { } else { Seq() }} ++ -

    Executor Summary

    ++ - executorTable.toNodeSeq++

    Active Stages ({activeStages.size})

    ++ activeStagesTable.toNodeSeq++

    Completed Stages ({completedStages.size})

    ++ diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 2635478592..8c92ff19a6 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -57,7 +57,7 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList val stageIdToTasksFailed = HashMap[Int, Int]() val stageIdToTaskInfos = HashMap[Int, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]]() - val executorIdToSummary = HashMap[String, ExecutorSummary]() + val stageIdToExecutorSummaries = HashMap[Int, HashMap[String, ExecutorSummary]]() override def onJobStart(jobStart: SparkListenerJobStart) {} @@ -115,9 +115,6 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList sid, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]()) taskList += ((taskStart.taskInfo, None, None)) stageIdToTaskInfos(sid) = taskList - val executorSummary = executorIdToSummary.getOrElseUpdate(key = taskStart.taskInfo.executorId, - op = new ExecutorSummary()) - executorSummary.totalTasks += 1 } override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) @@ -127,32 +124,39 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList } override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { - // update executor summary - val executorSummary = executorIdToSummary.get(taskEnd.taskInfo.executorId) + val sid = taskEnd.task.stageId + + // create executor summary map if necessary + val executorSummaryMap = stageIdToExecutorSummaries.getOrElseUpdate(key = sid, + op = new HashMap[String, ExecutorSummary]()) + executorSummaryMap.getOrElseUpdate(key = taskEnd.taskInfo.executorId, + op = new ExecutorSummary()) + + val executorSummary = executorSummaryMap.get(taskEnd.taskInfo.executorId) executorSummary match { - case Some(x) => { + case Some(y) => { // first update failed-task, succeed-task taskEnd.reason match { - case e: ExceptionFailure => - x.failedTasks += 1 + case Success => + y.succeededTasks += 1 case _ => - x.succeedTasks += 1 + y.failedTasks += 1 } // update duration - x.duration += taskEnd.taskInfo.duration + y.duration += taskEnd.taskInfo.duration // update shuffle read/write val shuffleRead = taskEnd.taskMetrics.shuffleReadMetrics shuffleRead match { case Some(s) => - x.shuffleRead += s.remoteBytesRead + y.shuffleRead += s.remoteBytesRead case _ => {} } val shuffleWrite = taskEnd.taskMetrics.shuffleWriteMetrics shuffleWrite match { case Some(s) => { - x.shuffleWrite += s.shuffleBytesWritten + y.shuffleWrite += s.shuffleBytesWritten } case _ => {} } @@ -160,7 +164,6 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList case _ => {} } - val sid = taskEnd.task.stageId val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]()) tasksActive -= taskEnd.taskInfo diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 69f9446bab..c077613b1d 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -160,9 +160,10 @@ private[spark] class StagePage(parent: JobProgressUI) { def quantileRow(data: Seq[String]): Seq[Node] = {data.map(d => {d})} Some(listingTable(quantileHeaders, quantileRow, listings, fixedWidth = true)) } - + val executorTable = new ExecutorTable(parent, stageId) val content = summary ++ +

    Summary Metrics for Executors

    ++ executorTable.toNodeSeq() ++

    Summary Metrics for {numCompleted} Completed Tasks

    ++
    {summaryTable.getOrElse("No tasks have reported metrics yet.")}
    ++

    Tasks

    ++ taskTable diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 861d37a862..67a57a0e7f 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -19,26 +19,19 @@ package org.apache.spark.ui.jobs import org.scalatest.FunSuite import org.apache.spark.scheduler._ -import org.apache.spark.SparkContext -import org.apache.spark.Success +import org.apache.spark.{LocalSparkContext, SparkContext, Success} import org.apache.spark.scheduler.SparkListenerTaskStart import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics} -class JobProgressListenerSuite extends FunSuite { +class JobProgressListenerSuite extends FunSuite with LocalSparkContext { test("test executor id to summary") { - val sc = new SparkContext("local", "joblogger") + val sc = new SparkContext("local", "test") val listener = new JobProgressListener(sc) val taskMetrics = new TaskMetrics() val shuffleReadMetrics = new ShuffleReadMetrics() // nothing in it - assert(listener.executorIdToSummary.size == 0) - - // launched a task, should get an item in map - listener.onTaskStart(new SparkListenerTaskStart( - new ShuffleMapTask(0, null, null, 0, null), - new TaskInfo(1234L, 0, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL))) - assert(listener.executorIdToSummary.size == 1) + assert(listener.stageIdToExecutorSummaries.size == 0) // finish this task, should get updated shuffleRead shuffleReadMetrics.remoteBytesRead = 1000 @@ -47,20 +40,15 @@ class JobProgressListenerSuite extends FunSuite { taskInfo.finishTime = 1 listener.onTaskEnd(new SparkListenerTaskEnd( new ShuffleMapTask(0, null, null, 0, null), Success, taskInfo, taskMetrics)) - assert(listener.executorIdToSummary.getOrElse("exe-1", fail()).shuffleRead == 1000) + assert(listener.stageIdToExecutorSummaries.getOrElse(0, fail()).getOrElse("exe-1", fail()) + .shuffleRead == 1000) // finish a task with unknown executor-id, nothing should happen taskInfo = new TaskInfo(1234L, 0, 1000L, "exe-unknown", "host1", TaskLocality.NODE_LOCAL) taskInfo.finishTime = 1 listener.onTaskEnd(new SparkListenerTaskEnd( new ShuffleMapTask(0, null, null, 0, null), Success, taskInfo, taskMetrics)) - assert(listener.executorIdToSummary.size == 1) - - // launched a task - listener.onTaskStart(new SparkListenerTaskStart( - new ShuffleMapTask(0, null, null, 0, null), - new TaskInfo(1235L, 0, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL))) - assert(listener.executorIdToSummary.size == 1) + assert(listener.stageIdToExecutorSummaries.size == 1) // finish this task, should get updated duration shuffleReadMetrics.remoteBytesRead = 1000 @@ -69,13 +57,8 @@ class JobProgressListenerSuite extends FunSuite { taskInfo.finishTime = 1 listener.onTaskEnd(new SparkListenerTaskEnd( new ShuffleMapTask(0, null, null, 0, null), Success, taskInfo, taskMetrics)) - assert(listener.executorIdToSummary.getOrElse("exe-1", fail()).shuffleRead == 2000) - - // launched a task in another exec - listener.onTaskStart(new SparkListenerTaskStart( - new ShuffleMapTask(0, null, null, 0, null), - new TaskInfo(1236L, 0, 0L, "exe-2", "host1", TaskLocality.NODE_LOCAL))) - assert(listener.executorIdToSummary.size == 2) + assert(listener.stageIdToExecutorSummaries.getOrElse(0, fail()).getOrElse("exe-1", fail()) + .shuffleRead == 2000) // finish this task, should get updated duration shuffleReadMetrics.remoteBytesRead = 1000 @@ -84,13 +67,7 @@ class JobProgressListenerSuite extends FunSuite { taskInfo.finishTime = 1 listener.onTaskEnd(new SparkListenerTaskEnd( new ShuffleMapTask(0, null, null, 0, null), Success, taskInfo, taskMetrics)) - assert(listener.executorIdToSummary.getOrElse("exe-2", fail()).shuffleRead == 1000) - - // do finalize - sc.stop() - - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.driver.port") - System.clearProperty("spark.hostPort") + assert(listener.stageIdToExecutorSummaries.getOrElse(0, fail()).getOrElse("exe-2", fail()) + .shuffleRead == 1000) } } -- cgit v1.2.3 From 59e53fa21caa202a57093c74ada128fca2be5bac Mon Sep 17 00:00:00 2001 From: "wangda.tan" Date: Tue, 17 Dec 2013 17:57:27 +0800 Subject: spark-968, changes for avoid a NPE --- .../org/apache/spark/ui/exec/ExecutorsUI.scala | 30 ++++++++++++---------- .../apache/spark/ui/jobs/JobProgressListener.scala | 24 +++++++++-------- 2 files changed, 29 insertions(+), 25 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala index 808bbe8c8f..f62ae37466 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala @@ -150,7 +150,7 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { activeTasks += taskStart.taskInfo } - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { val eid = taskEnd.taskInfo.executorId val activeTasks = executorToTasksActive.getOrElseUpdate(eid, new HashSet[TaskInfo]()) val newDuration = executorToDuration.getOrElse(eid, 0L) + taskEnd.taskInfo.duration @@ -168,20 +168,22 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { } // update shuffle read/write - val shuffleRead = taskEnd.taskMetrics.shuffleReadMetrics - shuffleRead match { - case Some(s) => - val newShuffleRead = executorToShuffleRead.getOrElse(eid, 0L) + s.remoteBytesRead - executorToShuffleRead.put(eid, newShuffleRead) - case _ => {} - } - val shuffleWrite = taskEnd.taskMetrics.shuffleWriteMetrics - shuffleWrite match { - case Some(s) => { - val newShuffleWrite = executorToShuffleWrite.getOrElse(eid, 0L) + s.shuffleBytesWritten - executorToShuffleWrite.put(eid, newShuffleWrite) + if (null != taskEnd.taskMetrics) { + val shuffleRead = taskEnd.taskMetrics.shuffleReadMetrics + shuffleRead match { + case Some(s) => + val newShuffleRead = executorToShuffleRead.getOrElse(eid, 0L) + s.remoteBytesRead + executorToShuffleRead.put(eid, newShuffleRead) + case _ => {} + } + val shuffleWrite = taskEnd.taskMetrics.shuffleWriteMetrics + shuffleWrite match { + case Some(s) => { + val newShuffleWrite = executorToShuffleWrite.getOrElse(eid, 0L) + s.shuffleBytesWritten + executorToShuffleWrite.put(eid, newShuffleWrite) + } + case _ => {} } - case _ => {} } } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 8c92ff19a6..64ce715993 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -147,18 +147,20 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList y.duration += taskEnd.taskInfo.duration // update shuffle read/write - val shuffleRead = taskEnd.taskMetrics.shuffleReadMetrics - shuffleRead match { - case Some(s) => - y.shuffleRead += s.remoteBytesRead - case _ => {} - } - val shuffleWrite = taskEnd.taskMetrics.shuffleWriteMetrics - shuffleWrite match { - case Some(s) => { - y.shuffleWrite += s.shuffleBytesWritten + if (null != taskEnd.taskMetrics) { + val shuffleRead = taskEnd.taskMetrics.shuffleReadMetrics + shuffleRead match { + case Some(s) => + y.shuffleRead += s.remoteBytesRead + case _ => {} + } + val shuffleWrite = taskEnd.taskMetrics.shuffleWriteMetrics + shuffleWrite match { + case Some(s) => { + y.shuffleWrite += s.shuffleBytesWritten + } + case _ => {} } - case _ => {} } } case _ => {} -- cgit v1.2.3 From d5b260e7dd17c43e45f5c16c663d3479fb8757d1 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 19 Dec 2013 02:16:04 +0900 Subject: Change the order of CLASSPATH. --- spark-class | 2 +- spark-class2.cmd | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/spark-class b/spark-class index 4fa6fb864e..ff51fbd557 100755 --- a/spark-class +++ b/spark-class @@ -124,7 +124,7 @@ fi # Compute classpath using external script CLASSPATH=`$FWDIR/bin/compute-classpath.sh` -CLASSPATH="$SPARK_TOOLS_JAR:$CLASSPATH" +CLASSPATH="$CLASSPATH:$SPARK_TOOLS_JAR" export CLASSPATH if [ "$SPARK_PRINT_LAUNCH_COMMAND" == "1" ]; then diff --git a/spark-class2.cmd b/spark-class2.cmd index 3869d0761b..a60c17d050 100644 --- a/spark-class2.cmd +++ b/spark-class2.cmd @@ -75,7 +75,7 @@ rem Compute classpath using external script set DONT_PRINT_CLASSPATH=1 call "%FWDIR%bin\compute-classpath.cmd" set DONT_PRINT_CLASSPATH=0 -set CLASSPATH=%SPARK_TOOLS_JAR%;%CLASSPATH% +set CLASSPATH=%CLASSPATH%;%SPARK_TOOLS_JAR% rem Figure out where java is. set RUNNER=java -- cgit v1.2.3 From b80ec05635132f96772545803a10a1bbfa1250e7 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 18 Dec 2013 15:35:24 -0800 Subject: Added StatsReportListener to generate processing time statistics across multiple batches. --- .../org/apache/spark/scheduler/SparkListener.scala | 5 +-- .../spark/streaming/scheduler/JobScheduler.scala | 2 +- .../streaming/scheduler/StreamingListener.scala | 45 +++++++++++++++++++++- 3 files changed, 46 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 2c5d87419d..ee63b3c4a1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -131,8 +131,8 @@ object StatsReportListener extends Logging { def showDistribution(heading: String, d: Distribution, formatNumber: Double => String) { val stats = d.statCounter - logInfo(heading + stats) val quantiles = d.getQuantiles(probabilities).map{formatNumber} + logInfo(heading + stats) logInfo(percentilesHeader) logInfo("\t" + quantiles.mkString("\t")) } @@ -173,8 +173,6 @@ object StatsReportListener extends Logging { showMillisDistribution(heading, extractLongDistribution(stage, getMetric)) } - - val seconds = 1000L val minutes = seconds * 60 val hours = minutes * 60 @@ -198,7 +196,6 @@ object StatsReportListener extends Logging { } - case class RuntimePercentage(executorPct: Double, fetchPct: Option[Double], other: Double) object RuntimePercentage { def apply(totalTime: Long, metrics: TaskMetrics): RuntimePercentage = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 14906fd720..69930f3b6c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -79,13 +79,13 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { jobSet.afterJobStop(job) logInfo("Finished job " + job.id + " from job set of time " + jobSet.time) if (jobSet.hasCompleted) { - listenerBus.post(StreamingListenerBatchCompleted(jobSet.toBatchInfo())) jobSets.remove(jobSet.time) generator.onBatchCompletion(jobSet.time) logInfo("Total delay: %.3f s for time %s (execution: %.3f s)".format( jobSet.totalDelay / 1000.0, jobSet.time.toString, jobSet.processingDelay / 1000.0 )) + listenerBus.post(StreamingListenerBatchCompleted(jobSet.toBatchInfo())) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala index 49fd0d29c3..5647ffab8d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala @@ -17,14 +17,22 @@ package org.apache.spark.streaming.scheduler +import scala.collection.mutable.Queue +import org.apache.spark.util.Distribution + +/** Base trait for events related to StreamingListener */ sealed trait StreamingListenerEvent case class StreamingListenerBatchCompleted(batchInfo: BatchInfo) extends StreamingListenerEvent case class StreamingListenerBatchStarted(batchInfo: BatchInfo) extends StreamingListenerEvent -trait StreamingListener { +/** + * A listener interface for receiving information about an ongoing streaming + * computation. + */ +trait StreamingListener { /** * Called when processing of a batch has completed */ @@ -34,4 +42,39 @@ trait StreamingListener { * Called when processing of a batch has started */ def onBatchStarted(batchStarted: StreamingListenerBatchStarted) { } +} + + +/** + * A simple StreamingListener that logs summary statistics across Spark Streaming batches + * @param numBatchInfos Number of last batches to consider for generating statistics (default: 10) + */ +class StatsReportListener(numBatchInfos: Int = 10) extends StreamingListener { + + import org.apache.spark + + val batchInfos = new Queue[BatchInfo]() + + override def onBatchCompleted(batchStarted: StreamingListenerBatchCompleted) { + addToQueue(batchStarted.batchInfo) + printStats() + } + + def addToQueue(newPoint: BatchInfo) { + batchInfos.enqueue(newPoint) + if (batchInfos.size > numBatchInfos) batchInfos.dequeue() + } + + def printStats() { + showMillisDistribution("Total delay: ", _.totalDelay) + showMillisDistribution("Processing time: ", _.processingDelay) + } + + def showMillisDistribution(heading: String, getMetric: BatchInfo => Option[Long]) { + spark.scheduler.StatsReportListener.showMillisDistribution(heading, extractDistribution(getMetric)) + } + + def extractDistribution(getMetric: BatchInfo => Option[Long]): Option[Distribution] = { + Distribution(batchInfos.flatMap(getMetric(_)).map(_.toDouble)) + } } \ No newline at end of file -- cgit v1.2.3 From 95915f8b3b6d07a9dddb09a637aa23c8622bff9b Mon Sep 17 00:00:00 2001 From: Tor Myklebust Date: Thu, 19 Dec 2013 01:22:18 -0500 Subject: First cut at python mllib bindings. Only LinearRegression is supported. --- .../apache/spark/mllib/api/PythonMLLibAPI.scala | 51 +++++++++ python/pyspark/mllib.py | 114 +++++++++++++++++++++ 2 files changed, 165 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala create mode 100644 python/pyspark/mllib.py diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala new file mode 100644 index 0000000000..19d2e9a773 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala @@ -0,0 +1,51 @@ +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.regression._ +import java.nio.ByteBuffer +import java.nio.ByteOrder +import java.nio.DoubleBuffer + +class PythonMLLibAPI extends Serializable { + def deserializeDoubleVector(bytes: Array[Byte]): Array[Double] = { + val packetLength = bytes.length; + if (packetLength < 16) { + throw new IllegalArgumentException("Byte array too short."); + } + val bb = ByteBuffer.wrap(bytes); + bb.order(ByteOrder.nativeOrder()); + val magic = bb.getLong(); + if (magic != 1) { + throw new IllegalArgumentException("Magic " + magic + " is wrong."); + } + val length = bb.getLong(); + if (packetLength != 16 + 8 * length) { + throw new IllegalArgumentException("Length " + length + "is wrong."); + } + val db = bb.asDoubleBuffer(); + val ans = new Array[Double](length.toInt); + db.get(ans); + return ans; + } + + def serializeDoubleVector(doubles: Array[Double]): Array[Byte] = { + val len = doubles.length; + val bytes = new Array[Byte](16 + 8 * len); + val bb = ByteBuffer.wrap(bytes); + bb.order(ByteOrder.nativeOrder()); + bb.putLong(1); + bb.putLong(len); + val db = bb.asDoubleBuffer(); + db.put(doubles); + return bytes; + } + + def trainLinearRegressionModel(dataBytesJRDD: JavaRDD[Array[Byte]]): + java.util.List[java.lang.Object] = { + val data = dataBytesJRDD.rdd.map(x => deserializeDoubleVector(x)) + .map(v => LabeledPoint(v(0), v.slice(1, v.length))); + val model = LinearRegressionWithSGD.train(data, 222); + val ret = new java.util.LinkedList[java.lang.Object](); + ret.add(serializeDoubleVector(model.weights)); + ret.add(model.intercept: java.lang.Double); + return ret; + } +} diff --git a/python/pyspark/mllib.py b/python/pyspark/mllib.py new file mode 100644 index 0000000000..8237f66d67 --- /dev/null +++ b/python/pyspark/mllib.py @@ -0,0 +1,114 @@ +from numpy import *; +from pyspark.serializers import NoOpSerializer, FramedSerializer, \ + BatchedSerializer, CloudPickleSerializer, pack_long + +#__all__ = ["train_linear_regression_model"]; + +# Double vector format: +# +# [8-byte 1] [8-byte length] [length*8 bytes of data] +# +# Double matrix format: +# +# [8-byte 2] [8-byte rows] [8-byte cols] [rows*cols*8 bytes of data] +# +# This is all in machine-endian. That means that the Java interpreter and the +# Python interpreter must agree on what endian the machine is. + +def deserialize_byte_array(shape, ba, offset): + """Implementation detail. Do not use directly.""" + ar = ndarray(shape=shape, buffer=ba, offset=offset, dtype="float64", \ + order='C'); + return ar.copy(); + +def serialize_double_vector(v): + """Implementation detail. Do not use directly.""" + if (type(v) == ndarray and v.dtype == float64 and v.ndim == 1): + length = v.shape[0]; + ba = bytearray(16 + 8*length); + header = ndarray(shape=[2], buffer=ba, dtype="int64"); + header[0] = 1; + header[1] = length; + copyto(ndarray(shape=[length], buffer=ba, offset=16, dtype="float64"), v); + return ba; + else: + raise TypeError("serialize_double_vector called on a non-double-vector"); + +def deserialize_double_vector(ba): + """Implementation detail. Do not use directly.""" + if (type(ba) == bytearray and len(ba) >= 16 and (len(ba) & 7 == 0)): + header = ndarray(shape=[2], buffer=ba, dtype="int64"); + if (header[0] != 1): + raise TypeError("deserialize_double_vector called on bytearray with " \ + "wrong magic"); + length = header[1]; + if (len(ba) != 8*length + 16): + raise TypeError("deserialize_double_vector called on bytearray with " \ + "wrong length"); + return deserialize_byte_array([length], ba, 16); + else: + raise TypeError("deserialize_double_vector called on a non-bytearray"); + +def serialize_double_matrix(m): + """Implementation detail. Do not use directly.""" + if (type(m) == ndarray and m.dtype == float64 and m.ndim == 2): + rows = m.shape[0]; + cols = m.shape[1]; + ba = bytearray(24 + 8 * rows * cols); + header = ndarray(shape=[3], buffer=ba, dtype="int64"); + header[0] = 2; + header[1] = rows; + header[2] = cols; + copyto(ndarray(shape=[rows, cols], buffer=ba, offset=24, dtype="float64", \ + order='C'), m); + return ba; + else: + print type(m); + print m.dtype; + print m.ndim; + raise TypeError("serialize_double_matrix called on a non-double-matrix"); + +def deserialize_double_matrix(ba): + """Implementation detail. Do not use directly.""" + if (type(ba) == bytearray and len(ba) >= 24 and (len(ba) & 7 == 0)): + header = ndarray(shape=[3], buffer=ba, dtype="int64"); + if (header[0] != 2): + raise TypeError("deserialize_double_matrix called on bytearray with " \ + "wrong magic"); + rows = header[1]; + cols = header[2]; + if (len(ba) != 8*rows*cols + 24): + raise TypeError("deserialize_double_matrix called on bytearray with " \ + "wrong length"); + return deserialize_byte_array([rows, cols], ba, 24); + else: + raise TypeError("deserialize_double_matrix called on a non-bytearray"); + +class LinearRegressionModel: + _coeff = None; + _intercept = None; + def __init__(self, coeff, intercept): + self._coeff = coeff; + self._intercept = intercept; + def predict(self, x): + if (type(x) == ndarray): + if (x.ndim == 1): + return dot(_coeff, x) - _intercept; + else: + raise RuntimeError("Bulk predict not yet supported."); + elif (type(x) == RDD): + raise RuntimeError("Bulk predict not yet supported."); + else: + raise TypeError("Bad type argument to LinearRegressionModel::predict"); + +def train_linear_regression_model(sc, data): + """Train a linear regression model on the given data.""" + dataBytes = data.map(serialize_double_vector); + sc.serializer = NoOpSerializer(); + dataBytes.cache(); + api = sc._jvm.PythonMLLibAPI(); + ans = api.trainLinearRegressionModel(dataBytes._jrdd); + if (len(ans) != 2 or type(ans[0]) != bytearray or type(ans[1]) != float): + raise RuntimeError("train_linear_regression_model received garbage " \ + "from JVM"); + return LinearRegressionModel(deserialize_double_vector(ans[0]), ans[1]); -- cgit v1.2.3 From bf491bb3c0a9008caa4ac112672a4760b3d1c7b8 Mon Sep 17 00:00:00 2001 From: Tor Myklebust Date: Thu, 19 Dec 2013 01:29:51 -0500 Subject: The rest of the Python side of those bindings. --- python/pyspark/__init__.py | 3 ++- python/pyspark/java_gateway.py | 1 + python/pyspark/serializers.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 1f35f6f939..949406c57b 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -42,6 +42,7 @@ from pyspark.context import SparkContext from pyspark.rdd import RDD from pyspark.files import SparkFiles from pyspark.storagelevel import StorageLevel +from pyspark.mllib import train_linear_regression_model -__all__ = ["SparkContext", "RDD", "SparkFiles", "StorageLevel"] +__all__ = ["SparkContext", "RDD", "SparkFiles", "StorageLevel", "train_linear_regression_model"] diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index e615c1e9b6..2941984e19 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -62,5 +62,6 @@ def launch_gateway(): # Import the classes used by PySpark java_import(gateway.jvm, "org.apache.spark.api.java.*") java_import(gateway.jvm, "org.apache.spark.api.python.*") + java_import(gateway.jvm, "org.apache.spark.mllib.api.*") java_import(gateway.jvm, "scala.Tuple2") return gateway diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 811fa6f018..2a500ab919 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -308,4 +308,4 @@ def write_int(value, stream): def write_with_length(obj, stream): write_int(len(obj), stream) - stream.write(obj) \ No newline at end of file + stream.write(obj) -- cgit v1.2.3 From ec71b445ad0440e84c4b4909e4faf75aba0f13d7 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 18 Dec 2013 23:39:28 -0800 Subject: Minor changes. --- .../org/apache/spark/scheduler/SparkListenerBus.scala | 1 - .../org/apache/spark/streaming/StreamingContext.scala | 10 ++++++---- .../spark/streaming/api/java/JavaStreamingContext.scala | 8 ++++++++ .../apache/spark/streaming/scheduler/BatchInfo.scala | 7 +++---- .../org/apache/spark/streaming/scheduler/Job.scala | 14 ++++---------- .../apache/spark/streaming/scheduler/JobGenerator.scala | 4 ++++ .../apache/spark/streaming/scheduler/JobScheduler.scala | 4 +++- .../spark/streaming/scheduler/StreamingListener.scala | 17 ++++++----------- .../streaming/scheduler/StreamingListenerBus.scala | 2 +- .../apache/spark/streaming/StreamingListenerSuite.scala | 2 +- 10 files changed, 36 insertions(+), 33 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index d5824e7954..85687ea330 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -91,4 +91,3 @@ private[spark] class SparkListenerBus() extends Logging { return true } } - diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index fedbbde80c..41da028a3c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -513,7 +513,10 @@ class StreamingContext private ( graph.addOutputStream(outputStream) } - def addListener(streamingListener: StreamingListener) { + /** Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for + * receiving system events related to streaming. + */ + def addStreamingListener(streamingListener: StreamingListener) { scheduler.listenerBus.addListener(streamingListener) } @@ -532,20 +535,19 @@ class StreamingContext private ( * Start the execution of the streams. */ def start() { - validate() + // Get the network input streams val networkInputStreams = graph.getInputStreams().filter(s => s match { case n: NetworkInputDStream[_] => true case _ => false }).map(_.asInstanceOf[NetworkInputDStream[_]]).toArray + // Start the network input tracker (must start before receivers) if (networkInputStreams.length > 0) { - // Start the network input tracker (must start before receivers) networkInputTracker = new NetworkInputTracker(this, networkInputStreams) networkInputTracker.start() } - Thread.sleep(1000) // Start the scheduler diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 80dcf87491..78d318cf27 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -39,6 +39,7 @@ import org.apache.spark.api.java.function.{Function => JFunction, Function2 => J import org.apache.spark.api.java.{JavaPairRDD, JavaSparkContext, JavaRDD} import org.apache.spark.streaming._ import org.apache.spark.streaming.dstream._ +import org.apache.spark.streaming.scheduler.StreamingListener /** * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic @@ -687,6 +688,13 @@ class JavaStreamingContext(val ssc: StreamingContext) { ssc.remember(duration) } + /** Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for + * receiving system events related to streaming. + */ + def addStreamingListener(streamingListener: StreamingListener) { + ssc.addStreamingListener(streamingListener) + } + /** * Starts the execution of the streams. */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala index 798598ad50..88e4af59b7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala @@ -19,6 +19,9 @@ package org.apache.spark.streaming.scheduler import org.apache.spark.streaming.Time +/** + * Class having information on completed batches. + */ case class BatchInfo( batchTime: Time, submissionTime: Long, @@ -32,7 +35,3 @@ case class BatchInfo( def totalDelay = schedulingDelay.zip(processingDelay).map(x => x._1 + x._2).headOption } - - - - diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala index bca5e1f1a5..7341bfbc99 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala @@ -17,9 +17,11 @@ package org.apache.spark.streaming.scheduler -import java.util.concurrent.atomic.AtomicLong import org.apache.spark.streaming.Time +/** + * Class representing a Spark computation. It may contain multiple Spark jobs. + */ private[streaming] class Job(val time: Time, func: () => _) { var id: String = _ @@ -36,12 +38,4 @@ class Job(val time: Time, func: () => _) { } override def toString = id -} -/* -private[streaming] -object Job { - val id = new AtomicLong(0) - - def getNewId() = id.getAndIncrement() -} -*/ +} \ No newline at end of file diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 5d3ce9c398..1cd0b9b0a4 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -22,6 +22,10 @@ import org.apache.spark.Logging import org.apache.spark.streaming.{Checkpoint, Time, CheckpointWriter} import org.apache.spark.streaming.util.{ManualClock, RecurringTimer, Clock} +/** + * This class generates jobs from DStreams as well as drives checkpointing and cleaning + * up DStream metadata. + */ private[streaming] class JobGenerator(jobScheduler: JobScheduler) extends Logging { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 69930f3b6c..33c5322358 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -23,7 +23,9 @@ import java.util.concurrent.{TimeUnit, ConcurrentHashMap, Executors} import scala.collection.mutable.HashSet import org.apache.spark.streaming._ - +/** + * This class drives the generation of Spark jobs from the DStreams. + */ private[streaming] class JobScheduler(val ssc: StreamingContext) extends Logging { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala index 5647ffab8d..36225e190c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala @@ -50,19 +50,13 @@ trait StreamingListener { * @param numBatchInfos Number of last batches to consider for generating statistics (default: 10) */ class StatsReportListener(numBatchInfos: Int = 10) extends StreamingListener { - - import org.apache.spark - + // Queue containing latest completed batches val batchInfos = new Queue[BatchInfo]() override def onBatchCompleted(batchStarted: StreamingListenerBatchCompleted) { - addToQueue(batchStarted.batchInfo) - printStats() - } - - def addToQueue(newPoint: BatchInfo) { - batchInfos.enqueue(newPoint) + batchInfos.enqueue(batchStarted.batchInfo) if (batchInfos.size > numBatchInfos) batchInfos.dequeue() + printStats() } def printStats() { @@ -71,10 +65,11 @@ class StatsReportListener(numBatchInfos: Int = 10) extends StreamingListener { } def showMillisDistribution(heading: String, getMetric: BatchInfo => Option[Long]) { - spark.scheduler.StatsReportListener.showMillisDistribution(heading, extractDistribution(getMetric)) + org.apache.spark.scheduler.StatsReportListener.showMillisDistribution( + heading, extractDistribution(getMetric)) } def extractDistribution(getMetric: BatchInfo => Option[Long]): Option[Distribution] = { Distribution(batchInfos.flatMap(getMetric(_)).map(_.toDouble)) } -} \ No newline at end of file +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala index 324e491914..110a20f282 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala @@ -78,4 +78,4 @@ private[spark] class StreamingListenerBus() extends Logging { } return true } -} \ No newline at end of file +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 826c839932..16410a21e3 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -34,7 +34,7 @@ class StreamingListenerSuite extends TestSuiteBase with ShouldMatchers{ test("basic BatchInfo generation") { val ssc = setupStreams(input, operation) val collector = new BatchInfoCollector - ssc.addListener(collector) + ssc.addStreamingListener(collector) runStreams(ssc, input.size, input.size) val batchInfos = collector.batchInfos batchInfos should have size 4 -- cgit v1.2.3 From bf20591a006b9d2fdd9a674d637f5e929fd065a2 Mon Sep 17 00:00:00 2001 From: Tor Myklebust Date: Thu, 19 Dec 2013 03:40:57 -0500 Subject: Incorporate most of Josh's style suggestions. I don't want to deal with the type and length checking errors until we've got at least one working stub that we're all happy with. --- python/pyspark/__init__.py | 4 +- python/pyspark/mllib.py | 185 ++++++++++++++++++++++----------------------- 2 files changed, 91 insertions(+), 98 deletions(-) diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 949406c57b..9f71db397d 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -42,7 +42,7 @@ from pyspark.context import SparkContext from pyspark.rdd import RDD from pyspark.files import SparkFiles from pyspark.storagelevel import StorageLevel -from pyspark.mllib import train_linear_regression_model +from pyspark.mllib import LinearRegressionModel -__all__ = ["SparkContext", "RDD", "SparkFiles", "StorageLevel", "train_linear_regression_model"] +__all__ = ["SparkContext", "RDD", "SparkFiles", "StorageLevel", "LinearRegressionModel"]; diff --git a/python/pyspark/mllib.py b/python/pyspark/mllib.py index 8237f66d67..0dfc4909c7 100644 --- a/python/pyspark/mllib.py +++ b/python/pyspark/mllib.py @@ -1,8 +1,4 @@ -from numpy import *; -from pyspark.serializers import NoOpSerializer, FramedSerializer, \ - BatchedSerializer, CloudPickleSerializer, pack_long - -#__all__ = ["train_linear_regression_model"]; +from numpy import * # Double vector format: # @@ -15,100 +11,97 @@ from pyspark.serializers import NoOpSerializer, FramedSerializer, \ # This is all in machine-endian. That means that the Java interpreter and the # Python interpreter must agree on what endian the machine is. -def deserialize_byte_array(shape, ba, offset): - """Implementation detail. Do not use directly.""" - ar = ndarray(shape=shape, buffer=ba, offset=offset, dtype="float64", \ - order='C'); - return ar.copy(); - -def serialize_double_vector(v): - """Implementation detail. Do not use directly.""" - if (type(v) == ndarray and v.dtype == float64 and v.ndim == 1): - length = v.shape[0]; - ba = bytearray(16 + 8*length); - header = ndarray(shape=[2], buffer=ba, dtype="int64"); - header[0] = 1; - header[1] = length; - copyto(ndarray(shape=[length], buffer=ba, offset=16, dtype="float64"), v); - return ba; - else: - raise TypeError("serialize_double_vector called on a non-double-vector"); +def _deserialize_byte_array(shape, ba, offset): + ar = ndarray(shape=shape, buffer=ba, offset=offset, dtype="float64", + order='C') + return ar.copy() -def deserialize_double_vector(ba): - """Implementation detail. Do not use directly.""" - if (type(ba) == bytearray and len(ba) >= 16 and (len(ba) & 7 == 0)): - header = ndarray(shape=[2], buffer=ba, dtype="int64"); - if (header[0] != 1): - raise TypeError("deserialize_double_vector called on bytearray with " \ - "wrong magic"); - length = header[1]; - if (len(ba) != 8*length + 16): - raise TypeError("deserialize_double_vector called on bytearray with " \ - "wrong length"); - return deserialize_byte_array([length], ba, 16); - else: - raise TypeError("deserialize_double_vector called on a non-bytearray"); +def _serialize_double_vector(v): + if (type(v) == ndarray and v.dtype == float64 and v.ndim == 1): + length = v.shape[0] + ba = bytearray(16 + 8*length) + header = ndarray(shape=[2], buffer=ba, dtype="int64") + header[0] = 1 + header[1] = length + copyto(ndarray(shape=[length], buffer=ba, offset=16, + dtype="float64"), v) + return ba + else: + raise TypeError("_serialize_double_vector called on a " + "non-double-vector") -def serialize_double_matrix(m): - """Implementation detail. Do not use directly.""" - if (type(m) == ndarray and m.dtype == float64 and m.ndim == 2): - rows = m.shape[0]; - cols = m.shape[1]; - ba = bytearray(24 + 8 * rows * cols); - header = ndarray(shape=[3], buffer=ba, dtype="int64"); - header[0] = 2; - header[1] = rows; - header[2] = cols; - copyto(ndarray(shape=[rows, cols], buffer=ba, offset=24, dtype="float64", \ - order='C'), m); - return ba; - else: - print type(m); - print m.dtype; - print m.ndim; - raise TypeError("serialize_double_matrix called on a non-double-matrix"); +def _deserialize_double_vector(ba): + if (type(ba) == bytearray and len(ba) >= 16 and (len(ba) & 7 == 0)): + header = ndarray(shape=[2], buffer=ba, dtype="int64") + if (header[0] != 1): + raise TypeError("_deserialize_double_vector called on bytearray " + "with wrong magic") + length = header[1] + if (len(ba) != 8*length + 16): + raise TypeError("_deserialize_double_vector called on bytearray " + "with wrong length") + return _deserialize_byte_array([length], ba, 16) + else: + raise TypeError("_deserialize_double_vector called on a non-bytearray") -def deserialize_double_matrix(ba): - """Implementation detail. Do not use directly.""" - if (type(ba) == bytearray and len(ba) >= 24 and (len(ba) & 7 == 0)): - header = ndarray(shape=[3], buffer=ba, dtype="int64"); - if (header[0] != 2): - raise TypeError("deserialize_double_matrix called on bytearray with " \ - "wrong magic"); - rows = header[1]; - cols = header[2]; - if (len(ba) != 8*rows*cols + 24): - raise TypeError("deserialize_double_matrix called on bytearray with " \ - "wrong length"); - return deserialize_byte_array([rows, cols], ba, 24); - else: - raise TypeError("deserialize_double_matrix called on a non-bytearray"); +def _serialize_double_matrix(m): + if (type(m) == ndarray and m.dtype == float64 and m.ndim == 2): + rows = m.shape[0] + cols = m.shape[1] + ba = bytearray(24 + 8 * rows * cols) + header = ndarray(shape=[3], buffer=ba, dtype="int64") + header[0] = 2 + header[1] = rows + header[2] = cols + copyto(ndarray(shape=[rows, cols], buffer=ba, offset=24, + dtype="float64", order='C'), m) + return ba + else: + raise TypeError("_serialize_double_matrix called on a " + "non-double-matrix") -class LinearRegressionModel: - _coeff = None; - _intercept = None; - def __init__(self, coeff, intercept): - self._coeff = coeff; - self._intercept = intercept; - def predict(self, x): - if (type(x) == ndarray): - if (x.ndim == 1): - return dot(_coeff, x) - _intercept; - else: - raise RuntimeError("Bulk predict not yet supported."); - elif (type(x) == RDD): - raise RuntimeError("Bulk predict not yet supported."); +def _deserialize_double_matrix(ba): + if (type(ba) == bytearray and len(ba) >= 24 and (len(ba) & 7 == 0)): + header = ndarray(shape=[3], buffer=ba, dtype="int64") + if (header[0] != 2): + raise TypeError("_deserialize_double_matrix called on bytearray " + "with wrong magic") + rows = header[1] + cols = header[2] + if (len(ba) != 8*rows*cols + 24): + raise TypeError("_deserialize_double_matrix called on bytearray " + "with wrong length") + return _deserialize_byte_array([rows, cols], ba, 24) else: - raise TypeError("Bad type argument to LinearRegressionModel::predict"); + raise TypeError("_deserialize_double_matrix called on a non-bytearray") + +class LinearRegressionModel(object): + def __init__(self, coeff, intercept): + self._coeff = coeff + self._intercept = intercept + + def predict(self, x): + if (type(x) == ndarray): + if (x.ndim == 1): + return dot(_coeff, x) - _intercept + else: + raise RuntimeError("Bulk predict not yet supported.") + elif (type(x) == RDD): + raise RuntimeError("Bulk predict not yet supported.") + else: + raise TypeError("Bad type argument to " + "LinearRegressionModel::predict") -def train_linear_regression_model(sc, data): - """Train a linear regression model on the given data.""" - dataBytes = data.map(serialize_double_vector); - sc.serializer = NoOpSerializer(); - dataBytes.cache(); - api = sc._jvm.PythonMLLibAPI(); - ans = api.trainLinearRegressionModel(dataBytes._jrdd); - if (len(ans) != 2 or type(ans[0]) != bytearray or type(ans[1]) != float): - raise RuntimeError("train_linear_regression_model received garbage " \ - "from JVM"); - return LinearRegressionModel(deserialize_double_vector(ans[0]), ans[1]); + @classmethod + def train(cls, sc, data): + """Train a linear regression model on the given data.""" + dataBytes = data.map(_serialize_double_vector) + dataBytes._bypass_serializer = True + dataBytes.cache() + api = sc._jvm.PythonMLLibAPI() + ans = api.trainLinearRegressionModel(dataBytes._jrdd) + if (len(ans) != 2 or type(ans[0]) != bytearray + or type(ans[1]) != float): + raise RuntimeError("train_linear_regression_model received " + "garbage from JVM") + return LinearRegressionModel(_deserialize_double_vector(ans[0]), ans[1]) -- cgit v1.2.3 From 0647ec97573dc267c7a6b4679fb938b4dfa4fbb6 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Thu, 19 Dec 2013 15:10:48 -0800 Subject: Clean up shuffle files once their metadata is gone Previously, we would only clean the in-memory metadata for consolidated shuffle files. Additionally, fixes a bug where the Metadata Cleaner was ignoring type- specific TTLs. --- .../apache/spark/storage/ShuffleBlockManager.scala | 25 +++++++++++++++++++--- .../org/apache/spark/util/MetadataCleaner.scala | 2 +- .../org/apache/spark/util/TimeStampedHashMap.scala | 15 ++++++++++--- 3 files changed, 35 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index e828e1d1c5..212ef6506f 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -70,10 +70,16 @@ class ShuffleBlockManager(blockManager: BlockManager) { * Contains all the state related to a particular shuffle. This includes a pool of unused * ShuffleFileGroups, as well as all ShuffleFileGroups that have been created for the shuffle. */ - private class ShuffleState() { + private class ShuffleState(val numBuckets: Int) { val nextFileId = new AtomicInteger(0) val unusedFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]() val allFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]() + + /** + * The mapIds of all map tasks completed on this Executor for this shuffle. + * NB: This is only populated if consolidateShuffleFiles is FALSE. We don't need it otherwise. + */ + val completedMapTasks = new ConcurrentLinkedQueue[Int]() } type ShuffleId = Int @@ -84,7 +90,7 @@ class ShuffleBlockManager(blockManager: BlockManager) { def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer) = { new ShuffleWriterGroup { - shuffleStates.putIfAbsent(shuffleId, new ShuffleState()) + shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets)) private val shuffleState = shuffleStates(shuffleId) private var fileGroup: ShuffleFileGroup = null @@ -109,6 +115,8 @@ class ShuffleBlockManager(blockManager: BlockManager) { fileGroup.recordMapOutput(mapId, offsets) } recycleFileGroup(fileGroup) + } else { + shuffleState.completedMapTasks.add(mapId) } } @@ -154,7 +162,18 @@ class ShuffleBlockManager(blockManager: BlockManager) { } private def cleanup(cleanupTime: Long) { - shuffleStates.clearOldValues(cleanupTime) + shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => { + if (consolidateShuffleFiles) { + for (fileGroup <- state.allFileGroups; file <- fileGroup.files) { + file.delete() + } + } else { + for (mapId <- state.completedMapTasks; reduceId <- 0 until state.numBuckets) { + val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId) + blockManager.diskBlockManager.getFile(blockId).delete() + } + } + }) } } diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index 7b41ef89f1..fe56960cbf 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -27,7 +27,7 @@ import org.apache.spark.Logging class MetadataCleaner(cleanerType: MetadataCleanerType.MetadataCleanerType, cleanupFunc: (Long) => Unit) extends Logging { val name = cleanerType.toString - private val delaySeconds = MetadataCleaner.getDelaySeconds + private val delaySeconds = MetadataCleaner.getDelaySeconds(cleanerType) private val periodSeconds = math.max(10, delaySeconds / 10) private val timer = new Timer(name + " cleanup timer", true) diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala index dbff571de9..181ae2fd45 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala @@ -104,19 +104,28 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() with Logging { def toMap: immutable.Map[A, B] = iterator.toMap /** - * Removes old key-value pairs that have timestamp earlier than `threshTime` + * Removes old key-value pairs that have timestamp earlier than `threshTime`, + * calling the supplied function on each such entry before removing. */ - def clearOldValues(threshTime: Long) { + def clearOldValues(threshTime: Long, f: (A, B) => Unit) { val iterator = internalMap.entrySet().iterator() - while(iterator.hasNext) { + while (iterator.hasNext) { val entry = iterator.next() if (entry.getValue._2 < threshTime) { + f(entry.getKey, entry.getValue._1) logDebug("Removing key " + entry.getKey) iterator.remove() } } } + /** + * Removes old key-value pairs that have timestamp earlier than `threshTime` + */ + def clearOldValues(threshTime: Long) { + clearOldValues(threshTime, (_, _) => ()) + } + private def currentTime: Long = System.currentTimeMillis() } -- cgit v1.2.3 From 2a41c9aad3d0a8477a11bf910fa57b49ea4dc6dc Mon Sep 17 00:00:00 2001 From: Tor Myklebust Date: Thu, 19 Dec 2013 21:27:11 -0500 Subject: Un-semicolon PythonMLLibAPI. --- .../apache/spark/mllib/api/PythonMLLibAPI.scala | 54 +++++++++++----------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala index 19d2e9a773..3daf5dcb39 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala @@ -6,46 +6,46 @@ import java.nio.DoubleBuffer class PythonMLLibAPI extends Serializable { def deserializeDoubleVector(bytes: Array[Byte]): Array[Double] = { - val packetLength = bytes.length; + val packetLength = bytes.length if (packetLength < 16) { - throw new IllegalArgumentException("Byte array too short."); + throw new IllegalArgumentException("Byte array too short.") } - val bb = ByteBuffer.wrap(bytes); - bb.order(ByteOrder.nativeOrder()); - val magic = bb.getLong(); + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + val magic = bb.getLong() if (magic != 1) { - throw new IllegalArgumentException("Magic " + magic + " is wrong."); + throw new IllegalArgumentException("Magic " + magic + " is wrong.") } - val length = bb.getLong(); + val length = bb.getLong() if (packetLength != 16 + 8 * length) { - throw new IllegalArgumentException("Length " + length + "is wrong."); + throw new IllegalArgumentException("Length " + length + "is wrong.") } - val db = bb.asDoubleBuffer(); - val ans = new Array[Double](length.toInt); - db.get(ans); - return ans; + val db = bb.asDoubleBuffer() + val ans = new Array[Double](length.toInt) + db.get(ans) + return ans } def serializeDoubleVector(doubles: Array[Double]): Array[Byte] = { - val len = doubles.length; - val bytes = new Array[Byte](16 + 8 * len); - val bb = ByteBuffer.wrap(bytes); - bb.order(ByteOrder.nativeOrder()); - bb.putLong(1); - bb.putLong(len); - val db = bb.asDoubleBuffer(); - db.put(doubles); - return bytes; + val len = doubles.length + val bytes = new Array[Byte](16 + 8 * len) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.putLong(1) + bb.putLong(len) + val db = bb.asDoubleBuffer() + db.put(doubles) + return bytes } def trainLinearRegressionModel(dataBytesJRDD: JavaRDD[Array[Byte]]): java.util.List[java.lang.Object] = { val data = dataBytesJRDD.rdd.map(x => deserializeDoubleVector(x)) - .map(v => LabeledPoint(v(0), v.slice(1, v.length))); - val model = LinearRegressionWithSGD.train(data, 222); - val ret = new java.util.LinkedList[java.lang.Object](); - ret.add(serializeDoubleVector(model.weights)); - ret.add(model.intercept: java.lang.Double); - return ret; + .map(v => LabeledPoint(v(0), v.slice(1, v.length))) + val model = LinearRegressionWithSGD.train(data, 222) + val ret = new java.util.LinkedList[java.lang.Object]() + ret.add(serializeDoubleVector(model.weights)) + ret.add(model.intercept: java.lang.Double) + return ret } } -- cgit v1.2.3 From ded67ee90c2c0b22d67e623156a3f6cce8573abd Mon Sep 17 00:00:00 2001 From: Tor Myklebust Date: Thu, 19 Dec 2013 22:42:12 -0500 Subject: Bindings for linear, Lasso, and ridge regression. --- .../apache/spark/mllib/api/PythonMLLibAPI.scala | 42 +++++++++++++++++++--- 1 file changed, 37 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala index 3daf5dcb39..c9bd7c6415 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala @@ -1,5 +1,6 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.regression._ +import org.apache.spark.rdd.RDD import java.nio.ByteBuffer import java.nio.ByteOrder import java.nio.DoubleBuffer @@ -38,14 +39,45 @@ class PythonMLLibAPI extends Serializable { return bytes } - def trainLinearRegressionModel(dataBytesJRDD: JavaRDD[Array[Byte]]): - java.util.List[java.lang.Object] = { - val data = dataBytesJRDD.rdd.map(x => deserializeDoubleVector(x)) - .map(v => LabeledPoint(v(0), v.slice(1, v.length))) - val model = LinearRegressionWithSGD.train(data, 222) + def trainRegressionModel(trainFunc: (RDD[LabeledPoint], Array[Double]) => GeneralizedLinearModel, + dataBytesJRDD: JavaRDD[Array[Byte]], initialWeightsBA: Array[Byte]): + java.util.LinkedList[java.lang.Object] = { + val data = dataBytesJRDD.rdd.map(xBytes => { + val x = deserializeDoubleVector(xBytes) + LabeledPoint(x(0), x.slice(1, x.length)) + }) + val initialWeights = deserializeDoubleVector(initialWeightsBA) + val model = trainFunc(data, initialWeights) val ret = new java.util.LinkedList[java.lang.Object]() ret.add(serializeDoubleVector(model.weights)) ret.add(model.intercept: java.lang.Double) return ret } + + def trainLinearRegressionModel(dataBytesJRDD: JavaRDD[Array[Byte]], + numIterations: Int, stepSize: Double, miniBatchFraction: Double, + initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + return trainRegressionModel((data, initialWeights) => + LinearRegressionWithSGD.train(data, numIterations, stepSize, + miniBatchFraction, initialWeights), + dataBytesJRDD, initialWeightsBA); + } + + def trainLassoModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, + stepSize: Double, regParam: Double, miniBatchFraction: Double, + initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + return trainRegressionModel((data, initialWeights) => + LassoWithSGD.train(data, numIterations, stepSize, regParam, + miniBatchFraction, initialWeights), + dataBytesJRDD, initialWeightsBA); + } + + def trainRidgeModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, + stepSize: Double, regParam: Double, miniBatchFraction: Double, + initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + return trainRegressionModel((data, initialWeights) => + RidgeRegressionWithSGD.train(data, numIterations, stepSize, regParam, + miniBatchFraction, initialWeights), + dataBytesJRDD, initialWeightsBA); + } } -- cgit v1.2.3 From 2328bdd00f701ca3b1bc7fdf8b2968fafc58fd11 Mon Sep 17 00:00:00 2001 From: Tor Myklebust Date: Thu, 19 Dec 2013 22:45:16 -0500 Subject: Python side of python bindings for linear, Lasso, and ridge regression --- python/pyspark/__init__.py | 6 ++-- python/pyspark/mllib.py | 81 ++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 72 insertions(+), 15 deletions(-) diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 9f71db397d..7c8f9148d5 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -42,7 +42,9 @@ from pyspark.context import SparkContext from pyspark.rdd import RDD from pyspark.files import SparkFiles from pyspark.storagelevel import StorageLevel -from pyspark.mllib import LinearRegressionModel +from pyspark.mllib import LinearRegressionModel, LassoModel, \ + RidgeRegressionModel -__all__ = ["SparkContext", "RDD", "SparkFiles", "StorageLevel", "LinearRegressionModel"]; +__all__ = ["SparkContext", "RDD", "SparkFiles", "StorageLevel", \ + "LinearRegressionModel", "LassoModel", "RidgeRegressionModel"]; diff --git a/python/pyspark/mllib.py b/python/pyspark/mllib.py index 0dfc4909c7..d3127874be 100644 --- a/python/pyspark/mllib.py +++ b/python/pyspark/mllib.py @@ -75,7 +75,7 @@ def _deserialize_double_matrix(ba): else: raise TypeError("_deserialize_double_matrix called on a non-bytearray") -class LinearRegressionModel(object): +class LinearModel(object): def __init__(self, coeff, intercept): self._coeff = coeff self._intercept = intercept @@ -83,7 +83,7 @@ class LinearRegressionModel(object): def predict(self, x): if (type(x) == ndarray): if (x.ndim == 1): - return dot(_coeff, x) - _intercept + return dot(_coeff, x) + _intercept else: raise RuntimeError("Bulk predict not yet supported.") elif (type(x) == RDD): @@ -92,16 +92,71 @@ class LinearRegressionModel(object): raise TypeError("Bad type argument to " "LinearRegressionModel::predict") +# Map a pickled Python RDD of numpy double vectors to a Java RDD of +# _serialized_double_vectors +def _get_unmangled_double_vector_rdd(data): + dataBytes = data.map(_serialize_double_vector) + dataBytes._bypass_serializer = True + dataBytes.cache() + return dataBytes; + +# If we weren't given initial weights, take a zero vector of the appropriate +# length. +def _get_initial_weights(initial_weights, data): + if initial_weights is None: + initial_weights = data.first() + if type(initial_weights) != ndarray: + raise TypeError("At least one data element has type " + + type(initial_weights) + " which is not ndarray") + if initial_weights.ndim != 1: + raise TypeError("At least one data element has " + + initial_weights.ndim + " dimensions, which is not 1") + initial_weights = zeros([initial_weights.shape[0] - 1]); + return initial_weights; + +# train_func should take two parameters, namely data and initial_weights, and +# return the result of a call to the appropriate JVM stub. +# _regression_train_wrapper is responsible for setup and error checking. +def _regression_train_wrapper(sc, train_func, klass, data, initial_weights): + initial_weights = _get_initial_weights(initial_weights, data) + dataBytes = _get_unmangled_double_vector_rdd(data) + ans = train_func(dataBytes, _serialize_double_vector(initial_weights)) + if len(ans) != 2: + raise RuntimeError("JVM call result had unexpected length"); + elif type(ans[0]) != bytearray: + raise RuntimeError("JVM call result had first element of type " + + type(ans[0]) + " which is not bytearray"); + elif type(ans[1]) != float: + raise RuntimeError("JVM call result had second element of type " + + type(ans[0]) + " which is not float"); + return klass(_deserialize_double_vector(ans[0]), ans[1]); + +class LinearRegressionModel(LinearModel): @classmethod - def train(cls, sc, data): + def train(cls, sc, data, iterations=100, step=1.0, + mini_batch_fraction=1.0, initial_weights=None): """Train a linear regression model on the given data.""" - dataBytes = data.map(_serialize_double_vector) - dataBytes._bypass_serializer = True - dataBytes.cache() - api = sc._jvm.PythonMLLibAPI() - ans = api.trainLinearRegressionModel(dataBytes._jrdd) - if (len(ans) != 2 or type(ans[0]) != bytearray - or type(ans[1]) != float): - raise RuntimeError("train_linear_regression_model received " - "garbage from JVM") - return LinearRegressionModel(_deserialize_double_vector(ans[0]), ans[1]) + return _regression_train_wrapper(sc, lambda d, i: + sc._jvm.PythonMLLibAPI().trainLinearRegressionModel( + d._jrdd, iterations, step, mini_batch_fraction, i), + LinearRegressionModel, data, initial_weights) + +class LassoModel(LinearModel): + @classmethod + def train(cls, sc, data, iterations=100, step=1.0, reg_param=1.0, + mini_batch_fraction=1.0, initial_weights=None): + """Train a Lasso regression model on the given data.""" + return _regression_train_wrapper(sc, lambda d, i: + sc._jvm.PythonMLLibAPI().trainLassoModel(d._jrdd, + iterations, step, reg_param, mini_batch_fraction, i), + LassoModel, data, initial_weights) + +class RidgeRegressionModel(LinearModel): + @classmethod + def train(cls, sc, data, iterations=100, step=1.0, reg_param=1.0, + mini_batch_fraction=1.0, initial_weights=None): + """Train a ridge regression model on the given data.""" + return _regression_train_wrapper(sc, lambda d, i: + sc._jvm.PythonMLLibAPI().trainRidgeModel(d._jrdd, + iterations, step, reg_param, mini_batch_fraction, i), + RidgeRegressionModel, data, initial_weights) -- cgit v1.2.3 From f99970e8cdc85eae33999b57a4c5c1893fe3727a Mon Sep 17 00:00:00 2001 From: Tor Myklebust Date: Fri, 20 Dec 2013 00:12:22 -0500 Subject: Scala classification and clustering stubs; matrix serialization/deserialization. --- .../apache/spark/mllib/api/PythonMLLibAPI.scala | 82 +++++++++++++++++++++- 1 file changed, 79 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala index c9bd7c6415..bcf2f07517 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala @@ -1,5 +1,7 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.regression._ +import org.apache.spark.mllib.classification._ +import org.apache.spark.mllib.clustering._ import org.apache.spark.rdd.RDD import java.nio.ByteBuffer import java.nio.ByteOrder @@ -39,6 +41,52 @@ class PythonMLLibAPI extends Serializable { return bytes } + def deserializeDoubleMatrix(bytes: Array[Byte]): Array[Array[Double]] = { + val packetLength = bytes.length + if (packetLength < 24) { + throw new IllegalArgumentException("Byte array too short.") + } + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + val magic = bb.getLong() + if (magic != 2) { + throw new IllegalArgumentException("Magic " + magic + " is wrong.") + } + val rows = bb.getLong() + val cols = bb.getLong() + if (packetLength != 24 + 8 * rows * cols) { + throw new IllegalArgumentException("Size " + rows + "x" + cols + "is wrong.") + } + val db = bb.asDoubleBuffer() + val ans = new Array[Array[Double]](rows.toInt) + var i = 0 + for (i <- 0 until rows.toInt) { + ans(i) = new Array[Double](cols.toInt) + db.get(ans(i)) + } + return ans + } + + def serializeDoubleMatrix(doubles: Array[Array[Double]]): Array[Byte] = { + val rows = doubles.length + var cols = 0 + if (rows > 0) { + cols = doubles(0).length + } + val bytes = new Array[Byte](24 + 8 * rows * cols) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.putLong(2) + bb.putLong(rows) + bb.putLong(cols) + val db = bb.asDoubleBuffer() + var i = 0 + for (i <- 0 until rows) { + db.put(doubles(i)) + } + return bytes + } + def trainRegressionModel(trainFunc: (RDD[LabeledPoint], Array[Double]) => GeneralizedLinearModel, dataBytesJRDD: JavaRDD[Array[Byte]], initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = { @@ -60,7 +108,7 @@ class PythonMLLibAPI extends Serializable { return trainRegressionModel((data, initialWeights) => LinearRegressionWithSGD.train(data, numIterations, stepSize, miniBatchFraction, initialWeights), - dataBytesJRDD, initialWeightsBA); + dataBytesJRDD, initialWeightsBA) } def trainLassoModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, @@ -69,7 +117,7 @@ class PythonMLLibAPI extends Serializable { return trainRegressionModel((data, initialWeights) => LassoWithSGD.train(data, numIterations, stepSize, regParam, miniBatchFraction, initialWeights), - dataBytesJRDD, initialWeightsBA); + dataBytesJRDD, initialWeightsBA) } def trainRidgeModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, @@ -78,6 +126,34 @@ class PythonMLLibAPI extends Serializable { return trainRegressionModel((data, initialWeights) => RidgeRegressionWithSGD.train(data, numIterations, stepSize, regParam, miniBatchFraction, initialWeights), - dataBytesJRDD, initialWeightsBA); + dataBytesJRDD, initialWeightsBA) + } + + def trainSVMModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, + stepSize: Double, regParam: Double, miniBatchFraction: Double, + initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + return trainRegressionModel((data, initialWeights) => + SVMWithSGD.train(data, numIterations, stepSize, regParam, + miniBatchFraction, initialWeights), + dataBytesJRDD, initialWeightsBA) + } + + def trainLogisticRegressionModel(dataBytesJRDD: JavaRDD[Array[Byte]], + numIterations: Int, stepSize: Double, miniBatchFraction: Double, + initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + return trainRegressionModel((data, initialWeights) => + LogisticRegressionWithSGD.train(data, numIterations, stepSize, + miniBatchFraction, initialWeights), + dataBytesJRDD, initialWeightsBA) + } + + def trainKMeansModel(dataBytesJRDD: JavaRDD[Array[Byte]], k: Int, + maxIterations: Int, runs: Int, initializationMode: String): + java.util.List[java.lang.Object] = { + val data = dataBytesJRDD.rdd.map(xBytes => deserializeDoubleVector(xBytes)) + val model = KMeans.train(data, k, maxIterations, runs, initializationMode) + val ret = new java.util.LinkedList[java.lang.Object]() + ret.add(serializeDoubleMatrix(model.clusterCenters)) + return ret } } -- cgit v1.2.3 From 73e17064c60c5aa2297dffbeaae4747890da0115 Mon Sep 17 00:00:00 2001 From: Tor Myklebust Date: Fri, 20 Dec 2013 00:12:48 -0500 Subject: Python stubs for classification and clustering. --- python/pyspark/__init__.py | 7 +-- python/pyspark/mllib.py | 105 +++++++++++++++++++++++++++++++++++++++------ 2 files changed, 96 insertions(+), 16 deletions(-) diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 7c8f9148d5..8b5bb79a18 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -43,8 +43,9 @@ from pyspark.rdd import RDD from pyspark.files import SparkFiles from pyspark.storagelevel import StorageLevel from pyspark.mllib import LinearRegressionModel, LassoModel, \ - RidgeRegressionModel + RidgeRegressionModel, LogisticRegressionModel, SVMModel, KMeansModel -__all__ = ["SparkContext", "RDD", "SparkFiles", "StorageLevel", \ - "LinearRegressionModel", "LassoModel", "RidgeRegressionModel"]; +__all__ = ["SparkContext", "RDD", "SparkFiles", "StorageLevel", + "LinearRegressionModel", "LassoModel", "RidgeRegressionModel", + "LogisticRegressionModel", "SVMModel", "KMeansModel"]; diff --git a/python/pyspark/mllib.py b/python/pyspark/mllib.py index d3127874be..21f3c0312c 100644 --- a/python/pyspark/mllib.py +++ b/python/pyspark/mllib.py @@ -75,22 +75,35 @@ def _deserialize_double_matrix(ba): else: raise TypeError("_deserialize_double_matrix called on a non-bytearray") +def _linear_predictor_typecheck(x, coeffs): + """Predict the class of the vector x.""" + if type(x) == ndarray: + if x.ndim == 1: + if x.shape == coeffs.shape: + pass + else: + raise RuntimeError("Got array of %d elements; wanted %d" + % shape(x)[0] % shape(coeffs)[0]) + else: + raise RuntimeError("Bulk predict not yet supported.") + elif (type(x) == RDD): + raise RuntimeError("Bulk predict not yet supported.") + else: + raise TypeError("Argument of type " + type(x) + " unsupported"); + class LinearModel(object): + """Something containing a vector of coefficients and an intercept.""" def __init__(self, coeff, intercept): self._coeff = coeff self._intercept = intercept +class LinearRegressionModelBase(LinearModel): + """A linear regression model.""" def predict(self, x): - if (type(x) == ndarray): - if (x.ndim == 1): - return dot(_coeff, x) + _intercept - else: - raise RuntimeError("Bulk predict not yet supported.") - elif (type(x) == RDD): - raise RuntimeError("Bulk predict not yet supported.") - else: - raise TypeError("Bad type argument to " - "LinearRegressionModel::predict") + """Predict the value of the dependent variable given a vector x""" + """containing values for the independent variables.""" + _linear_predictor_typecheck(x, _coeff) + return dot(_coeff, x) + _intercept # Map a pickled Python RDD of numpy double vectors to a Java RDD of # _serialized_double_vectors @@ -131,7 +144,8 @@ def _regression_train_wrapper(sc, train_func, klass, data, initial_weights): + type(ans[0]) + " which is not float"); return klass(_deserialize_double_vector(ans[0]), ans[1]); -class LinearRegressionModel(LinearModel): +class LinearRegressionModel(LinearRegressionModelBase): + """A linear regression model derived from a least-squares fit.""" @classmethod def train(cls, sc, data, iterations=100, step=1.0, mini_batch_fraction=1.0, initial_weights=None): @@ -141,7 +155,9 @@ class LinearRegressionModel(LinearModel): d._jrdd, iterations, step, mini_batch_fraction, i), LinearRegressionModel, data, initial_weights) -class LassoModel(LinearModel): +class LassoModel(LinearRegressionModelBase): + """A linear regression model derived from a least-squares fit with an """ + """l_1 penalty term.""" @classmethod def train(cls, sc, data, iterations=100, step=1.0, reg_param=1.0, mini_batch_fraction=1.0, initial_weights=None): @@ -151,7 +167,9 @@ class LassoModel(LinearModel): iterations, step, reg_param, mini_batch_fraction, i), LassoModel, data, initial_weights) -class RidgeRegressionModel(LinearModel): +class RidgeRegressionModel(LinearRegressionModelBase): + """A linear regression model derived from a least-squares fit with an """ + """l_2 penalty term.""" @classmethod def train(cls, sc, data, iterations=100, step=1.0, reg_param=1.0, mini_batch_fraction=1.0, initial_weights=None): @@ -160,3 +178,64 @@ class RidgeRegressionModel(LinearModel): sc._jvm.PythonMLLibAPI().trainRidgeModel(d._jrdd, iterations, step, reg_param, mini_batch_fraction, i), RidgeRegressionModel, data, initial_weights) + +class LogisticRegressionModel(LinearModel): + """A linear binary classification model derived from logistic regression.""" + def predict(self, x): + _linear_predictor_typecheck(x, _coeff) + margin = dot(x, _coeff) + intercept + prob = 1/(1 + exp(-margin)) + return 1 if prob > 0.5 else 0 + + @classmethod + def train(cls, sc, data, iterations=100, step=1.0, + mini_batch_fraction=1.0, initial_weights=None): + """Train a logistic regression model on the given data.""" + return _regression_train_wrapper(sc, lambda d, i: + sc._jvm.PythonMLLibAPI().trainLogisticRegressionModel(d._jrdd, + iterations, step, mini_batch_fraction, i), + LogisticRegressionModel, data, initial_weights) + +class SVMModel(LinearModel): + """A support vector machine.""" + def predict(self, x): + _linear_predictor_typecheck(x, _coeff) + margin = dot(x, _coeff) + intercept + return 1 if margin >= 0 else 0 + @classmethod + def train(cls, sc, data, iterations=100, step=1.0, reg_param=1.0, + mini_batch_fraction=1.0, initial_weights=None): + """Train a support vector machine on the given data.""" + return _regression_train_wrapper(sc, lambda d, i: + sc._jvm.PythonMLLibAPI().trainSVMModel(d._jrdd, + iterations, step, reg_param, mini_batch_fraction, i), + SVMModel, data, initial_weights) + +class KMeansModel(object): + """A clustering model derived from the k-means method.""" + def __init__(self, centers_): + self.centers = centers_ + + def predict(self, x): + best = 0 + best_distance = 1e75 + for i in range(0, centers.shape[0]): + diff = x - centers[i] + distance = sqrt(dot(diff, diff)) + if distance < best_distance: + best = i + best_distance = distance + return best + + @classmethod + def train(cls, sc, data, k, maxIterations = 100, runs = 1, + initialization_mode="k-means||"): + dataBytes = _get_unmangled_double_vector_rdd(data) + ans = sc._jvm.PythonMLLibAPI().trainKMeansModel(dataBytes._jrdd, + k, maxIterations, runs, initialization_mode) + if len(ans) != 1: + raise RuntimeError("JVM call result had unexpected length"); + elif type(ans[0]) != bytearray: + raise RuntimeError("JVM call result had first element of type " + + type(ans[0]) + " which is not bytearray"); + return KMeansModel(_deserialize_double_matrix(ans[0])); -- cgit v1.2.3 From 2940201ad86e5dee16cf7386b3c934fc75c15582 Mon Sep 17 00:00:00 2001 From: Tor Myklebust Date: Fri, 20 Dec 2013 01:33:32 -0500 Subject: Tests for the Python side of the mllib bindings. --- python/pyspark/mllib.py | 224 +++++++++++++++++++++++++++++++++++++----------- 1 file changed, 172 insertions(+), 52 deletions(-) diff --git a/python/pyspark/mllib.py b/python/pyspark/mllib.py index 21f3c0312c..aa9fc76c29 100644 --- a/python/pyspark/mllib.py +++ b/python/pyspark/mllib.py @@ -1,4 +1,5 @@ from numpy import * +from pyspark import SparkContext # Double vector format: # @@ -7,44 +8,106 @@ from numpy import * # Double matrix format: # # [8-byte 2] [8-byte rows] [8-byte cols] [rows*cols*8 bytes of data] -# +# # This is all in machine-endian. That means that the Java interpreter and the # Python interpreter must agree on what endian the machine is. def _deserialize_byte_array(shape, ba, offset): + """Wrapper around ndarray aliasing hack. + + >>> x = array([1.0, 2.0, 3.0, 4.0, 5.0]) + >>> array_equal(x, _deserialize_byte_array(x.shape, x.data, 0)) + True + >>> x = array([1.0, 2.0, 3.0, 4.0]).reshape(2,2) + >>> array_equal(x, _deserialize_byte_array(x.shape, x.data, 0)) + True + """ ar = ndarray(shape=shape, buffer=ba, offset=offset, dtype="float64", order='C') return ar.copy() def _serialize_double_vector(v): - if (type(v) == ndarray and v.dtype == float64 and v.ndim == 1): - length = v.shape[0] - ba = bytearray(16 + 8*length) - header = ndarray(shape=[2], buffer=ba, dtype="int64") - header[0] = 1 - header[1] = length - copyto(ndarray(shape=[length], buffer=ba, offset=16, - dtype="float64"), v) - return ba - else: - raise TypeError("_serialize_double_vector called on a " - "non-double-vector") + """Serialize a double vector into a mutually understood format. + + >>> _serialize_double_vector(array([])) + bytearray(b'\\x01\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00') + >>> _serialize_double_vector(array([0.0, 1.0])) + bytearray(b'\\x01\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x02\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\xf0?') + >>> _serialize_double_vector("hello, world") + Traceback (most recent call last): + File "/usr/lib/python2.7/doctest.py", line 1289, in __run + compileflags, 1) in test.globs + File "", line 1, in + _serialize_double_vector("hello, world") + File "python/pyspark/mllib.py", line 41, in _serialize_double_vector + raise TypeError("_serialize_double_vector called on a %s; wanted ndarray" % type(v)) + TypeError: _serialize_double_vector called on a ; wanted ndarray + >>> _serialize_double_vector(array([0, 1])) + Traceback (most recent call last): + File "/usr/lib/python2.7/doctest.py", line 1289, in __run + compileflags, 1) in test.globs + File "", line 1, in + _serialize_double_vector(array([0, 1])) + File "python/pyspark/mllib.py", line 51, in _serialize_double_vector + "wanted ndarray of float64" % v.dtype) + TypeError: _serialize_double_vector called on an ndarray of int64; wanted ndarray of float64 + >>> _serialize_double_vector(array([0.0, 1.0, 2.0, 3.0]).reshape(2,2)) + Traceback (most recent call last): + File "/usr/lib/python2.7/doctest.py", line 1289, in __run + compileflags, 1) in test.globs + File "", line 1, in + _serialize_double_vector(array([0.0, 1.0, 2.0, 3.0]).reshape(2,2)) + File "python/pyspark/mllib.py", line 62, in _serialize_double_vector + "wanted a 1darray" % v.ndim) + TypeError: _serialize_double_vector called on a 2darray; wanted a 1darray + """ + if type(v) != ndarray: + raise TypeError("_serialize_double_vector called on a %s; " + "wanted ndarray" % type(v)) + if v.dtype != float64: + raise TypeError("_serialize_double_vector called on an ndarray of %s; " + "wanted ndarray of float64" % v.dtype) + if v.ndim != 1: + raise TypeError("_serialize_double_vector called on a %ddarray; " + "wanted a 1darray" % v.ndim) + length = v.shape[0] + ba = bytearray(16 + 8*length) + header = ndarray(shape=[2], buffer=ba, dtype="int64") + header[0] = 1 + header[1] = length + copyto(ndarray(shape=[length], buffer=ba, offset=16, + dtype="float64"), v) + return ba def _deserialize_double_vector(ba): - if (type(ba) == bytearray and len(ba) >= 16 and (len(ba) & 7 == 0)): - header = ndarray(shape=[2], buffer=ba, dtype="int64") - if (header[0] != 1): - raise TypeError("_deserialize_double_vector called on bytearray " - "with wrong magic") - length = header[1] - if (len(ba) != 8*length + 16): - raise TypeError("_deserialize_double_vector called on bytearray " - "with wrong length") - return _deserialize_byte_array([length], ba, 16) - else: - raise TypeError("_deserialize_double_vector called on a non-bytearray") + """Deserialize a double vector from a mutually understood format. + + >>> x = array([1.0, 2.0, 3.0, 4.0, -1.0, 0.0, -0.0]) + >>> array_equal(x, _deserialize_double_vector(_serialize_double_vector(x))) + True + """ + if type(ba) != bytearray: + raise TypeError("_deserialize_double_vector called on a %s; " + "wanted bytearray" % type(ba)) + if len(ba) < 16: + raise TypeError("_deserialize_double_vector called on a %d-byte array, " + "which is too short" % len(ba)) + if (len(ba) & 7) != 0: + raise TypeError("_deserialize_double_vector called on a %d-byte array, " + "which is not a multiple of 8" % len(ba)) + header = ndarray(shape=[2], buffer=ba, dtype="int64") + if header[0] != 1: + raise TypeError("_deserialize_double_vector called on bytearray " + "with wrong magic") + length = header[1] + if len(ba) != 8*length + 16: + raise TypeError("_deserialize_double_vector called on bytearray " + "with wrong length") + return _deserialize_byte_array([length], ba, 16) def _serialize_double_matrix(m): + """Serialize a double matrix into a mutually understood format. + """ if (type(m) == ndarray and m.dtype == float64 and m.ndim == 2): rows = m.shape[0] cols = m.shape[1] @@ -61,22 +124,31 @@ def _serialize_double_matrix(m): "non-double-matrix") def _deserialize_double_matrix(ba): - if (type(ba) == bytearray and len(ba) >= 24 and (len(ba) & 7 == 0)): - header = ndarray(shape=[3], buffer=ba, dtype="int64") - if (header[0] != 2): - raise TypeError("_deserialize_double_matrix called on bytearray " - "with wrong magic") - rows = header[1] - cols = header[2] - if (len(ba) != 8*rows*cols + 24): - raise TypeError("_deserialize_double_matrix called on bytearray " - "with wrong length") - return _deserialize_byte_array([rows, cols], ba, 24) - else: - raise TypeError("_deserialize_double_matrix called on a non-bytearray") + """Deserialize a double matrix from a mutually understood format. + """ + if type(ba) != bytearray: + raise TypeError("_deserialize_double_matrix called on a %s; " + "wanted bytearray" % type(ba)) + if len(ba) < 24: + raise TypeError("_deserialize_double_matrix called on a %d-byte array, " + "which is too short" % len(ba)) + if (len(ba) & 7) != 0: + raise TypeError("_deserialize_double_matrix called on a %d-byte array, " + "which is not a multiple of 8" % len(ba)) + header = ndarray(shape=[3], buffer=ba, dtype="int64") + if (header[0] != 2): + raise TypeError("_deserialize_double_matrix called on bytearray " + "with wrong magic") + rows = header[1] + cols = header[2] + if (len(ba) != 8*rows*cols + 24): + raise TypeError("_deserialize_double_matrix called on bytearray " + "with wrong length") + return _deserialize_byte_array([rows, cols], ba, 24) def _linear_predictor_typecheck(x, coeffs): - """Predict the class of the vector x.""" + """Check that x is a one-dimensional vector of the right shape. + This is a temporary hackaround until I actually implement bulk predict.""" if type(x) == ndarray: if x.ndim == 1: if x.shape == coeffs.shape: @@ -98,12 +170,17 @@ class LinearModel(object): self._intercept = intercept class LinearRegressionModelBase(LinearModel): - """A linear regression model.""" + """A linear regression model. + + >>> lrmb = LinearRegressionModelBase(array([1.0, 2.0]), 0.1) + >>> abs(lrmb.predict(array([-1.03, 7.777])) - 14.624) < 1e-6 + True + """ def predict(self, x): """Predict the value of the dependent variable given a vector x""" """containing values for the independent variables.""" - _linear_predictor_typecheck(x, _coeff) - return dot(_coeff, x) + _intercept + _linear_predictor_typecheck(x, self._coeff) + return dot(self._coeff, x) + self._intercept # Map a pickled Python RDD of numpy double vectors to a Java RDD of # _serialized_double_vectors @@ -145,7 +222,11 @@ def _regression_train_wrapper(sc, train_func, klass, data, initial_weights): return klass(_deserialize_double_vector(ans[0]), ans[1]); class LinearRegressionModel(LinearRegressionModelBase): - """A linear regression model derived from a least-squares fit.""" + """A linear regression model derived from a least-squares fit. + + >>> data = array([0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0]).reshape(4,2) + >>> lrm = LinearRegressionModel.train(sc, sc.parallelize(data), initial_weights=array([1.0])) + """ @classmethod def train(cls, sc, data, iterations=100, step=1.0, mini_batch_fraction=1.0, initial_weights=None): @@ -156,8 +237,12 @@ class LinearRegressionModel(LinearRegressionModelBase): LinearRegressionModel, data, initial_weights) class LassoModel(LinearRegressionModelBase): - """A linear regression model derived from a least-squares fit with an """ - """l_1 penalty term.""" + """A linear regression model derived from a least-squares fit with an + l_1 penalty term. + + >>> data = array([0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0]).reshape(4,2) + >>> lrm = LassoModel.train(sc, sc.parallelize(data), initial_weights=array([1.0])) + """ @classmethod def train(cls, sc, data, iterations=100, step=1.0, reg_param=1.0, mini_batch_fraction=1.0, initial_weights=None): @@ -168,8 +253,12 @@ class LassoModel(LinearRegressionModelBase): LassoModel, data, initial_weights) class RidgeRegressionModel(LinearRegressionModelBase): - """A linear regression model derived from a least-squares fit with an """ - """l_2 penalty term.""" + """A linear regression model derived from a least-squares fit with an + l_2 penalty term. + + >>> data = array([0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0]).reshape(4,2) + >>> lrm = RidgeRegressionModel.train(sc, sc.parallelize(data), initial_weights=array([1.0])) + """ @classmethod def train(cls, sc, data, iterations=100, step=1.0, reg_param=1.0, mini_batch_fraction=1.0, initial_weights=None): @@ -180,7 +269,11 @@ class RidgeRegressionModel(LinearRegressionModelBase): RidgeRegressionModel, data, initial_weights) class LogisticRegressionModel(LinearModel): - """A linear binary classification model derived from logistic regression.""" + """A linear binary classification model derived from logistic regression. + + >>> data = array([0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 1.0, 3.0]).reshape(4,2) + >>> lrm = LogisticRegressionModel.train(sc, sc.parallelize(data)) + """ def predict(self, x): _linear_predictor_typecheck(x, _coeff) margin = dot(x, _coeff) + intercept @@ -197,7 +290,11 @@ class LogisticRegressionModel(LinearModel): LogisticRegressionModel, data, initial_weights) class SVMModel(LinearModel): - """A support vector machine.""" + """A support vector machine. + + >>> data = array([0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 1.0, 3.0]).reshape(4,2) + >>> svm = SVMModel.train(sc, sc.parallelize(data)) + """ def predict(self, x): _linear_predictor_typecheck(x, _coeff) margin = dot(x, _coeff) + intercept @@ -212,15 +309,24 @@ class SVMModel(LinearModel): SVMModel, data, initial_weights) class KMeansModel(object): - """A clustering model derived from the k-means method.""" + """A clustering model derived from the k-means method. + + >>> data = array([0.0, 0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4,2) + >>> clusters = KMeansModel.train(sc, sc.parallelize(data), 2, maxIterations=10, runs=30, initialization_mode="random") + >>> clusters.predict(array([0.0, 0.0])) == clusters.predict(array([1.0, 1.0])) + True + >>> clusters.predict(array([8.0, 9.0])) == clusters.predict(array([9.0, 8.0])) + True + >>> clusters = KMeansModel.train(sc, sc.parallelize(data), 2) + """ def __init__(self, centers_): self.centers = centers_ def predict(self, x): best = 0 best_distance = 1e75 - for i in range(0, centers.shape[0]): - diff = x - centers[i] + for i in range(0, self.centers.shape[0]): + diff = x - self.centers[i] distance = sqrt(dot(diff, diff)) if distance < best_distance: best = i @@ -239,3 +345,17 @@ class KMeansModel(object): raise RuntimeError("JVM call result had first element of type " + type(ans[0]) + " which is not bytearray"); return KMeansModel(_deserialize_double_matrix(ans[0])); + +def _test(): + import doctest + globs = globals().copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + (failure_count, test_count) = doctest.testmod(globs=globs, + optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + print failure_count,"failures among",test_count,"tests" + if failure_count: + exit(-1) + +if __name__ == "__main__": + _test() -- cgit v1.2.3 From 319520b9bb0071527a0be1e0e545ca084ac090ee Mon Sep 17 00:00:00 2001 From: Tor Myklebust Date: Fri, 20 Dec 2013 01:48:44 -0500 Subject: Remove gigantic endian-specific test and exception tests. --- python/pyspark/mllib.py | 41 +++-------------------------------------- 1 file changed, 3 insertions(+), 38 deletions(-) diff --git a/python/pyspark/mllib.py b/python/pyspark/mllib.py index aa9fc76c29..e7e22166b0 100644 --- a/python/pyspark/mllib.py +++ b/python/pyspark/mllib.py @@ -27,40 +27,7 @@ def _deserialize_byte_array(shape, ba, offset): return ar.copy() def _serialize_double_vector(v): - """Serialize a double vector into a mutually understood format. - - >>> _serialize_double_vector(array([])) - bytearray(b'\\x01\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00') - >>> _serialize_double_vector(array([0.0, 1.0])) - bytearray(b'\\x01\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x02\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\xf0?') - >>> _serialize_double_vector("hello, world") - Traceback (most recent call last): - File "/usr/lib/python2.7/doctest.py", line 1289, in __run - compileflags, 1) in test.globs - File "", line 1, in - _serialize_double_vector("hello, world") - File "python/pyspark/mllib.py", line 41, in _serialize_double_vector - raise TypeError("_serialize_double_vector called on a %s; wanted ndarray" % type(v)) - TypeError: _serialize_double_vector called on a ; wanted ndarray - >>> _serialize_double_vector(array([0, 1])) - Traceback (most recent call last): - File "/usr/lib/python2.7/doctest.py", line 1289, in __run - compileflags, 1) in test.globs - File "", line 1, in - _serialize_double_vector(array([0, 1])) - File "python/pyspark/mllib.py", line 51, in _serialize_double_vector - "wanted ndarray of float64" % v.dtype) - TypeError: _serialize_double_vector called on an ndarray of int64; wanted ndarray of float64 - >>> _serialize_double_vector(array([0.0, 1.0, 2.0, 3.0]).reshape(2,2)) - Traceback (most recent call last): - File "/usr/lib/python2.7/doctest.py", line 1289, in __run - compileflags, 1) in test.globs - File "", line 1, in - _serialize_double_vector(array([0.0, 1.0, 2.0, 3.0]).reshape(2,2)) - File "python/pyspark/mllib.py", line 62, in _serialize_double_vector - "wanted a 1darray" % v.ndim) - TypeError: _serialize_double_vector called on a 2darray; wanted a 1darray - """ + """Serialize a double vector into a mutually understood format.""" if type(v) != ndarray: raise TypeError("_serialize_double_vector called on a %s; " "wanted ndarray" % type(v)) @@ -106,8 +73,7 @@ def _deserialize_double_vector(ba): return _deserialize_byte_array([length], ba, 16) def _serialize_double_matrix(m): - """Serialize a double matrix into a mutually understood format. - """ + """Serialize a double matrix into a mutually understood format.""" if (type(m) == ndarray and m.dtype == float64 and m.ndim == 2): rows = m.shape[0] cols = m.shape[1] @@ -124,8 +90,7 @@ def _serialize_double_matrix(m): "non-double-matrix") def _deserialize_double_matrix(ba): - """Deserialize a double matrix from a mutually understood format. - """ + """Deserialize a double matrix from a mutually understood format.""" if type(ba) != bytearray: raise TypeError("_deserialize_double_matrix called on a %s; " "wanted bytearray" % type(ba)) -- cgit v1.2.3 From d89cc1e28a88a5f943fed096c4bd647f79753c9f Mon Sep 17 00:00:00 2001 From: Tor Myklebust Date: Fri, 20 Dec 2013 01:50:42 -0500 Subject: Whitespace. --- python/pyspark/mllib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/mllib.py b/python/pyspark/mllib.py index e7e22166b0..1877404cc1 100644 --- a/python/pyspark/mllib.py +++ b/python/pyspark/mllib.py @@ -276,7 +276,7 @@ class SVMModel(LinearModel): class KMeansModel(object): """A clustering model derived from the k-means method. - >>> data = array([0.0, 0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4,2) + >>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4,2) >>> clusters = KMeansModel.train(sc, sc.parallelize(data), 2, maxIterations=10, runs=30, initialization_mode="random") >>> clusters.predict(array([0.0, 0.0])) == clusters.predict(array([1.0, 1.0])) True -- cgit v1.2.3 From b835ddf3dffe8698dab3b42c14a9da472868b13c Mon Sep 17 00:00:00 2001 From: Tor Myklebust Date: Fri, 20 Dec 2013 01:55:03 -0500 Subject: Licence notice. --- .../org/apache/spark/mllib/api/PythonMLLibAPI.scala | 17 +++++++++++++++++ python/pyspark/mllib.py | 17 +++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala index bcf2f07517..bad1f66424 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.classification._ diff --git a/python/pyspark/mllib.py b/python/pyspark/mllib.py index 1877404cc1..ce1363fd17 100644 --- a/python/pyspark/mllib.py +++ b/python/pyspark/mllib.py @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + from numpy import * from pyspark import SparkContext -- cgit v1.2.3 From 0a5cacb9615d960c93bca8cc3f4ad2a599f94ec0 Mon Sep 17 00:00:00 2001 From: Tor Myklebust Date: Fri, 20 Dec 2013 02:05:15 -0500 Subject: Change some docstrings and add some others. --- python/pyspark/mllib.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/pyspark/mllib.py b/python/pyspark/mllib.py index ce1363fd17..928caa9e80 100644 --- a/python/pyspark/mllib.py +++ b/python/pyspark/mllib.py @@ -146,7 +146,7 @@ def _linear_predictor_typecheck(x, coeffs): raise TypeError("Argument of type " + type(x) + " unsupported"); class LinearModel(object): - """Something containing a vector of coefficients and an intercept.""" + """Something that has a vector of coefficients and an intercept.""" def __init__(self, coeff, intercept): self._coeff = coeff self._intercept = intercept @@ -305,6 +305,7 @@ class KMeansModel(object): self.centers = centers_ def predict(self, x): + """Find the cluster to which x belongs in this model.""" best = 0 best_distance = 1e75 for i in range(0, self.centers.shape[0]): @@ -318,6 +319,7 @@ class KMeansModel(object): @classmethod def train(cls, sc, data, k, maxIterations = 100, runs = 1, initialization_mode="k-means||"): + """Train a k-means clustering model.""" dataBytes = _get_unmangled_double_vector_rdd(data) ans = sc._jvm.PythonMLLibAPI().trainKMeansModel(dataBytes._jrdd, k, maxIterations, runs, initialization_mode) -- cgit v1.2.3 From 0b494c21675b6cc3b5d669dbd9b9a8f277216613 Mon Sep 17 00:00:00 2001 From: Tor Myklebust Date: Fri, 20 Dec 2013 02:05:55 -0500 Subject: Un-semicolon mllib.py. --- python/pyspark/mllib.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/python/pyspark/mllib.py b/python/pyspark/mllib.py index 928caa9e80..8848284a5e 100644 --- a/python/pyspark/mllib.py +++ b/python/pyspark/mllib.py @@ -143,7 +143,7 @@ def _linear_predictor_typecheck(x, coeffs): elif (type(x) == RDD): raise RuntimeError("Bulk predict not yet supported.") else: - raise TypeError("Argument of type " + type(x) + " unsupported"); + raise TypeError("Argument of type " + type(x) + " unsupported") class LinearModel(object): """Something that has a vector of coefficients and an intercept.""" @@ -170,7 +170,7 @@ def _get_unmangled_double_vector_rdd(data): dataBytes = data.map(_serialize_double_vector) dataBytes._bypass_serializer = True dataBytes.cache() - return dataBytes; + return dataBytes # If we weren't given initial weights, take a zero vector of the appropriate # length. @@ -183,8 +183,8 @@ def _get_initial_weights(initial_weights, data): if initial_weights.ndim != 1: raise TypeError("At least one data element has " + initial_weights.ndim + " dimensions, which is not 1") - initial_weights = zeros([initial_weights.shape[0] - 1]); - return initial_weights; + initial_weights = zeros([initial_weights.shape[0] - 1]) + return initial_weights # train_func should take two parameters, namely data and initial_weights, and # return the result of a call to the appropriate JVM stub. @@ -194,14 +194,14 @@ def _regression_train_wrapper(sc, train_func, klass, data, initial_weights): dataBytes = _get_unmangled_double_vector_rdd(data) ans = train_func(dataBytes, _serialize_double_vector(initial_weights)) if len(ans) != 2: - raise RuntimeError("JVM call result had unexpected length"); + raise RuntimeError("JVM call result had unexpected length") elif type(ans[0]) != bytearray: raise RuntimeError("JVM call result had first element of type " - + type(ans[0]) + " which is not bytearray"); + + type(ans[0]) + " which is not bytearray") elif type(ans[1]) != float: raise RuntimeError("JVM call result had second element of type " - + type(ans[0]) + " which is not float"); - return klass(_deserialize_double_vector(ans[0]), ans[1]); + + type(ans[0]) + " which is not float") + return klass(_deserialize_double_vector(ans[0]), ans[1]) class LinearRegressionModel(LinearRegressionModelBase): """A linear regression model derived from a least-squares fit. @@ -324,11 +324,11 @@ class KMeansModel(object): ans = sc._jvm.PythonMLLibAPI().trainKMeansModel(dataBytes._jrdd, k, maxIterations, runs, initialization_mode) if len(ans) != 1: - raise RuntimeError("JVM call result had unexpected length"); + raise RuntimeError("JVM call result had unexpected length") elif type(ans[0]) != bytearray: raise RuntimeError("JVM call result had first element of type " - + type(ans[0]) + " which is not bytearray"); - return KMeansModel(_deserialize_double_matrix(ans[0])); + + type(ans[0]) + " which is not bytearray") + return KMeansModel(_deserialize_double_matrix(ans[0])) def _test(): import doctest -- cgit v1.2.3 From b454fdc2ebc495e4d13162f4bea8cf3e33909463 Mon Sep 17 00:00:00 2001 From: Tor Myklebust Date: Fri, 20 Dec 2013 02:10:21 -0500 Subject: Javadocs; also, declare some things private. --- .../apache/spark/mllib/api/PythonMLLibAPI.scala | 31 ++++++++++++++++++---- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala index bad1f66424..6472bf6367 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala @@ -24,8 +24,11 @@ import java.nio.ByteBuffer import java.nio.ByteOrder import java.nio.DoubleBuffer +/** + * The Java stubs necessary for the Python mllib bindings. + */ class PythonMLLibAPI extends Serializable { - def deserializeDoubleVector(bytes: Array[Byte]): Array[Double] = { + private def deserializeDoubleVector(bytes: Array[Byte]): Array[Double] = { val packetLength = bytes.length if (packetLength < 16) { throw new IllegalArgumentException("Byte array too short.") @@ -46,7 +49,7 @@ class PythonMLLibAPI extends Serializable { return ans } - def serializeDoubleVector(doubles: Array[Double]): Array[Byte] = { + private def serializeDoubleVector(doubles: Array[Double]): Array[Byte] = { val len = doubles.length val bytes = new Array[Byte](16 + 8 * len) val bb = ByteBuffer.wrap(bytes) @@ -58,7 +61,7 @@ class PythonMLLibAPI extends Serializable { return bytes } - def deserializeDoubleMatrix(bytes: Array[Byte]): Array[Array[Double]] = { + private def deserializeDoubleMatrix(bytes: Array[Byte]): Array[Array[Double]] = { val packetLength = bytes.length if (packetLength < 24) { throw new IllegalArgumentException("Byte array too short.") @@ -84,7 +87,7 @@ class PythonMLLibAPI extends Serializable { return ans } - def serializeDoubleMatrix(doubles: Array[Array[Double]]): Array[Byte] = { + private def serializeDoubleMatrix(doubles: Array[Array[Double]]): Array[Byte] = { val rows = doubles.length var cols = 0 if (rows > 0) { @@ -104,7 +107,7 @@ class PythonMLLibAPI extends Serializable { return bytes } - def trainRegressionModel(trainFunc: (RDD[LabeledPoint], Array[Double]) => GeneralizedLinearModel, + private def trainRegressionModel(trainFunc: (RDD[LabeledPoint], Array[Double]) => GeneralizedLinearModel, dataBytesJRDD: JavaRDD[Array[Byte]], initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = { val data = dataBytesJRDD.rdd.map(xBytes => { @@ -119,6 +122,9 @@ class PythonMLLibAPI extends Serializable { return ret } + /** + * Java stub for Python mllib LinearRegressionModel.train() + */ def trainLinearRegressionModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, stepSize: Double, miniBatchFraction: Double, initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { @@ -128,6 +134,9 @@ class PythonMLLibAPI extends Serializable { dataBytesJRDD, initialWeightsBA) } + /** + * Java stub for Python mllib LassoModel.train() + */ def trainLassoModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, stepSize: Double, regParam: Double, miniBatchFraction: Double, initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { @@ -137,6 +146,9 @@ class PythonMLLibAPI extends Serializable { dataBytesJRDD, initialWeightsBA) } + /** + * Java stub for Python mllib RidgeRegressionModel.train() + */ def trainRidgeModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, stepSize: Double, regParam: Double, miniBatchFraction: Double, initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { @@ -146,6 +158,9 @@ class PythonMLLibAPI extends Serializable { dataBytesJRDD, initialWeightsBA) } + /** + * Java stub for Python mllib SVMModel.train() + */ def trainSVMModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, stepSize: Double, regParam: Double, miniBatchFraction: Double, initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { @@ -155,6 +170,9 @@ class PythonMLLibAPI extends Serializable { dataBytesJRDD, initialWeightsBA) } + /** + * Java stub for Python mllib LogisticRegressionModel.train() + */ def trainLogisticRegressionModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, stepSize: Double, miniBatchFraction: Double, initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { @@ -164,6 +182,9 @@ class PythonMLLibAPI extends Serializable { dataBytesJRDD, initialWeightsBA) } + /** + * Java stub for Python mllib KMeansModel.train() + */ def trainKMeansModel(dataBytesJRDD: JavaRDD[Array[Byte]], k: Int, maxIterations: Int, runs: Int, initializationMode: String): java.util.List[java.lang.Object] = { -- cgit v1.2.3 From 30186aa2648f90d0ad4e312d28e99c9378ea317a Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Fri, 20 Dec 2013 14:58:04 -0800 Subject: Renamed ClusterScheduler to TaskSchedulerImpl --- .../main/scala/org/apache/spark/SparkContext.scala | 20 +- .../apache/spark/scheduler/ClusterScheduler.scala | 473 --------------------- .../apache/spark/scheduler/TaskResultGetter.scala | 2 +- .../apache/spark/scheduler/TaskSchedulerImpl.scala | 473 +++++++++++++++++++++ .../apache/spark/scheduler/TaskSetManager.scala | 2 +- .../cluster/CoarseGrainedSchedulerBackend.scala | 4 +- .../scheduler/cluster/SimrSchedulerBackend.scala | 4 +- .../cluster/SparkDeploySchedulerBackend.scala | 4 +- .../mesos/CoarseMesosSchedulerBackend.scala | 4 +- .../cluster/mesos/MesosSchedulerBackend.scala | 4 +- .../spark/scheduler/local/LocalBackend.scala | 6 +- .../spark/SparkContextSchedulerCreationSuite.scala | 6 +- .../spark/scheduler/ClusterSchedulerSuite.scala | 10 +- .../spark/scheduler/TaskResultGetterSuite.scala | 6 +- .../spark/scheduler/TaskSetManagerSuite.scala | 2 +- 15 files changed, 510 insertions(+), 510 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 663b473e5d..ad3337d94c 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1044,25 +1044,25 @@ object SparkContext { master match { case "local" => - val scheduler = new ClusterScheduler(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) + val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) val backend = new LocalBackend(scheduler, 1) scheduler.initialize(backend) scheduler case LOCAL_N_REGEX(threads) => - val scheduler = new ClusterScheduler(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) + val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) val backend = new LocalBackend(scheduler, threads.toInt) scheduler.initialize(backend) scheduler case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => - val scheduler = new ClusterScheduler(sc, maxFailures.toInt, isLocal = true) + val scheduler = new TaskSchedulerImpl(sc, maxFailures.toInt, isLocal = true) val backend = new LocalBackend(scheduler, threads.toInt) scheduler.initialize(backend) scheduler case SPARK_REGEX(sparkUrl) => - val scheduler = new ClusterScheduler(sc) + val scheduler = new TaskSchedulerImpl(sc) val masterUrls = sparkUrl.split(",").map("spark://" + _) val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls, appName) scheduler.initialize(backend) @@ -1077,7 +1077,7 @@ object SparkContext { memoryPerSlaveInt, SparkContext.executorMemoryRequested)) } - val scheduler = new ClusterScheduler(sc) + val scheduler = new TaskSchedulerImpl(sc) val localCluster = new LocalSparkCluster( numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt) val masterUrls = localCluster.start() @@ -1092,7 +1092,7 @@ object SparkContext { val scheduler = try { val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler") val cons = clazz.getConstructor(classOf[SparkContext]) - cons.newInstance(sc).asInstanceOf[ClusterScheduler] + cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl] } catch { // TODO: Enumerate the exact reasons why it can fail // But irrespective of it, it means we cannot proceed ! @@ -1108,7 +1108,7 @@ object SparkContext { val scheduler = try { val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientClusterScheduler") val cons = clazz.getConstructor(classOf[SparkContext]) - cons.newInstance(sc).asInstanceOf[ClusterScheduler] + cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl] } catch { case th: Throwable => { @@ -1118,7 +1118,7 @@ object SparkContext { val backend = try { val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend") - val cons = clazz.getConstructor(classOf[ClusterScheduler], classOf[SparkContext]) + val cons = clazz.getConstructor(classOf[TaskSchedulerImpl], classOf[SparkContext]) cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend] } catch { case th: Throwable => { @@ -1131,7 +1131,7 @@ object SparkContext { case mesosUrl @ MESOS_REGEX(_) => MesosNativeLibrary.load() - val scheduler = new ClusterScheduler(sc) + val scheduler = new TaskSchedulerImpl(sc) val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean val url = mesosUrl.stripPrefix("mesos://") // strip scheme from raw Mesos URLs val backend = if (coarseGrained) { @@ -1143,7 +1143,7 @@ object SparkContext { scheduler case SIMR_REGEX(simrUrl) => - val scheduler = new ClusterScheduler(sc) + val scheduler = new TaskSchedulerImpl(sc) val backend = new SimrSchedulerBackend(scheduler, sc, simrUrl) scheduler.initialize(backend) scheduler diff --git a/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala deleted file mode 100644 index 1ad735bc04..0000000000 --- a/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala +++ /dev/null @@ -1,473 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler - -import java.nio.ByteBuffer -import java.util.concurrent.atomic.AtomicLong -import java.util.{TimerTask, Timer} - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet -import scala.concurrent.duration._ - -import org.apache.spark._ -import org.apache.spark.TaskState.TaskState -import org.apache.spark.scheduler.SchedulingMode.SchedulingMode - -/** - * Schedules tasks for multiple types of clusters by acting through a SchedulerBackend. - * It can also work with a local setup by using a LocalBackend and setting isLocal to true. - * It handles common logic, like determining a scheduling order across jobs, waking up to launch - * speculative tasks, etc. - * - * Clients should first call initialize() and start(), then submit task sets through the - * runTasks method. - * - * THREADING: SchedulerBackends and task-submitting clients can call this class from multiple - * threads, so it needs locks in public API methods to maintain its state. In addition, some - * SchedulerBackends sycnchronize on themselves when they want to send events here, and then - * acquire a lock on us, so we need to make sure that we don't try to lock the backend while - * we are holding a lock on ourselves. - */ -private[spark] class ClusterScheduler( - val sc: SparkContext, - val maxTaskFailures : Int = System.getProperty("spark.task.maxFailures", "4").toInt, - isLocal: Boolean = false) extends TaskScheduler with Logging { - - // How often to check for speculative tasks - val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong - - // Threshold above which we warn user initial TaskSet may be starved - val STARVATION_TIMEOUT = System.getProperty("spark.starvation.timeout", "15000").toLong - - // TaskSetManagers are not thread safe, so any access to one should be synchronized - // on this class. - val activeTaskSets = new HashMap[String, TaskSetManager] - - val taskIdToTaskSetId = new HashMap[Long, String] - val taskIdToExecutorId = new HashMap[Long, String] - val taskSetTaskIds = new HashMap[String, HashSet[Long]] - - @volatile private var hasReceivedTask = false - @volatile private var hasLaunchedTask = false - private val starvationTimer = new Timer(true) - - // Incrementing task IDs - val nextTaskId = new AtomicLong(0) - - // Which executor IDs we have executors on - val activeExecutorIds = new HashSet[String] - - // The set of executors we have on each host; this is used to compute hostsAlive, which - // in turn is used to decide when we can attain data locality on a given host - private val executorsByHost = new HashMap[String, HashSet[String]] - - private val executorIdToHost = new HashMap[String, String] - - // Listener object to pass upcalls into - var dagScheduler: DAGScheduler = null - - var backend: SchedulerBackend = null - - val mapOutputTracker = SparkEnv.get.mapOutputTracker - - var schedulableBuilder: SchedulableBuilder = null - var rootPool: Pool = null - // default scheduler is FIFO - val schedulingMode: SchedulingMode = SchedulingMode.withName( - System.getProperty("spark.scheduler.mode", "FIFO")) - - // This is a var so that we can reset it for testing purposes. - private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this) - - override def setDAGScheduler(dagScheduler: DAGScheduler) { - this.dagScheduler = dagScheduler - } - - def initialize(backend: SchedulerBackend) { - this.backend = backend - // temporarily set rootPool name to empty - rootPool = new Pool("", schedulingMode, 0, 0) - schedulableBuilder = { - schedulingMode match { - case SchedulingMode.FIFO => - new FIFOSchedulableBuilder(rootPool) - case SchedulingMode.FAIR => - new FairSchedulableBuilder(rootPool) - } - } - schedulableBuilder.buildPools() - } - - def newTaskId(): Long = nextTaskId.getAndIncrement() - - override def start() { - backend.start() - - if (!isLocal && System.getProperty("spark.speculation", "false").toBoolean) { - logInfo("Starting speculative execution thread") - import sc.env.actorSystem.dispatcher - sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL milliseconds, - SPECULATION_INTERVAL milliseconds) { - checkSpeculatableTasks() - } - } - } - - override def submitTasks(taskSet: TaskSet) { - val tasks = taskSet.tasks - logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") - this.synchronized { - val manager = new TaskSetManager(this, taskSet, maxTaskFailures) - activeTaskSets(taskSet.id) = manager - schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) - taskSetTaskIds(taskSet.id) = new HashSet[Long]() - - if (!isLocal && !hasReceivedTask) { - starvationTimer.scheduleAtFixedRate(new TimerTask() { - override def run() { - if (!hasLaunchedTask) { - logWarning("Initial job has not accepted any resources; " + - "check your cluster UI to ensure that workers are registered " + - "and have sufficient memory") - } else { - this.cancel() - } - } - }, STARVATION_TIMEOUT, STARVATION_TIMEOUT) - } - hasReceivedTask = true - } - backend.reviveOffers() - } - - override def cancelTasks(stageId: Int): Unit = synchronized { - logInfo("Cancelling stage " + stageId) - activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) => - // There are two possible cases here: - // 1. The task set manager has been created and some tasks have been scheduled. - // In this case, send a kill signal to the executors to kill the task and then abort - // the stage. - // 2. The task set manager has been created but no tasks has been scheduled. In this case, - // simply abort the stage. - val taskIds = taskSetTaskIds(tsm.taskSet.id) - if (taskIds.size > 0) { - taskIds.foreach { tid => - val execId = taskIdToExecutorId(tid) - backend.killTask(tid, execId) - } - } - logInfo("Stage %d was cancelled".format(stageId)) - tsm.removeAllRunningTasks() - taskSetFinished(tsm) - } - } - - def taskSetFinished(manager: TaskSetManager): Unit = synchronized { - // Check to see if the given task set has been removed. This is possible in the case of - // multiple unrecoverable task failures (e.g. if the entire task set is killed when it has - // more than one running tasks). - if (activeTaskSets.contains(manager.taskSet.id)) { - activeTaskSets -= manager.taskSet.id - manager.parent.removeSchedulable(manager) - logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name)) - taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) - taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id) - taskSetTaskIds.remove(manager.taskSet.id) - } - } - - /** - * Called by cluster manager to offer resources on slaves. We respond by asking our active task - * sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so - * that tasks are balanced across the cluster. - */ - def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized { - SparkEnv.set(sc.env) - - // Mark each slave as alive and remember its hostname - for (o <- offers) { - executorIdToHost(o.executorId) = o.host - if (!executorsByHost.contains(o.host)) { - executorsByHost(o.host) = new HashSet[String]() - executorGained(o.executorId, o.host) - } - } - - // Build a list of tasks to assign to each worker - val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores)) - val availableCpus = offers.map(o => o.cores).toArray - val sortedTaskSets = rootPool.getSortedTaskSetQueue() - for (taskSet <- sortedTaskSets) { - logDebug("parentName: %s, name: %s, runningTasks: %s".format( - taskSet.parent.name, taskSet.name, taskSet.runningTasks)) - } - - // Take each TaskSet in our scheduling order, and then offer it each node in increasing order - // of locality levels so that it gets a chance to launch local tasks on all of them. - var launchedTask = false - for (taskSet <- sortedTaskSets; maxLocality <- TaskLocality.values) { - do { - launchedTask = false - for (i <- 0 until offers.size) { - val execId = offers(i).executorId - val host = offers(i).host - for (task <- taskSet.resourceOffer(execId, host, availableCpus(i), maxLocality)) { - tasks(i) += task - val tid = task.taskId - taskIdToTaskSetId(tid) = taskSet.taskSet.id - taskSetTaskIds(taskSet.taskSet.id) += tid - taskIdToExecutorId(tid) = execId - activeExecutorIds += execId - executorsByHost(host) += execId - availableCpus(i) -= 1 - launchedTask = true - } - } - } while (launchedTask) - } - - if (tasks.size > 0) { - hasLaunchedTask = true - } - return tasks - } - - def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { - var failedExecutor: Option[String] = None - synchronized { - try { - if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) { - // We lost this entire executor, so remember that it's gone - val execId = taskIdToExecutorId(tid) - if (activeExecutorIds.contains(execId)) { - removeExecutor(execId) - failedExecutor = Some(execId) - } - } - taskIdToTaskSetId.get(tid) match { - case Some(taskSetId) => - if (TaskState.isFinished(state)) { - taskIdToTaskSetId.remove(tid) - if (taskSetTaskIds.contains(taskSetId)) { - taskSetTaskIds(taskSetId) -= tid - } - taskIdToExecutorId.remove(tid) - } - activeTaskSets.get(taskSetId).foreach { taskSet => - if (state == TaskState.FINISHED) { - taskSet.removeRunningTask(tid) - taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData) - } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) { - taskSet.removeRunningTask(tid) - taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData) - } - } - case None => - logInfo("Ignoring update from TID " + tid + " because its task set is gone") - } - } catch { - case e: Exception => logError("Exception in statusUpdate", e) - } - } - // Update the DAGScheduler without holding a lock on this, since that can deadlock - if (failedExecutor != None) { - dagScheduler.executorLost(failedExecutor.get) - backend.reviveOffers() - } - } - - def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long) { - taskSetManager.handleTaskGettingResult(tid) - } - - def handleSuccessfulTask( - taskSetManager: TaskSetManager, - tid: Long, - taskResult: DirectTaskResult[_]) = synchronized { - taskSetManager.handleSuccessfulTask(tid, taskResult) - } - - def handleFailedTask( - taskSetManager: TaskSetManager, - tid: Long, - taskState: TaskState, - reason: Option[TaskEndReason]) = synchronized { - taskSetManager.handleFailedTask(tid, taskState, reason) - if (taskState != TaskState.KILLED) { - // Need to revive offers again now that the task set manager state has been updated to - // reflect failed tasks that need to be re-run. - backend.reviveOffers() - } - } - - def error(message: String) { - synchronized { - if (activeTaskSets.size > 0) { - // Have each task set throw a SparkException with the error - for ((taskSetId, manager) <- activeTaskSets) { - try { - manager.error(message) - } catch { - case e: Exception => logError("Exception in error callback", e) - } - } - } else { - // No task sets are active but we still got an error. Just exit since this - // must mean the error is during registration. - // It might be good to do something smarter here in the future. - logError("Exiting due to error from cluster scheduler: " + message) - System.exit(1) - } - } - } - - override def stop() { - if (backend != null) { - backend.stop() - } - if (taskResultGetter != null) { - taskResultGetter.stop() - } - - // sleeping for an arbitrary 5 seconds : to ensure that messages are sent out. - // TODO: Do something better ! - Thread.sleep(5000L) - } - - override def defaultParallelism() = backend.defaultParallelism() - - // Check for speculatable tasks in all our active jobs. - def checkSpeculatableTasks() { - var shouldRevive = false - synchronized { - shouldRevive = rootPool.checkSpeculatableTasks() - } - if (shouldRevive) { - backend.reviveOffers() - } - } - - // Check for pending tasks in all our active jobs. - def hasPendingTasks: Boolean = { - synchronized { - rootPool.hasPendingTasks() - } - } - - def executorLost(executorId: String, reason: ExecutorLossReason) { - var failedExecutor: Option[String] = None - - synchronized { - if (activeExecutorIds.contains(executorId)) { - val hostPort = executorIdToHost(executorId) - logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason)) - removeExecutor(executorId) - failedExecutor = Some(executorId) - } else { - // We may get multiple executorLost() calls with different loss reasons. For example, one - // may be triggered by a dropped connection from the slave while another may be a report - // of executor termination from Mesos. We produce log messages for both so we eventually - // report the termination reason. - logError("Lost an executor " + executorId + " (already removed): " + reason) - } - } - // Call dagScheduler.executorLost without holding the lock on this to prevent deadlock - if (failedExecutor != None) { - dagScheduler.executorLost(failedExecutor.get) - backend.reviveOffers() - } - } - - /** Remove an executor from all our data structures and mark it as lost */ - private def removeExecutor(executorId: String) { - activeExecutorIds -= executorId - val host = executorIdToHost(executorId) - val execs = executorsByHost.getOrElse(host, new HashSet) - execs -= executorId - if (execs.isEmpty) { - executorsByHost -= host - } - executorIdToHost -= executorId - rootPool.executorLost(executorId, host) - } - - def executorGained(execId: String, host: String) { - dagScheduler.executorGained(execId, host) - } - - def getExecutorsAliveOnHost(host: String): Option[Set[String]] = synchronized { - executorsByHost.get(host).map(_.toSet) - } - - def hasExecutorsAliveOnHost(host: String): Boolean = synchronized { - executorsByHost.contains(host) - } - - def isExecutorAlive(execId: String): Boolean = synchronized { - activeExecutorIds.contains(execId) - } - - // By default, rack is unknown - def getRackForHost(value: String): Option[String] = None -} - - -private[spark] object ClusterScheduler { - /** - * Used to balance containers across hosts. - * - * Accepts a map of hosts to resource offers for that host, and returns a prioritized list of - * resource offers representing the order in which the offers should be used. The resource - * offers are ordered such that we'll allocate one container on each host before allocating a - * second container on any host, and so on, in order to reduce the damage if a host fails. - * - * For example, given , , , returns - * [o1, o5, o4, 02, o6, o3] - */ - def prioritizeContainers[K, T] (map: HashMap[K, ArrayBuffer[T]]): List[T] = { - val _keyList = new ArrayBuffer[K](map.size) - _keyList ++= map.keys - - // order keyList based on population of value in map - val keyList = _keyList.sortWith( - (left, right) => map(left).size > map(right).size - ) - - val retval = new ArrayBuffer[T](keyList.size * 2) - var index = 0 - var found = true - - while (found) { - found = false - for (key <- keyList) { - val containerList: ArrayBuffer[T] = map.get(key).getOrElse(null) - assert(containerList != null) - // Get the index'th entry for this host - if present - if (index < containerList.size){ - retval += containerList.apply(index) - found = true - } - } - index += 1 - } - - retval.toList - } -} diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 7b5543e222..89102720fa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -28,7 +28,7 @@ import org.apache.spark.util.Utils /** * Runs a thread pool that deserializes and remotely fetches (if necessary) task results. */ -private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler) +private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedulerImpl) extends Logging { private val THREADS = System.getProperty("spark.resultGetter.threads", "4").toInt private val getTaskResultExecutor = Utils.newDaemonFixedThreadPool( diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala new file mode 100644 index 0000000000..7409168f7b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -0,0 +1,473 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.nio.ByteBuffer +import java.util.concurrent.atomic.AtomicLong +import java.util.{TimerTask, Timer} + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet +import scala.concurrent.duration._ + +import org.apache.spark._ +import org.apache.spark.TaskState.TaskState +import org.apache.spark.scheduler.SchedulingMode.SchedulingMode + +/** + * Schedules tasks for multiple types of clusters by acting through a SchedulerBackend. + * It can also work with a local setup by using a LocalBackend and setting isLocal to true. + * It handles common logic, like determining a scheduling order across jobs, waking up to launch + * speculative tasks, etc. + * + * Clients should first call initialize() and start(), then submit task sets through the + * runTasks method. + * + * THREADING: SchedulerBackends and task-submitting clients can call this class from multiple + * threads, so it needs locks in public API methods to maintain its state. In addition, some + * SchedulerBackends sycnchronize on themselves when they want to send events here, and then + * acquire a lock on us, so we need to make sure that we don't try to lock the backend while + * we are holding a lock on ourselves. + */ +private[spark] class TaskSchedulerImpl( + val sc: SparkContext, + val maxTaskFailures : Int = System.getProperty("spark.task.maxFailures", "4").toInt, + isLocal: Boolean = false) extends TaskScheduler with Logging { + + // How often to check for speculative tasks + val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong + + // Threshold above which we warn user initial TaskSet may be starved + val STARVATION_TIMEOUT = System.getProperty("spark.starvation.timeout", "15000").toLong + + // TaskSetManagers are not thread safe, so any access to one should be synchronized + // on this class. + val activeTaskSets = new HashMap[String, TaskSetManager] + + val taskIdToTaskSetId = new HashMap[Long, String] + val taskIdToExecutorId = new HashMap[Long, String] + val taskSetTaskIds = new HashMap[String, HashSet[Long]] + + @volatile private var hasReceivedTask = false + @volatile private var hasLaunchedTask = false + private val starvationTimer = new Timer(true) + + // Incrementing task IDs + val nextTaskId = new AtomicLong(0) + + // Which executor IDs we have executors on + val activeExecutorIds = new HashSet[String] + + // The set of executors we have on each host; this is used to compute hostsAlive, which + // in turn is used to decide when we can attain data locality on a given host + private val executorsByHost = new HashMap[String, HashSet[String]] + + private val executorIdToHost = new HashMap[String, String] + + // Listener object to pass upcalls into + var dagScheduler: DAGScheduler = null + + var backend: SchedulerBackend = null + + val mapOutputTracker = SparkEnv.get.mapOutputTracker + + var schedulableBuilder: SchedulableBuilder = null + var rootPool: Pool = null + // default scheduler is FIFO + val schedulingMode: SchedulingMode = SchedulingMode.withName( + System.getProperty("spark.scheduler.mode", "FIFO")) + + // This is a var so that we can reset it for testing purposes. + private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this) + + override def setDAGScheduler(dagScheduler: DAGScheduler) { + this.dagScheduler = dagScheduler + } + + def initialize(backend: SchedulerBackend) { + this.backend = backend + // temporarily set rootPool name to empty + rootPool = new Pool("", schedulingMode, 0, 0) + schedulableBuilder = { + schedulingMode match { + case SchedulingMode.FIFO => + new FIFOSchedulableBuilder(rootPool) + case SchedulingMode.FAIR => + new FairSchedulableBuilder(rootPool) + } + } + schedulableBuilder.buildPools() + } + + def newTaskId(): Long = nextTaskId.getAndIncrement() + + override def start() { + backend.start() + + if (!isLocal && System.getProperty("spark.speculation", "false").toBoolean) { + logInfo("Starting speculative execution thread") + import sc.env.actorSystem.dispatcher + sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL milliseconds, + SPECULATION_INTERVAL milliseconds) { + checkSpeculatableTasks() + } + } + } + + override def submitTasks(taskSet: TaskSet) { + val tasks = taskSet.tasks + logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") + this.synchronized { + val manager = new TaskSetManager(this, taskSet, maxTaskFailures) + activeTaskSets(taskSet.id) = manager + schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) + taskSetTaskIds(taskSet.id) = new HashSet[Long]() + + if (!isLocal && !hasReceivedTask) { + starvationTimer.scheduleAtFixedRate(new TimerTask() { + override def run() { + if (!hasLaunchedTask) { + logWarning("Initial job has not accepted any resources; " + + "check your cluster UI to ensure that workers are registered " + + "and have sufficient memory") + } else { + this.cancel() + } + } + }, STARVATION_TIMEOUT, STARVATION_TIMEOUT) + } + hasReceivedTask = true + } + backend.reviveOffers() + } + + override def cancelTasks(stageId: Int): Unit = synchronized { + logInfo("Cancelling stage " + stageId) + activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) => + // There are two possible cases here: + // 1. The task set manager has been created and some tasks have been scheduled. + // In this case, send a kill signal to the executors to kill the task and then abort + // the stage. + // 2. The task set manager has been created but no tasks has been scheduled. In this case, + // simply abort the stage. + val taskIds = taskSetTaskIds(tsm.taskSet.id) + if (taskIds.size > 0) { + taskIds.foreach { tid => + val execId = taskIdToExecutorId(tid) + backend.killTask(tid, execId) + } + } + logInfo("Stage %d was cancelled".format(stageId)) + tsm.removeAllRunningTasks() + taskSetFinished(tsm) + } + } + + def taskSetFinished(manager: TaskSetManager): Unit = synchronized { + // Check to see if the given task set has been removed. This is possible in the case of + // multiple unrecoverable task failures (e.g. if the entire task set is killed when it has + // more than one running tasks). + if (activeTaskSets.contains(manager.taskSet.id)) { + activeTaskSets -= manager.taskSet.id + manager.parent.removeSchedulable(manager) + logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name)) + taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) + taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id) + taskSetTaskIds.remove(manager.taskSet.id) + } + } + + /** + * Called by cluster manager to offer resources on slaves. We respond by asking our active task + * sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so + * that tasks are balanced across the cluster. + */ + def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized { + SparkEnv.set(sc.env) + + // Mark each slave as alive and remember its hostname + for (o <- offers) { + executorIdToHost(o.executorId) = o.host + if (!executorsByHost.contains(o.host)) { + executorsByHost(o.host) = new HashSet[String]() + executorGained(o.executorId, o.host) + } + } + + // Build a list of tasks to assign to each worker + val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores)) + val availableCpus = offers.map(o => o.cores).toArray + val sortedTaskSets = rootPool.getSortedTaskSetQueue() + for (taskSet <- sortedTaskSets) { + logDebug("parentName: %s, name: %s, runningTasks: %s".format( + taskSet.parent.name, taskSet.name, taskSet.runningTasks)) + } + + // Take each TaskSet in our scheduling order, and then offer it each node in increasing order + // of locality levels so that it gets a chance to launch local tasks on all of them. + var launchedTask = false + for (taskSet <- sortedTaskSets; maxLocality <- TaskLocality.values) { + do { + launchedTask = false + for (i <- 0 until offers.size) { + val execId = offers(i).executorId + val host = offers(i).host + for (task <- taskSet.resourceOffer(execId, host, availableCpus(i), maxLocality)) { + tasks(i) += task + val tid = task.taskId + taskIdToTaskSetId(tid) = taskSet.taskSet.id + taskSetTaskIds(taskSet.taskSet.id) += tid + taskIdToExecutorId(tid) = execId + activeExecutorIds += execId + executorsByHost(host) += execId + availableCpus(i) -= 1 + launchedTask = true + } + } + } while (launchedTask) + } + + if (tasks.size > 0) { + hasLaunchedTask = true + } + return tasks + } + + def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { + var failedExecutor: Option[String] = None + synchronized { + try { + if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) { + // We lost this entire executor, so remember that it's gone + val execId = taskIdToExecutorId(tid) + if (activeExecutorIds.contains(execId)) { + removeExecutor(execId) + failedExecutor = Some(execId) + } + } + taskIdToTaskSetId.get(tid) match { + case Some(taskSetId) => + if (TaskState.isFinished(state)) { + taskIdToTaskSetId.remove(tid) + if (taskSetTaskIds.contains(taskSetId)) { + taskSetTaskIds(taskSetId) -= tid + } + taskIdToExecutorId.remove(tid) + } + activeTaskSets.get(taskSetId).foreach { taskSet => + if (state == TaskState.FINISHED) { + taskSet.removeRunningTask(tid) + taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData) + } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) { + taskSet.removeRunningTask(tid) + taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData) + } + } + case None => + logInfo("Ignoring update from TID " + tid + " because its task set is gone") + } + } catch { + case e: Exception => logError("Exception in statusUpdate", e) + } + } + // Update the DAGScheduler without holding a lock on this, since that can deadlock + if (failedExecutor != None) { + dagScheduler.executorLost(failedExecutor.get) + backend.reviveOffers() + } + } + + def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long) { + taskSetManager.handleTaskGettingResult(tid) + } + + def handleSuccessfulTask( + taskSetManager: TaskSetManager, + tid: Long, + taskResult: DirectTaskResult[_]) = synchronized { + taskSetManager.handleSuccessfulTask(tid, taskResult) + } + + def handleFailedTask( + taskSetManager: TaskSetManager, + tid: Long, + taskState: TaskState, + reason: Option[TaskEndReason]) = synchronized { + taskSetManager.handleFailedTask(tid, taskState, reason) + if (taskState != TaskState.KILLED) { + // Need to revive offers again now that the task set manager state has been updated to + // reflect failed tasks that need to be re-run. + backend.reviveOffers() + } + } + + def error(message: String) { + synchronized { + if (activeTaskSets.size > 0) { + // Have each task set throw a SparkException with the error + for ((taskSetId, manager) <- activeTaskSets) { + try { + manager.error(message) + } catch { + case e: Exception => logError("Exception in error callback", e) + } + } + } else { + // No task sets are active but we still got an error. Just exit since this + // must mean the error is during registration. + // It might be good to do something smarter here in the future. + logError("Exiting due to error from cluster scheduler: " + message) + System.exit(1) + } + } + } + + override def stop() { + if (backend != null) { + backend.stop() + } + if (taskResultGetter != null) { + taskResultGetter.stop() + } + + // sleeping for an arbitrary 5 seconds : to ensure that messages are sent out. + // TODO: Do something better ! + Thread.sleep(5000L) + } + + override def defaultParallelism() = backend.defaultParallelism() + + // Check for speculatable tasks in all our active jobs. + def checkSpeculatableTasks() { + var shouldRevive = false + synchronized { + shouldRevive = rootPool.checkSpeculatableTasks() + } + if (shouldRevive) { + backend.reviveOffers() + } + } + + // Check for pending tasks in all our active jobs. + def hasPendingTasks: Boolean = { + synchronized { + rootPool.hasPendingTasks() + } + } + + def executorLost(executorId: String, reason: ExecutorLossReason) { + var failedExecutor: Option[String] = None + + synchronized { + if (activeExecutorIds.contains(executorId)) { + val hostPort = executorIdToHost(executorId) + logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason)) + removeExecutor(executorId) + failedExecutor = Some(executorId) + } else { + // We may get multiple executorLost() calls with different loss reasons. For example, one + // may be triggered by a dropped connection from the slave while another may be a report + // of executor termination from Mesos. We produce log messages for both so we eventually + // report the termination reason. + logError("Lost an executor " + executorId + " (already removed): " + reason) + } + } + // Call dagScheduler.executorLost without holding the lock on this to prevent deadlock + if (failedExecutor != None) { + dagScheduler.executorLost(failedExecutor.get) + backend.reviveOffers() + } + } + + /** Remove an executor from all our data structures and mark it as lost */ + private def removeExecutor(executorId: String) { + activeExecutorIds -= executorId + val host = executorIdToHost(executorId) + val execs = executorsByHost.getOrElse(host, new HashSet) + execs -= executorId + if (execs.isEmpty) { + executorsByHost -= host + } + executorIdToHost -= executorId + rootPool.executorLost(executorId, host) + } + + def executorGained(execId: String, host: String) { + dagScheduler.executorGained(execId, host) + } + + def getExecutorsAliveOnHost(host: String): Option[Set[String]] = synchronized { + executorsByHost.get(host).map(_.toSet) + } + + def hasExecutorsAliveOnHost(host: String): Boolean = synchronized { + executorsByHost.contains(host) + } + + def isExecutorAlive(execId: String): Boolean = synchronized { + activeExecutorIds.contains(execId) + } + + // By default, rack is unknown + def getRackForHost(value: String): Option[String] = None +} + + +private[spark] object TaskSchedulerImpl { + /** + * Used to balance containers across hosts. + * + * Accepts a map of hosts to resource offers for that host, and returns a prioritized list of + * resource offers representing the order in which the offers should be used. The resource + * offers are ordered such that we'll allocate one container on each host before allocating a + * second container on any host, and so on, in order to reduce the damage if a host fails. + * + * For example, given , , , returns + * [o1, o5, o4, 02, o6, o3] + */ + def prioritizeContainers[K, T] (map: HashMap[K, ArrayBuffer[T]]): List[T] = { + val _keyList = new ArrayBuffer[K](map.size) + _keyList ++= map.keys + + // order keyList based on population of value in map + val keyList = _keyList.sortWith( + (left, right) => map(left).size > map(right).size + ) + + val retval = new ArrayBuffer[T](keyList.size * 2) + var index = 0 + var found = true + + while (found) { + found = false + for (key <- keyList) { + val containerList: ArrayBuffer[T] = map.get(key).getOrElse(null) + assert(containerList != null) + // Get the index'th entry for this host - if present + if (index < containerList.size){ + retval += containerList.apply(index) + found = true + } + } + index += 1 + } + + retval.toList + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 0fe413a7c4..0ac982909c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -48,7 +48,7 @@ import java.io.NotSerializableException * task set will be aborted */ private[spark] class TaskSetManager( - sched: ClusterScheduler, + sched: TaskSchedulerImpl, val taskSet: TaskSet, val maxTaskFailures: Int, clock: Clock = SystemClock) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 5797783793..5c534a6f43 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -29,7 +29,7 @@ import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import org.apache.spark.{SparkException, Logging, TaskState} import org.apache.spark.{Logging, SparkException, TaskState} -import org.apache.spark.scheduler.{ClusterScheduler, SchedulerBackend, SlaveLost, TaskDescription, +import org.apache.spark.scheduler.{TaskSchedulerImpl, SchedulerBackend, SlaveLost, TaskDescription, WorkerOffer} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.{AkkaUtils, Utils} @@ -43,7 +43,7 @@ import org.apache.spark.util.{AkkaUtils, Utils} * (spark.deploy.*). */ private[spark] -class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: ActorSystem) +class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: ActorSystem) extends SchedulerBackend with Logging { // Use an atomic variable to track total number of cores in the cluster for simplicity and speed diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index 2fbd725d75..ec3e68e970 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -21,10 +21,10 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, FileSystem} import org.apache.spark.{Logging, SparkContext} -import org.apache.spark.scheduler.ClusterScheduler +import org.apache.spark.scheduler.TaskSchedulerImpl private[spark] class SimrSchedulerBackend( - scheduler: ClusterScheduler, + scheduler: TaskSchedulerImpl, sc: SparkContext, driverFilePath: String) extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 1d38f0d956..404ce7a452 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -22,11 +22,11 @@ import scala.collection.mutable.HashMap import org.apache.spark.{Logging, SparkContext} import org.apache.spark.deploy.client.{Client, ClientListener} import org.apache.spark.deploy.{Command, ApplicationDescription} -import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, ClusterScheduler} +import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskSchedulerImpl} import org.apache.spark.util.Utils private[spark] class SparkDeploySchedulerBackend( - scheduler: ClusterScheduler, + scheduler: TaskSchedulerImpl, sc: SparkContext, masters: Array[String], appName: String) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 5481828111..39573fc8c9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -30,7 +30,7 @@ import org.apache.mesos._ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} import org.apache.spark.{SparkException, Logging, SparkContext, TaskState} -import org.apache.spark.scheduler.ClusterScheduler +import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend /** @@ -44,7 +44,7 @@ import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend * remove this. */ private[spark] class CoarseMesosSchedulerBackend( - scheduler: ClusterScheduler, + scheduler: TaskSchedulerImpl, sc: SparkContext, master: String, appName: String) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 773b980c53..6aa788c460 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -31,7 +31,7 @@ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTas import org.apache.spark.{Logging, SparkException, SparkContext, TaskState} import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SchedulerBackend, SlaveLost, - TaskDescription, ClusterScheduler, WorkerOffer} + TaskDescription, TaskSchedulerImpl, WorkerOffer} import org.apache.spark.util.Utils /** @@ -40,7 +40,7 @@ import org.apache.spark.util.Utils * from multiple apps can run on different cores) and in time (a core can switch ownership). */ private[spark] class MesosSchedulerBackend( - scheduler: ClusterScheduler, + scheduler: TaskSchedulerImpl, sc: SparkContext, master: String, appName: String) diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 6b5f1a5dc2..69c1c04843 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -24,7 +24,7 @@ import akka.actor.{Actor, ActorRef, Props} import org.apache.spark.{Logging, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} -import org.apache.spark.scheduler.{SchedulerBackend, ClusterScheduler, WorkerOffer} +import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} private case class ReviveOffers() @@ -38,7 +38,7 @@ private case class KillTask(taskId: Long) * and the ClusterScheduler. */ private[spark] class LocalActor( - scheduler: ClusterScheduler, + scheduler: TaskSchedulerImpl, executorBackend: LocalBackend, private val totalCores: Int) extends Actor with Logging { @@ -78,7 +78,7 @@ private[spark] class LocalActor( * master all run in the same JVM. It sits behind a ClusterScheduler and handles launching tasks * on a single Executor (created by the LocalBackend) running locally. */ -private[spark] class LocalBackend(scheduler: ClusterScheduler, val totalCores: Int) +private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: Int) extends SchedulerBackend with ExecutorBackend { var localActor: ActorRef = null diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index d4a7a11515..9deed568ac 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark import org.scalatest.{FunSuite, PrivateMethodTester} -import org.apache.spark.scheduler.{ClusterScheduler, TaskScheduler} +import org.apache.spark.scheduler.{TaskSchedulerImpl, TaskScheduler} import org.apache.spark.scheduler.cluster.{SimrSchedulerBackend, SparkDeploySchedulerBackend} import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import org.apache.spark.scheduler.local.LocalBackend @@ -27,13 +27,13 @@ import org.apache.spark.scheduler.local.LocalBackend class SparkContextSchedulerCreationSuite extends FunSuite with PrivateMethodTester with LocalSparkContext with Logging { - def createTaskScheduler(master: String): ClusterScheduler = { + def createTaskScheduler(master: String): TaskSchedulerImpl = { // Create local SparkContext to setup a SparkEnv. We don't actually want to start() the // real schedulers, so we don't want to create a full SparkContext with the desired scheduler. sc = new SparkContext("local", "test") val createTaskSchedulerMethod = PrivateMethod[TaskScheduler]('createTaskScheduler) val sched = SparkContext invokePrivate createTaskSchedulerMethod(sc, master, "test") - sched.asInstanceOf[ClusterScheduler] + sched.asInstanceOf[TaskSchedulerImpl] } test("bad-master") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala index 35a06c4875..702edb862f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala @@ -29,7 +29,7 @@ class FakeTaskSetManager( initPriority: Int, initStageId: Int, initNumTasks: Int, - clusterScheduler: ClusterScheduler, + clusterScheduler: TaskSchedulerImpl, taskSet: TaskSet) extends TaskSetManager(clusterScheduler, taskSet, 0) { @@ -104,7 +104,7 @@ class FakeTaskSetManager( class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging { - def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: ClusterScheduler, taskSet: TaskSet): FakeTaskSetManager = { + def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: TaskSchedulerImpl, taskSet: TaskSet): FakeTaskSetManager = { new FakeTaskSetManager(priority, stage, numTasks, cs , taskSet) } @@ -131,7 +131,7 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging test("FIFO Scheduler Test") { sc = new SparkContext("local", "ClusterSchedulerSuite") - val clusterScheduler = new ClusterScheduler(sc) + val clusterScheduler = new TaskSchedulerImpl(sc) var tasks = ArrayBuffer[Task[_]]() val task = new FakeTask(0) tasks += task @@ -158,7 +158,7 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging test("Fair Scheduler Test") { sc = new SparkContext("local", "ClusterSchedulerSuite") - val clusterScheduler = new ClusterScheduler(sc) + val clusterScheduler = new TaskSchedulerImpl(sc) var tasks = ArrayBuffer[Task[_]]() val task = new FakeTask(0) tasks += task @@ -215,7 +215,7 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging test("Nested Pool Test") { sc = new SparkContext("local", "ClusterSchedulerSuite") - val clusterScheduler = new ClusterScheduler(sc) + val clusterScheduler = new TaskSchedulerImpl(sc) var tasks = ArrayBuffer[Task[_]]() val task = new FakeTask(0) tasks += task diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index 9784920653..2265619570 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.storage.TaskResultBlockId * Used to test the case where a BlockManager evicts the task result (or dies) before the * TaskResult is retrieved. */ -class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler) +class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedulerImpl) extends TaskResultGetter(sparkEnv, scheduler) { var removedResult = false @@ -92,8 +92,8 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA sc = new SparkContext("local[1,1]", "test") // If this test hangs, it's probably because no resource offers were made after the task // failed. - val scheduler: ClusterScheduler = sc.taskScheduler match { - case clusterScheduler: ClusterScheduler => + val scheduler: TaskSchedulerImpl = sc.taskScheduler match { + case clusterScheduler: TaskSchedulerImpl => clusterScheduler case _ => assert(false, "Expect local cluster to use ClusterScheduler") diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index b34b6f32f2..771a64ff6c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -58,7 +58,7 @@ class FakeDAGScheduler(taskScheduler: FakeClusterScheduler) extends DAGScheduler * to work, and these are required for locality in TaskSetManager. */ class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /* execId, host */) - extends ClusterScheduler(sc) + extends TaskSchedulerImpl(sc) { val startedTasks = new ArrayBuffer[Long] val endedTasks = new mutable.HashMap[Long, TaskEndReason] -- cgit v1.2.3 From 3ddbdbfbc71486cd5076d875f82796a880d2dccb Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 20 Dec 2013 19:51:37 -0800 Subject: Minor updated based on comments on PR 277. --- .../scala/org/apache/spark/streaming/scheduler/JobScheduler.scala | 4 +++- .../src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala | 3 +++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 33c5322358..9511ccfbed 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -24,7 +24,8 @@ import scala.collection.mutable.HashSet import org.apache.spark.streaming._ /** - * This class drives the generation of Spark jobs from the DStreams. + * This class schedules jobs to be run on Spark. It uses the JobGenerator to generate + * the jobs and runs them using a thread pool. Number of threads */ private[streaming] class JobScheduler(val ssc: StreamingContext) extends Logging { @@ -91,6 +92,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { } } + private[streaming] class JobHandler(job: Job) extends Runnable { def run() { beforeJobStart(job) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala index 05233d095b..cf7431a8a3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala @@ -20,6 +20,9 @@ package org.apache.spark.streaming.scheduler import scala.collection.mutable.HashSet import org.apache.spark.streaming.Time +/** Class representing a set of Jobs + * belong to the same batch. + */ private[streaming] case class JobSet(time: Time, jobs: Seq[Job]) { -- cgit v1.2.3 From 076fc1622190d342e20592c00ca19f8c0a56997f Mon Sep 17 00:00:00 2001 From: Tor Myklebust Date: Sat, 21 Dec 2013 14:54:01 -0500 Subject: Python stubs for ALSModel. --- python/pyspark/__init__.py | 5 ++-- python/pyspark/mllib.py | 59 +++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 56 insertions(+), 8 deletions(-) diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 8b5bb79a18..3d73d95909 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -43,9 +43,10 @@ from pyspark.rdd import RDD from pyspark.files import SparkFiles from pyspark.storagelevel import StorageLevel from pyspark.mllib import LinearRegressionModel, LassoModel, \ - RidgeRegressionModel, LogisticRegressionModel, SVMModel, KMeansModel + RidgeRegressionModel, LogisticRegressionModel, SVMModel, KMeansModel, \ + ALSModel __all__ = ["SparkContext", "RDD", "SparkFiles", "StorageLevel", "LinearRegressionModel", "LassoModel", "RidgeRegressionModel", - "LogisticRegressionModel", "SVMModel", "KMeansModel"]; + "LogisticRegressionModel", "SVMModel", "KMeansModel", "ALSModel"]; diff --git a/python/pyspark/mllib.py b/python/pyspark/mllib.py index 8848284a5e..22187eb4dd 100644 --- a/python/pyspark/mllib.py +++ b/python/pyspark/mllib.py @@ -164,14 +164,17 @@ class LinearRegressionModelBase(LinearModel): _linear_predictor_typecheck(x, self._coeff) return dot(self._coeff, x) + self._intercept -# Map a pickled Python RDD of numpy double vectors to a Java RDD of -# _serialized_double_vectors -def _get_unmangled_double_vector_rdd(data): - dataBytes = data.map(_serialize_double_vector) +def _get_unmangled_rdd(data, serializer): + dataBytes = data.map(serializer) dataBytes._bypass_serializer = True dataBytes.cache() return dataBytes +# Map a pickled Python RDD of numpy double vectors to a Java RDD of +# _serialized_double_vectors +def _get_unmangled_double_vector_rdd(data): + return _get_unmangled_rdd(data, _serialize_double_vector) + # If we weren't given initial weights, take a zero vector of the appropriate # length. def _get_initial_weights(initial_weights, data): @@ -317,7 +320,7 @@ class KMeansModel(object): return best @classmethod - def train(cls, sc, data, k, maxIterations = 100, runs = 1, + def train(cls, sc, data, k, maxIterations=100, runs=1, initialization_mode="k-means||"): """Train a k-means clustering model.""" dataBytes = _get_unmangled_double_vector_rdd(data) @@ -330,12 +333,56 @@ class KMeansModel(object): + type(ans[0]) + " which is not bytearray") return KMeansModel(_deserialize_double_matrix(ans[0])) +def _serialize_rating(r): + ba = bytearray(16) + intpart = ndarray(shape=[2], buffer=ba, dtype=int32) + doublepart = ndarray(shape=[1], buffer=ba, dtype=float64, offset=8) + intpart[0], intpart[1], doublepart[0] = r + return ba + +class ALSModel(object): + """A matrix factorisation model trained by regularized alternating + least-squares. + + >>> r1 = (1, 1, 1.0) + >>> r2 = (1, 2, 2.0) + >>> r3 = (2, 1, 2.0) + >>> ratings = sc.parallelize([r1, r2, r3]) + >>> model = ALSModel.trainImplicit(sc, ratings, 1) + >>> model.predict(2,2) is not None + True + """ + + def __init__(self, sc, java_model): + self._context = sc + self._java_model = java_model + + #def __del__(self): + #self._gateway.detach(self._java_model) + + def predict(self, user, product): + return self._java_model.predict(user, product) + + @classmethod + def train(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1): + ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating) + mod = sc._jvm.PythonMLLibAPI().trainALSModel(ratingBytes._jrdd, + rank, iterations, lambda_, blocks) + return ALSModel(sc, mod) + + @classmethod + def trainImplicit(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01): + ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating) + mod = sc._jvm.PythonMLLibAPI().trainImplicitALSModel(ratingBytes._jrdd, + rank, iterations, lambda_, blocks, alpha) + return ALSModel(sc, mod) + def _test(): import doctest globs = globals().copy() globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) (failure_count, test_count) = doctest.testmod(globs=globs, - optionflags=doctest.ELLIPSIS) + optionflags=doctest.ELLIPSIS) globs['sc'].stop() print failure_count,"failures among",test_count,"tests" if failure_count: -- cgit v1.2.3 From 20f85eca3d924aecd0fcf61cd516d9ac8e369dc1 Mon Sep 17 00:00:00 2001 From: Tor Myklebust Date: Sat, 21 Dec 2013 14:54:13 -0500 Subject: Java stubs for ALSModel. --- .../apache/spark/mllib/api/PythonMLLibAPI.scala | 34 ++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala index 6472bf6367..4620cab175 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala @@ -19,6 +19,7 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ +import org.apache.spark.mllib.recommendation._ import org.apache.spark.rdd.RDD import java.nio.ByteBuffer import java.nio.ByteOrder @@ -194,4 +195,37 @@ class PythonMLLibAPI extends Serializable { ret.add(serializeDoubleMatrix(model.clusterCenters)) return ret } + + private def unpackRating(ratingBytes: Array[Byte]): Rating = { + val bb = ByteBuffer.wrap(ratingBytes) + bb.order(ByteOrder.nativeOrder()) + val user = bb.getInt() + val product = bb.getInt() + val rating = bb.getDouble() + return new Rating(user, product, rating) + } + + /** + * Java stub for Python mllib ALSModel.train(). This stub returns a handle + * to the Java object instead of the content of the Java object. Extra care + * needs to be taken in the Python code to ensure it gets freed on exit; see + * the Py4J documentation. + */ + def trainALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int, + iterations: Int, lambda: Double, blocks: Int): MatrixFactorizationModel = { + val ratings = ratingsBytesJRDD.rdd.map(unpackRating) + return ALS.train(ratings, rank, iterations, lambda, blocks) + } + + /** + * Java stub for Python mllib ALSModel.trainImplicit(). This stub returns a + * handle to the Java object instead of the content of the Java object. + * Extra care needs to be taken in the Python code to ensure it gets freed on + * exit; see the Py4J documentation. + */ + def trainImplicitALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int, + iterations: Int, lambda: Double, blocks: Int, alpha: Double): MatrixFactorizationModel = { + val ratings = ratingsBytesJRDD.rdd.map(unpackRating) + return ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha) + } } -- cgit v1.2.3 From b8ae096a40eb0f83aac889deb061a9484effd9aa Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Sat, 21 Dec 2013 23:28:48 -0800 Subject: Fix build error in test --- .../src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 771a64ff6c..3dcb01ae5e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -283,7 +283,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { // Fail the task MAX_TASK_FAILURES times, and check that the task set is aborted // after the last failure. - (1 to manager.MAX_TASK_FAILURES).foreach { index => + (1 to manager.maxTaskFailures).foreach { index => val offerResult = manager.resourceOffer("exec1", "host1", 1, ANY) assert(offerResult != None, "Expect resource offer on iteration %s to return a task".format(index)) -- cgit v1.2.3 From c979eecdf6a11462595aba9d5b8fc942682cf85d Mon Sep 17 00:00:00 2001 From: "wangda.tan" Date: Sun, 22 Dec 2013 21:43:15 +0800 Subject: added changes according to comments from rxin --- .../org/apache/spark/ui/exec/ExecutorsUI.scala | 24 +++++++-------------- .../org/apache/spark/ui/jobs/ExecutorSummary.scala | 5 +++-- .../org/apache/spark/ui/jobs/ExecutorTable.scala | 4 ++-- .../scala/org/apache/spark/ui/jobs/IndexPage.scala | 4 ---- .../apache/spark/ui/jobs/JobProgressListener.scala | 25 +++++++--------------- .../scala/org/apache/spark/ui/jobs/StagePage.scala | 4 ++-- .../org/apache/spark/ui/jobs/StageTable.scala | 2 +- 7 files changed, 24 insertions(+), 44 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala index f62ae37466..a31a7e1d58 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala @@ -56,7 +56,7 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize)).fold(0L)(_+_) val execHead = Seq("Executor ID", "Address", "RDD blocks", "Memory used", "Disk used", - "Active tasks", "Failed tasks", "Complete tasks", "Total tasks", "Duration", "Shuffle Read", + "Active tasks", "Failed tasks", "Complete tasks", "Total tasks", "Task Time", "Shuffle Read", "Shuffle Write") def execRow(kv: Seq[String]) = { @@ -169,21 +169,13 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { // update shuffle read/write if (null != taskEnd.taskMetrics) { - val shuffleRead = taskEnd.taskMetrics.shuffleReadMetrics - shuffleRead match { - case Some(s) => - val newShuffleRead = executorToShuffleRead.getOrElse(eid, 0L) + s.remoteBytesRead - executorToShuffleRead.put(eid, newShuffleRead) - case _ => {} - } - val shuffleWrite = taskEnd.taskMetrics.shuffleWriteMetrics - shuffleWrite match { - case Some(s) => { - val newShuffleWrite = executorToShuffleWrite.getOrElse(eid, 0L) + s.shuffleBytesWritten - executorToShuffleWrite.put(eid, newShuffleWrite) - } - case _ => {} - } + taskEnd.taskMetrics.shuffleReadMetrics.foreach(shuffleRead => + executorToShuffleRead.put(eid, executorToShuffleRead.getOrElse(eid, 0L) + + shuffleRead.remoteBytesRead)) + + taskEnd.taskMetrics.shuffleWriteMetrics.foreach(shuffleWrite => + executorToShuffleWrite.put(eid, executorToShuffleWrite.getOrElse(eid, 0L) + + shuffleWrite.shuffleBytesWritten)) } } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala index 75c0dd2c7f..3c53e88380 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala @@ -17,8 +17,9 @@ package org.apache.spark.ui.jobs -private[spark] class ExecutorSummary() { - var duration : Long = 0 +/** class for reporting aggregated metrics for each executors in stageUI */ +private[spark] class ExecutorSummary { + var taskTime : Long = 0 var failedTasks : Int = 0 var succeededTasks : Int = 0 var shuffleRead : Long = 0 diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 763d5a344b..0e9dd4a8c7 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -40,7 +40,7 @@ private[spark] class ExecutorTable(val parent: JobProgressUI, val stageId: Int) - + @@ -61,7 +61,7 @@ private[spark] class ExecutorTable(val parent: JobProgressUI, val stageId: Int) case (k,v) => { - + diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala index 854afb665a..ca5a28625b 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala @@ -56,10 +56,6 @@ private[spark] class IndexPage(parent: JobProgressUI) { {parent.formatDuration(now - listener.sc.startTime)}
  • Scheduling Mode: {parent.sc.getSchedulingMode}
  • -
  • - Executor Summary: - {listener.stageIdToExecutorSummaries.size} -
  • Active Stages: {activeStages.size} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 64ce715993..07a42f0503 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -144,23 +144,14 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList } // update duration - y.duration += taskEnd.taskInfo.duration - - // update shuffle read/write - if (null != taskEnd.taskMetrics) { - val shuffleRead = taskEnd.taskMetrics.shuffleReadMetrics - shuffleRead match { - case Some(s) => - y.shuffleRead += s.remoteBytesRead - case _ => {} - } - val shuffleWrite = taskEnd.taskMetrics.shuffleWriteMetrics - shuffleWrite match { - case Some(s) => { - y.shuffleWrite += s.shuffleBytesWritten - } - case _ => {} - } + y.taskTime += taskEnd.taskInfo.duration + + taskEnd.taskMetrics.shuffleReadMetrics.foreach { shuffleRead => + y.shuffleRead += shuffleRead.remoteBytesRead + } + + taskEnd.taskMetrics.shuffleWriteMetrics.foreach { shuffleWrite => + y.shuffleWrite += shuffleWrite.shuffleBytesWritten } } case _ => {} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index c077613b1d..d8a6c9e2dc 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -66,7 +66,7 @@ private[spark] class StagePage(parent: JobProgressUI) {
    • - Total duration across all tasks: + Total task time across all tasks: {parent.formatDuration(listener.stageIdToTime.getOrElse(stageId, 0L) + activeTime)}
    • {if (hasShuffleRead) @@ -163,9 +163,9 @@ private[spark] class StagePage(parent: JobProgressUI) { val executorTable = new ExecutorTable(parent, stageId) val content = summary ++ -

      Summary Metrics for Executors

      ++ executorTable.toNodeSeq() ++

      Summary Metrics for {numCompleted} Completed Tasks

      ++
      {summaryTable.getOrElse("No tasks have reported metrics yet.")}
      ++ +

      Aggregated Metrics by Executors

      ++ executorTable.toNodeSeq() ++

      Tasks

      ++ taskTable headerSparkPage(content, parent.sc, "Details for Stage %d".format(stageId), Stages) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 9ad6de3c6d..463d85dfd5 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -48,7 +48,7 @@ private[spark] class StageTable(val stages: Seq[StageInfo], val parent: JobProgr {if (isFairScheduler) {
  • } else {}} - + -- cgit v1.2.3 From b7bfae1afecad0ae79d5d040d2e02e390c272efb Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Sun, 22 Dec 2013 07:27:28 -0800 Subject: Correctly merged in maxTaskFailures fix --- core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala | 2 +- core/src/test/scala/org/apache/spark/FailureSuite.scala | 4 ++-- .../scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala | 2 +- .../test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 0ac982909c..aa3fb0b35a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -540,7 +540,7 @@ private[spark] class TaskSetManager( if (numFailures(index) >= maxTaskFailures) { logError("Task %s:%d failed %d times; aborting job".format( taskSet.id, index, maxTaskFailures)) - abort("Task %s:%d failed more than %d times (most recent failure: %s)".format( + abort("Task %s:%d failed %d times (most recent failure: %s)".format( taskSet.id, index, maxTaskFailures, failureReason)) } } diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index af448fcb37..befdc1589f 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -42,7 +42,7 @@ class FailureSuite extends FunSuite with LocalSparkContext { // Run a 3-task map job in which task 1 deterministically fails once, and check // whether the job completes successfully and we ran 4 tasks in total. test("failure in a single-stage job") { - sc = new SparkContext("local[1,1]", "test") + sc = new SparkContext("local[1,2]", "test") val results = sc.makeRDD(1 to 3, 3).map { x => FailureSuiteState.synchronized { FailureSuiteState.tasksRun += 1 @@ -62,7 +62,7 @@ class FailureSuite extends FunSuite with LocalSparkContext { // Run a map-reduce job in which a reduce task deterministically fails once. test("failure in a two-stage job") { - sc = new SparkContext("local[1,1]", "test") + sc = new SparkContext("local[1,2]", "test") val results = sc.makeRDD(1 to 3).map(x => (x, x)).groupByKey(3).map { case (k, v) => FailureSuiteState.synchronized { diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index 9deed568ac..f28d5c7b13 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -53,7 +53,7 @@ class SparkContextSchedulerCreationSuite test("local-n") { val sched = createTaskScheduler("local[5]") - assert(sched.maxTaskFailures === 0) + assert(sched.maxTaskFailures === 1) sched.backend match { case s: LocalBackend => assert(s.totalCores === 5) case _ => fail() diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index 2265619570..ca97f7d2a5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -89,7 +89,7 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA test("task retried if result missing from block manager") { // Set the maximum number of task failures to > 0, so that the task set isn't aborted // after the result is missing. - sc = new SparkContext("local[1,1]", "test") + sc = new SparkContext("local[1,2]", "test") // If this test hangs, it's probably because no resource offers were made after the task // failed. val scheduler: TaskSchedulerImpl = sc.taskScheduler match { -- cgit v1.2.3 From cbb28111896844a0fd94346cd9c6f9926c706555 Mon Sep 17 00:00:00 2001 From: Tor Myklebust Date: Sun, 22 Dec 2013 15:03:58 -0500 Subject: Release JVM reference to the ALSModel when done. --- python/pyspark/mllib.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/mllib.py b/python/pyspark/mllib.py index 22187eb4dd..1f5a5f6c01 100644 --- a/python/pyspark/mllib.py +++ b/python/pyspark/mllib.py @@ -357,8 +357,8 @@ class ALSModel(object): self._context = sc self._java_model = java_model - #def __del__(self): - #self._gateway.detach(self._java_model) + def __del__(self): + self._context._gateway.detach(self._java_model) def predict(self, user, product): return self._java_model.predict(user, product) -- cgit v1.2.3 From 2f689ba97b437092bf52063cface12aa9ee09bf3 Mon Sep 17 00:00:00 2001 From: "wangda.tan" Date: Mon, 23 Dec 2013 15:03:45 +0800 Subject: SPARK-968, added executor address showing in aggregated metrics by executors table --- .../main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 0e9dd4a8c7..0dd876480a 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -21,6 +21,7 @@ import scala.xml.Node import org.apache.spark.scheduler.SchedulingMode import org.apache.spark.util.Utils +import scala.collection.mutable /** Page showing executor summary */ private[spark] class ExecutorTable(val parent: JobProgressUI, val stageId: Int) { @@ -40,6 +41,7 @@ private[spark] class ExecutorTable(val parent: JobProgressUI, val stageId: Int)
    Executor IDDurationTask Time Total Tasks Failed Tasks Succeeded Tasks
    {k}{parent.formatDuration(v.duration)}{parent.formatDuration(v.taskTime)} {v.failedTasks + v.succeededTasks} {v.failedTasks} {v.succeededTasks} Pool NameDescription SubmittedDurationTask Time Tasks: Succeeded/Total Shuffle Read Shuffle Write
    + @@ -54,6 +56,16 @@ private[spark] class ExecutorTable(val parent: JobProgressUI, val stageId: Int) } private def createExecutorTable() : Seq[Node] = { + // make a executor-id -> address map + val executorIdToAddress = mutable.HashMap[String, String]() + val storageStatusList = parent.sc.getExecutorStorageStatus + for (statusId <- 0 until storageStatusList.size) { + val blockManagerId = parent.sc.getExecutorStorageStatus(statusId).blockManagerId + val address = blockManagerId.hostPort + val executorId = blockManagerId.executorId + executorIdToAddress.put(executorId, address) + } + val executorIdToSummary = listener.stageIdToExecutorSummaries.get(stageId) executorIdToSummary match { case Some(x) => { @@ -61,6 +73,7 @@ private[spark] class ExecutorTable(val parent: JobProgressUI, val stageId: Int) case (k,v) => { + -- cgit v1.2.3 From dc3ee6b6122229cd99a133baf10a46dac2f7e9e2 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 23 Dec 2013 11:30:42 -0800 Subject: Added comments to BatchInfo and JobSet, based on Patrick's comment on PR 277. --- .../apache/spark/streaming/scheduler/BatchInfo.scala | 19 +++++++++++++++++++ .../org/apache/spark/streaming/scheduler/JobSet.scala | 10 +++++++--- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala index 88e4af59b7..e3fb07624e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala @@ -21,6 +21,12 @@ import org.apache.spark.streaming.Time /** * Class having information on completed batches. + * @param batchTime Time of the batch + * @param submissionTime Clock time of when jobs of this batch was submitted to + * the streaming scheduler queue + * @param processingStartTime Clock time of when the first job of this batch started processing + * @param processingEndTime Clock time of when the last job of this batch finished processing + * */ case class BatchInfo( batchTime: Time, @@ -29,9 +35,22 @@ case class BatchInfo( processingEndTime: Option[Long] ) { + /** + * Time taken for the first job of this batch to start processing from the time this batch + * was submitted to the streaming scheduler. Essentially, it is + * `processingStartTime` - `submissionTime`. + */ def schedulingDelay = processingStartTime.map(_ - submissionTime) + /** + * Time taken for the all jobs of this batch to finish processing from the time they started + * processing. Essentially, it is `processingEndTime` - `processingStartTime`. + */ def processingDelay = processingEndTime.zip(processingStartTime).map(x => x._1 - x._2).headOption + /** + * Time taken for all the jobs of this batch to finish processing from the time they + * were submitted. Essentially, it is `processingDelay` + `schedulingDelay`. + */ def totalDelay = schedulingDelay.zip(processingDelay).map(x => x._1 + x._2).headOption } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala index cf7431a8a3..57268674ea 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala @@ -27,9 +27,9 @@ private[streaming] case class JobSet(time: Time, jobs: Seq[Job]) { private val incompleteJobs = new HashSet[Job]() - var submissionTime = System.currentTimeMillis() - var processingStartTime = -1L - var processingEndTime = -1L + var submissionTime = System.currentTimeMillis() // when this jobset was submitted + var processingStartTime = -1L // when the first job of this jobset started processing + var processingEndTime = -1L // when the last job of this jobset finished processing jobs.zipWithIndex.foreach { case (job, i) => job.setId(i) } incompleteJobs ++= jobs @@ -47,8 +47,12 @@ case class JobSet(time: Time, jobs: Seq[Job]) { def hasCompleted() = incompleteJobs.isEmpty + // Time taken to process all the jobs from the time they started processing + // (i.e. not including the time they wait in the streaming scheduler queue) def processingDelay = processingEndTime - processingStartTime + // Time taken to process all the jobs from the time they were submitted + // (i.e. including the time they wait in the streaming scheduler queue) def totalDelay = { processingEndTime - time.milliseconds } -- cgit v1.2.3 From f9771690a698b6ce5d29eb36b38bbeb498d1af0d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 23 Dec 2013 11:32:26 -0800 Subject: Minor formatting fixes. --- .../scala/org/apache/spark/streaming/scheduler/BatchInfo.scala | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala index e3fb07624e..4e8d07fe92 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala @@ -26,7 +26,6 @@ import org.apache.spark.streaming.Time * the streaming scheduler queue * @param processingStartTime Clock time of when the first job of this batch started processing * @param processingEndTime Clock time of when the last job of this batch finished processing - * */ case class BatchInfo( batchTime: Time, @@ -48,9 +47,9 @@ case class BatchInfo( */ def processingDelay = processingEndTime.zip(processingStartTime).map(x => x._1 - x._2).headOption - /** - * Time taken for all the jobs of this batch to finish processing from the time they - * were submitted. Essentially, it is `processingDelay` + `schedulingDelay`. - */ + /** + * Time taken for all the jobs of this batch to finish processing from the time they + * were submitted. Essentially, it is `processingDelay` + `schedulingDelay`. + */ def totalDelay = schedulingDelay.zip(processingDelay).map(x => x._1 + x._2).headOption } -- cgit v1.2.3 From 6eaa0505493511adb040257abc749fcd774bbb68 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 23 Dec 2013 15:55:45 -0800 Subject: Minor change for PR 277. --- .../test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 16410a21e3..fa64142096 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.streaming.scheduler._ import scala.collection.mutable.ArrayBuffer import org.scalatest.matchers.ShouldMatchers -class StreamingListenerSuite extends TestSuiteBase with ShouldMatchers{ +class StreamingListenerSuite extends TestSuiteBase with ShouldMatchers { val input = (1 to 4).map(Seq(_)).toSeq val operation = (d: DStream[Int]) => d.map(x => x) -- cgit v1.2.3 From fc80b2e693d4c52d0f1ada67216723902c09c666 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 23 Dec 2013 21:20:20 -0800 Subject: Show full stack trace and time taken in unit tests. --- project/SparkBuild.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index ab96cfa18b..7bcbd90bd3 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -114,6 +114,9 @@ object SparkBuild extends Build { fork := true, javaOptions += "-Xmx3g", + // Show full stack trace and duration in test cases. + testOptions in Test += Tests.Argument("-oDF"), + // Only allow one test at a time, even across projects, since they run in the same JVM concurrentRestrictions in Global += Tags.limit(Tags.Test, 1), @@ -259,7 +262,7 @@ object SparkBuild extends Build { libraryDependencies <+= scalaVersion(v => "org.scala-lang" % "scala-reflect" % v ) ) - + def examplesSettings = sharedSettings ++ Seq( name := "spark-examples", libraryDependencies ++= Seq( -- cgit v1.2.3 From a8bb86389d8dc8efeff83561aea044a3c4924df5 Mon Sep 17 00:00:00 2001 From: azuryyu Date: Tue, 24 Dec 2013 16:52:20 +0800 Subject: Fixed job name in the java streaming example. --- .../java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java b/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java index 9a8e4209ed..22994fb2ec 100644 --- a/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java +++ b/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java @@ -53,7 +53,7 @@ public class JavaKafkaWordCount { } // Create the context with a 1 second batch size - JavaStreamingContext ssc = new JavaStreamingContext(args[0], "NetworkWordCount", + JavaStreamingContext ssc = new JavaStreamingContext(args[0], "KafkaWordCount", new Duration(2000), System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); int numThreads = Integer.parseInt(args[4]); -- cgit v1.2.3 From 66b7bea7f82efa4f52186d15824d31035be253de Mon Sep 17 00:00:00 2001 From: azuryyu Date: Tue, 24 Dec 2013 18:16:49 +0800 Subject: Make App report interval configurable during 'run on Yarn' --- new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 94678815e8..9fdee29498 100644 --- a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -437,8 +437,10 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl } def monitorApplication(appId: ApplicationId): Boolean = { + val interval = System.getProperty("spark.yarn.report.interval", "1000").toLong + while (true) { - Thread.sleep(1000) + Thread.sleep(interval) val report = super.getApplicationReport(appId) logInfo("Application report from ASM: \n" + -- cgit v1.2.3 From 2402180b32d530319d0526490afa3cfafc5c36b8 Mon Sep 17 00:00:00 2001 From: Tor Myklebust Date: Tue, 24 Dec 2013 16:18:33 -0500 Subject: Fix error message ugliness. --- mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala index 4620cab175..67ec974734 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala @@ -42,7 +42,7 @@ class PythonMLLibAPI extends Serializable { } val length = bb.getLong() if (packetLength != 16 + 8 * length) { - throw new IllegalArgumentException("Length " + length + "is wrong.") + throw new IllegalArgumentException("Length " + length + " is wrong.") } val db = bb.asDoubleBuffer() val ans = new Array[Double](length.toInt) @@ -76,7 +76,7 @@ class PythonMLLibAPI extends Serializable { val rows = bb.getLong() val cols = bb.getLong() if (packetLength != 24 + 8 * rows * cols) { - throw new IllegalArgumentException("Size " + rows + "x" + cols + "is wrong.") + throw new IllegalArgumentException("Size " + rows + "x" + cols + " is wrong.") } val db = bb.asDoubleBuffer() val ans = new Array[Array[Double]](rows.toInt) -- cgit v1.2.3 From 58e2a7d6d4f036b20896674b1cac076d8daa55e8 Mon Sep 17 00:00:00 2001 From: Tor Myklebust Date: Tue, 24 Dec 2013 16:48:40 -0500 Subject: Move PythonMLLibAPI into its own package. --- .../apache/spark/mllib/api/PythonMLLibAPI.scala | 231 -------------------- .../spark/mllib/api/python/PythonMLLibAPI.scala | 232 +++++++++++++++++++++ 2 files changed, 232 insertions(+), 231 deletions(-) delete mode 100644 mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala deleted file mode 100644 index 67ec974734..0000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import org.apache.spark.api.java.JavaRDD -import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.classification._ -import org.apache.spark.mllib.clustering._ -import org.apache.spark.mllib.recommendation._ -import org.apache.spark.rdd.RDD -import java.nio.ByteBuffer -import java.nio.ByteOrder -import java.nio.DoubleBuffer - -/** - * The Java stubs necessary for the Python mllib bindings. - */ -class PythonMLLibAPI extends Serializable { - private def deserializeDoubleVector(bytes: Array[Byte]): Array[Double] = { - val packetLength = bytes.length - if (packetLength < 16) { - throw new IllegalArgumentException("Byte array too short.") - } - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - val magic = bb.getLong() - if (magic != 1) { - throw new IllegalArgumentException("Magic " + magic + " is wrong.") - } - val length = bb.getLong() - if (packetLength != 16 + 8 * length) { - throw new IllegalArgumentException("Length " + length + " is wrong.") - } - val db = bb.asDoubleBuffer() - val ans = new Array[Double](length.toInt) - db.get(ans) - return ans - } - - private def serializeDoubleVector(doubles: Array[Double]): Array[Byte] = { - val len = doubles.length - val bytes = new Array[Byte](16 + 8 * len) - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - bb.putLong(1) - bb.putLong(len) - val db = bb.asDoubleBuffer() - db.put(doubles) - return bytes - } - - private def deserializeDoubleMatrix(bytes: Array[Byte]): Array[Array[Double]] = { - val packetLength = bytes.length - if (packetLength < 24) { - throw new IllegalArgumentException("Byte array too short.") - } - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - val magic = bb.getLong() - if (magic != 2) { - throw new IllegalArgumentException("Magic " + magic + " is wrong.") - } - val rows = bb.getLong() - val cols = bb.getLong() - if (packetLength != 24 + 8 * rows * cols) { - throw new IllegalArgumentException("Size " + rows + "x" + cols + " is wrong.") - } - val db = bb.asDoubleBuffer() - val ans = new Array[Array[Double]](rows.toInt) - var i = 0 - for (i <- 0 until rows.toInt) { - ans(i) = new Array[Double](cols.toInt) - db.get(ans(i)) - } - return ans - } - - private def serializeDoubleMatrix(doubles: Array[Array[Double]]): Array[Byte] = { - val rows = doubles.length - var cols = 0 - if (rows > 0) { - cols = doubles(0).length - } - val bytes = new Array[Byte](24 + 8 * rows * cols) - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - bb.putLong(2) - bb.putLong(rows) - bb.putLong(cols) - val db = bb.asDoubleBuffer() - var i = 0 - for (i <- 0 until rows) { - db.put(doubles(i)) - } - return bytes - } - - private def trainRegressionModel(trainFunc: (RDD[LabeledPoint], Array[Double]) => GeneralizedLinearModel, - dataBytesJRDD: JavaRDD[Array[Byte]], initialWeightsBA: Array[Byte]): - java.util.LinkedList[java.lang.Object] = { - val data = dataBytesJRDD.rdd.map(xBytes => { - val x = deserializeDoubleVector(xBytes) - LabeledPoint(x(0), x.slice(1, x.length)) - }) - val initialWeights = deserializeDoubleVector(initialWeightsBA) - val model = trainFunc(data, initialWeights) - val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(serializeDoubleVector(model.weights)) - ret.add(model.intercept: java.lang.Double) - return ret - } - - /** - * Java stub for Python mllib LinearRegressionModel.train() - */ - def trainLinearRegressionModel(dataBytesJRDD: JavaRDD[Array[Byte]], - numIterations: Int, stepSize: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { - return trainRegressionModel((data, initialWeights) => - LinearRegressionWithSGD.train(data, numIterations, stepSize, - miniBatchFraction, initialWeights), - dataBytesJRDD, initialWeightsBA) - } - - /** - * Java stub for Python mllib LassoModel.train() - */ - def trainLassoModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, - stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { - return trainRegressionModel((data, initialWeights) => - LassoWithSGD.train(data, numIterations, stepSize, regParam, - miniBatchFraction, initialWeights), - dataBytesJRDD, initialWeightsBA) - } - - /** - * Java stub for Python mllib RidgeRegressionModel.train() - */ - def trainRidgeModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, - stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { - return trainRegressionModel((data, initialWeights) => - RidgeRegressionWithSGD.train(data, numIterations, stepSize, regParam, - miniBatchFraction, initialWeights), - dataBytesJRDD, initialWeightsBA) - } - - /** - * Java stub for Python mllib SVMModel.train() - */ - def trainSVMModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, - stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { - return trainRegressionModel((data, initialWeights) => - SVMWithSGD.train(data, numIterations, stepSize, regParam, - miniBatchFraction, initialWeights), - dataBytesJRDD, initialWeightsBA) - } - - /** - * Java stub for Python mllib LogisticRegressionModel.train() - */ - def trainLogisticRegressionModel(dataBytesJRDD: JavaRDD[Array[Byte]], - numIterations: Int, stepSize: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { - return trainRegressionModel((data, initialWeights) => - LogisticRegressionWithSGD.train(data, numIterations, stepSize, - miniBatchFraction, initialWeights), - dataBytesJRDD, initialWeightsBA) - } - - /** - * Java stub for Python mllib KMeansModel.train() - */ - def trainKMeansModel(dataBytesJRDD: JavaRDD[Array[Byte]], k: Int, - maxIterations: Int, runs: Int, initializationMode: String): - java.util.List[java.lang.Object] = { - val data = dataBytesJRDD.rdd.map(xBytes => deserializeDoubleVector(xBytes)) - val model = KMeans.train(data, k, maxIterations, runs, initializationMode) - val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(serializeDoubleMatrix(model.clusterCenters)) - return ret - } - - private def unpackRating(ratingBytes: Array[Byte]): Rating = { - val bb = ByteBuffer.wrap(ratingBytes) - bb.order(ByteOrder.nativeOrder()) - val user = bb.getInt() - val product = bb.getInt() - val rating = bb.getDouble() - return new Rating(user, product, rating) - } - - /** - * Java stub for Python mllib ALSModel.train(). This stub returns a handle - * to the Java object instead of the content of the Java object. Extra care - * needs to be taken in the Python code to ensure it gets freed on exit; see - * the Py4J documentation. - */ - def trainALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int, - iterations: Int, lambda: Double, blocks: Int): MatrixFactorizationModel = { - val ratings = ratingsBytesJRDD.rdd.map(unpackRating) - return ALS.train(ratings, rank, iterations, lambda, blocks) - } - - /** - * Java stub for Python mllib ALSModel.trainImplicit(). This stub returns a - * handle to the Java object instead of the content of the Java object. - * Extra care needs to be taken in the Python code to ensure it gets freed on - * exit; see the Py4J documentation. - */ - def trainImplicitALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int, - iterations: Int, lambda: Double, blocks: Int, alpha: Double): MatrixFactorizationModel = { - val ratings = ratingsBytesJRDD.rdd.map(unpackRating) - return ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha) - } -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala new file mode 100644 index 0000000000..ca474322a8 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.api.python +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.regression._ +import org.apache.spark.mllib.classification._ +import org.apache.spark.mllib.clustering._ +import org.apache.spark.mllib.recommendation._ +import org.apache.spark.rdd.RDD +import java.nio.ByteBuffer +import java.nio.ByteOrder +import java.nio.DoubleBuffer + +/** + * The Java stubs necessary for the Python mllib bindings. + */ +class PythonMLLibAPI extends Serializable { + private def deserializeDoubleVector(bytes: Array[Byte]): Array[Double] = { + val packetLength = bytes.length + if (packetLength < 16) { + throw new IllegalArgumentException("Byte array too short.") + } + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + val magic = bb.getLong() + if (magic != 1) { + throw new IllegalArgumentException("Magic " + magic + " is wrong.") + } + val length = bb.getLong() + if (packetLength != 16 + 8 * length) { + throw new IllegalArgumentException("Length " + length + " is wrong.") + } + val db = bb.asDoubleBuffer() + val ans = new Array[Double](length.toInt) + db.get(ans) + return ans + } + + private def serializeDoubleVector(doubles: Array[Double]): Array[Byte] = { + val len = doubles.length + val bytes = new Array[Byte](16 + 8 * len) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.putLong(1) + bb.putLong(len) + val db = bb.asDoubleBuffer() + db.put(doubles) + return bytes + } + + private def deserializeDoubleMatrix(bytes: Array[Byte]): Array[Array[Double]] = { + val packetLength = bytes.length + if (packetLength < 24) { + throw new IllegalArgumentException("Byte array too short.") + } + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + val magic = bb.getLong() + if (magic != 2) { + throw new IllegalArgumentException("Magic " + magic + " is wrong.") + } + val rows = bb.getLong() + val cols = bb.getLong() + if (packetLength != 24 + 8 * rows * cols) { + throw new IllegalArgumentException("Size " + rows + "x" + cols + " is wrong.") + } + val db = bb.asDoubleBuffer() + val ans = new Array[Array[Double]](rows.toInt) + var i = 0 + for (i <- 0 until rows.toInt) { + ans(i) = new Array[Double](cols.toInt) + db.get(ans(i)) + } + return ans + } + + private def serializeDoubleMatrix(doubles: Array[Array[Double]]): Array[Byte] = { + val rows = doubles.length + var cols = 0 + if (rows > 0) { + cols = doubles(0).length + } + val bytes = new Array[Byte](24 + 8 * rows * cols) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.putLong(2) + bb.putLong(rows) + bb.putLong(cols) + val db = bb.asDoubleBuffer() + var i = 0 + for (i <- 0 until rows) { + db.put(doubles(i)) + } + return bytes + } + + private def trainRegressionModel(trainFunc: (RDD[LabeledPoint], Array[Double]) => GeneralizedLinearModel, + dataBytesJRDD: JavaRDD[Array[Byte]], initialWeightsBA: Array[Byte]): + java.util.LinkedList[java.lang.Object] = { + val data = dataBytesJRDD.rdd.map(xBytes => { + val x = deserializeDoubleVector(xBytes) + LabeledPoint(x(0), x.slice(1, x.length)) + }) + val initialWeights = deserializeDoubleVector(initialWeightsBA) + val model = trainFunc(data, initialWeights) + val ret = new java.util.LinkedList[java.lang.Object]() + ret.add(serializeDoubleVector(model.weights)) + ret.add(model.intercept: java.lang.Double) + return ret + } + + /** + * Java stub for Python mllib LinearRegressionModel.train() + */ + def trainLinearRegressionModel(dataBytesJRDD: JavaRDD[Array[Byte]], + numIterations: Int, stepSize: Double, miniBatchFraction: Double, + initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + return trainRegressionModel((data, initialWeights) => + LinearRegressionWithSGD.train(data, numIterations, stepSize, + miniBatchFraction, initialWeights), + dataBytesJRDD, initialWeightsBA) + } + + /** + * Java stub for Python mllib LassoModel.train() + */ + def trainLassoModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, + stepSize: Double, regParam: Double, miniBatchFraction: Double, + initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + return trainRegressionModel((data, initialWeights) => + LassoWithSGD.train(data, numIterations, stepSize, regParam, + miniBatchFraction, initialWeights), + dataBytesJRDD, initialWeightsBA) + } + + /** + * Java stub for Python mllib RidgeRegressionModel.train() + */ + def trainRidgeModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, + stepSize: Double, regParam: Double, miniBatchFraction: Double, + initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + return trainRegressionModel((data, initialWeights) => + RidgeRegressionWithSGD.train(data, numIterations, stepSize, regParam, + miniBatchFraction, initialWeights), + dataBytesJRDD, initialWeightsBA) + } + + /** + * Java stub for Python mllib SVMModel.train() + */ + def trainSVMModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, + stepSize: Double, regParam: Double, miniBatchFraction: Double, + initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + return trainRegressionModel((data, initialWeights) => + SVMWithSGD.train(data, numIterations, stepSize, regParam, + miniBatchFraction, initialWeights), + dataBytesJRDD, initialWeightsBA) + } + + /** + * Java stub for Python mllib LogisticRegressionModel.train() + */ + def trainLogisticRegressionModel(dataBytesJRDD: JavaRDD[Array[Byte]], + numIterations: Int, stepSize: Double, miniBatchFraction: Double, + initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + return trainRegressionModel((data, initialWeights) => + LogisticRegressionWithSGD.train(data, numIterations, stepSize, + miniBatchFraction, initialWeights), + dataBytesJRDD, initialWeightsBA) + } + + /** + * Java stub for Python mllib KMeansModel.train() + */ + def trainKMeansModel(dataBytesJRDD: JavaRDD[Array[Byte]], k: Int, + maxIterations: Int, runs: Int, initializationMode: String): + java.util.List[java.lang.Object] = { + val data = dataBytesJRDD.rdd.map(xBytes => deserializeDoubleVector(xBytes)) + val model = KMeans.train(data, k, maxIterations, runs, initializationMode) + val ret = new java.util.LinkedList[java.lang.Object]() + ret.add(serializeDoubleMatrix(model.clusterCenters)) + return ret + } + + private def unpackRating(ratingBytes: Array[Byte]): Rating = { + val bb = ByteBuffer.wrap(ratingBytes) + bb.order(ByteOrder.nativeOrder()) + val user = bb.getInt() + val product = bb.getInt() + val rating = bb.getDouble() + return new Rating(user, product, rating) + } + + /** + * Java stub for Python mllib ALSModel.train(). This stub returns a handle + * to the Java object instead of the content of the Java object. Extra care + * needs to be taken in the Python code to ensure it gets freed on exit; see + * the Py4J documentation. + */ + def trainALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int, + iterations: Int, lambda: Double, blocks: Int): MatrixFactorizationModel = { + val ratings = ratingsBytesJRDD.rdd.map(unpackRating) + return ALS.train(ratings, rank, iterations, lambda, blocks) + } + + /** + * Java stub for Python mllib ALSModel.trainImplicit(). This stub returns a + * handle to the Java object instead of the content of the Java object. + * Extra care needs to be taken in the Python code to ensure it gets freed on + * exit; see the Py4J documentation. + */ + def trainImplicitALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int, + iterations: Int, lambda: Double, blocks: Int, alpha: Double): MatrixFactorizationModel = { + val ratings = ratingsBytesJRDD.rdd.map(unpackRating) + return ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha) + } +} -- cgit v1.2.3 From 4efec6eb941c1c6cdec884174ea98c040a277cde Mon Sep 17 00:00:00 2001 From: Tor Myklebust Date: Tue, 24 Dec 2013 16:49:03 -0500 Subject: Python change for move of PythonMLLibAPI. --- python/pyspark/java_gateway.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 2941984e19..eb79135b9d 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -62,6 +62,6 @@ def launch_gateway(): # Import the classes used by PySpark java_import(gateway.jvm, "org.apache.spark.api.java.*") java_import(gateway.jvm, "org.apache.spark.api.python.*") - java_import(gateway.jvm, "org.apache.spark.mllib.api.*") + java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") java_import(gateway.jvm, "scala.Tuple2") return gateway -- cgit v1.2.3 From 86e38c49420098da422a17e7c098efa34c94c35b Mon Sep 17 00:00:00 2001 From: Tor Myklebust Date: Tue, 24 Dec 2013 16:49:31 -0500 Subject: Remove useless line from test stub. --- python/pyspark/mllib.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/pyspark/mllib.py b/python/pyspark/mllib.py index 1f5a5f6c01..46f368b1ec 100644 --- a/python/pyspark/mllib.py +++ b/python/pyspark/mllib.py @@ -384,7 +384,6 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() - print failure_count,"failures among",test_count,"tests" if failure_count: exit(-1) -- cgit v1.2.3 From 1efe3adf560d207f9106ffd4e15934e422adb636 Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Tue, 24 Dec 2013 14:18:39 -0800 Subject: Responded to Reynold's style comments --- .../main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala | 7 ++++--- .../src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala | 4 ++-- .../main/scala/org/apache/spark/scheduler/local/LocalBackend.scala | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 7409168f7b..dbac6b96ac 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -46,9 +46,10 @@ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode * we are holding a lock on ourselves. */ private[spark] class TaskSchedulerImpl( - val sc: SparkContext, - val maxTaskFailures : Int = System.getProperty("spark.task.maxFailures", "4").toInt, - isLocal: Boolean = false) extends TaskScheduler with Logging { + val sc: SparkContext, + val maxTaskFailures : Int = System.getProperty("spark.task.maxFailures", "4").toInt, + isLocal: Boolean = false) + extends TaskScheduler with Logging { // How often to check for speculative tasks val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index aa3fb0b35a..c676e73e03 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -17,6 +17,7 @@ package org.apache.spark.scheduler +import java.io.NotSerializableException import java.util.Arrays import scala.collection.mutable.ArrayBuffer @@ -28,8 +29,7 @@ import scala.math.min import org.apache.spark.{ExceptionFailure, FetchFailed, Logging, Resubmitted, SparkEnv, Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState} import org.apache.spark.TaskState.TaskState -import org.apache.spark.util.{SystemClock, Clock} -import java.io.NotSerializableException +import org.apache.spark.util.{Clock, SystemClock} /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 69c1c04843..4edc6a0d3f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -93,7 +93,7 @@ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: } override def reviveOffers() { - localActor ! ReviveOffers + localActor ! ReviveOffers } override def defaultParallelism() = totalCores -- cgit v1.2.3 From 3665c722b5b540ec96463031b9980cdc43829deb Mon Sep 17 00:00:00 2001 From: Andrew Ash Date: Tue, 24 Dec 2013 17:25:04 -0800 Subject: Typo: avaiable -> available --- python/pyspark/shell.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index a475959090..ef07eb437b 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -42,7 +42,7 @@ print "Using Python version %s (%s, %s)" % ( platform.python_version(), platform.python_build()[0], platform.python_build()[1]) -print "Spark context avaiable as sc." +print "Spark context available as sc." if add_files != None: print "Adding files: [%s]" % ", ".join(add_files) -- cgit v1.2.3 From 05163057a1810f0a32b722e8c93e5435240636d9 Mon Sep 17 00:00:00 2001 From: Tor Myklebust Date: Wed, 25 Dec 2013 00:08:05 -0500 Subject: Split the mllib bindings into a whole bunch of modules and rename some things. --- python/pyspark/__init__.py | 7 +- python/pyspark/mllib.py | 391 --------------------------------- python/pyspark/mllib/__init__.py | 46 ++++ python/pyspark/mllib/_common.py | 227 +++++++++++++++++++ python/pyspark/mllib/classification.py | 86 ++++++++ python/pyspark/mllib/clustering.py | 79 +++++++ python/pyspark/mllib/recommendation.py | 74 +++++++ python/pyspark/mllib/regression.py | 110 ++++++++++ 8 files changed, 623 insertions(+), 397 deletions(-) delete mode 100644 python/pyspark/mllib.py create mode 100644 python/pyspark/mllib/__init__.py create mode 100644 python/pyspark/mllib/_common.py create mode 100644 python/pyspark/mllib/classification.py create mode 100644 python/pyspark/mllib/clustering.py create mode 100644 python/pyspark/mllib/recommendation.py create mode 100644 python/pyspark/mllib/regression.py diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 3d73d95909..1f35f6f939 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -42,11 +42,6 @@ from pyspark.context import SparkContext from pyspark.rdd import RDD from pyspark.files import SparkFiles from pyspark.storagelevel import StorageLevel -from pyspark.mllib import LinearRegressionModel, LassoModel, \ - RidgeRegressionModel, LogisticRegressionModel, SVMModel, KMeansModel, \ - ALSModel -__all__ = ["SparkContext", "RDD", "SparkFiles", "StorageLevel", - "LinearRegressionModel", "LassoModel", "RidgeRegressionModel", - "LogisticRegressionModel", "SVMModel", "KMeansModel", "ALSModel"]; +__all__ = ["SparkContext", "RDD", "SparkFiles", "StorageLevel"] diff --git a/python/pyspark/mllib.py b/python/pyspark/mllib.py deleted file mode 100644 index 46f368b1ec..0000000000 --- a/python/pyspark/mllib.py +++ /dev/null @@ -1,391 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from numpy import * -from pyspark import SparkContext - -# Double vector format: -# -# [8-byte 1] [8-byte length] [length*8 bytes of data] -# -# Double matrix format: -# -# [8-byte 2] [8-byte rows] [8-byte cols] [rows*cols*8 bytes of data] -# -# This is all in machine-endian. That means that the Java interpreter and the -# Python interpreter must agree on what endian the machine is. - -def _deserialize_byte_array(shape, ba, offset): - """Wrapper around ndarray aliasing hack. - - >>> x = array([1.0, 2.0, 3.0, 4.0, 5.0]) - >>> array_equal(x, _deserialize_byte_array(x.shape, x.data, 0)) - True - >>> x = array([1.0, 2.0, 3.0, 4.0]).reshape(2,2) - >>> array_equal(x, _deserialize_byte_array(x.shape, x.data, 0)) - True - """ - ar = ndarray(shape=shape, buffer=ba, offset=offset, dtype="float64", - order='C') - return ar.copy() - -def _serialize_double_vector(v): - """Serialize a double vector into a mutually understood format.""" - if type(v) != ndarray: - raise TypeError("_serialize_double_vector called on a %s; " - "wanted ndarray" % type(v)) - if v.dtype != float64: - raise TypeError("_serialize_double_vector called on an ndarray of %s; " - "wanted ndarray of float64" % v.dtype) - if v.ndim != 1: - raise TypeError("_serialize_double_vector called on a %ddarray; " - "wanted a 1darray" % v.ndim) - length = v.shape[0] - ba = bytearray(16 + 8*length) - header = ndarray(shape=[2], buffer=ba, dtype="int64") - header[0] = 1 - header[1] = length - copyto(ndarray(shape=[length], buffer=ba, offset=16, - dtype="float64"), v) - return ba - -def _deserialize_double_vector(ba): - """Deserialize a double vector from a mutually understood format. - - >>> x = array([1.0, 2.0, 3.0, 4.0, -1.0, 0.0, -0.0]) - >>> array_equal(x, _deserialize_double_vector(_serialize_double_vector(x))) - True - """ - if type(ba) != bytearray: - raise TypeError("_deserialize_double_vector called on a %s; " - "wanted bytearray" % type(ba)) - if len(ba) < 16: - raise TypeError("_deserialize_double_vector called on a %d-byte array, " - "which is too short" % len(ba)) - if (len(ba) & 7) != 0: - raise TypeError("_deserialize_double_vector called on a %d-byte array, " - "which is not a multiple of 8" % len(ba)) - header = ndarray(shape=[2], buffer=ba, dtype="int64") - if header[0] != 1: - raise TypeError("_deserialize_double_vector called on bytearray " - "with wrong magic") - length = header[1] - if len(ba) != 8*length + 16: - raise TypeError("_deserialize_double_vector called on bytearray " - "with wrong length") - return _deserialize_byte_array([length], ba, 16) - -def _serialize_double_matrix(m): - """Serialize a double matrix into a mutually understood format.""" - if (type(m) == ndarray and m.dtype == float64 and m.ndim == 2): - rows = m.shape[0] - cols = m.shape[1] - ba = bytearray(24 + 8 * rows * cols) - header = ndarray(shape=[3], buffer=ba, dtype="int64") - header[0] = 2 - header[1] = rows - header[2] = cols - copyto(ndarray(shape=[rows, cols], buffer=ba, offset=24, - dtype="float64", order='C'), m) - return ba - else: - raise TypeError("_serialize_double_matrix called on a " - "non-double-matrix") - -def _deserialize_double_matrix(ba): - """Deserialize a double matrix from a mutually understood format.""" - if type(ba) != bytearray: - raise TypeError("_deserialize_double_matrix called on a %s; " - "wanted bytearray" % type(ba)) - if len(ba) < 24: - raise TypeError("_deserialize_double_matrix called on a %d-byte array, " - "which is too short" % len(ba)) - if (len(ba) & 7) != 0: - raise TypeError("_deserialize_double_matrix called on a %d-byte array, " - "which is not a multiple of 8" % len(ba)) - header = ndarray(shape=[3], buffer=ba, dtype="int64") - if (header[0] != 2): - raise TypeError("_deserialize_double_matrix called on bytearray " - "with wrong magic") - rows = header[1] - cols = header[2] - if (len(ba) != 8*rows*cols + 24): - raise TypeError("_deserialize_double_matrix called on bytearray " - "with wrong length") - return _deserialize_byte_array([rows, cols], ba, 24) - -def _linear_predictor_typecheck(x, coeffs): - """Check that x is a one-dimensional vector of the right shape. - This is a temporary hackaround until I actually implement bulk predict.""" - if type(x) == ndarray: - if x.ndim == 1: - if x.shape == coeffs.shape: - pass - else: - raise RuntimeError("Got array of %d elements; wanted %d" - % shape(x)[0] % shape(coeffs)[0]) - else: - raise RuntimeError("Bulk predict not yet supported.") - elif (type(x) == RDD): - raise RuntimeError("Bulk predict not yet supported.") - else: - raise TypeError("Argument of type " + type(x) + " unsupported") - -class LinearModel(object): - """Something that has a vector of coefficients and an intercept.""" - def __init__(self, coeff, intercept): - self._coeff = coeff - self._intercept = intercept - -class LinearRegressionModelBase(LinearModel): - """A linear regression model. - - >>> lrmb = LinearRegressionModelBase(array([1.0, 2.0]), 0.1) - >>> abs(lrmb.predict(array([-1.03, 7.777])) - 14.624) < 1e-6 - True - """ - def predict(self, x): - """Predict the value of the dependent variable given a vector x""" - """containing values for the independent variables.""" - _linear_predictor_typecheck(x, self._coeff) - return dot(self._coeff, x) + self._intercept - -def _get_unmangled_rdd(data, serializer): - dataBytes = data.map(serializer) - dataBytes._bypass_serializer = True - dataBytes.cache() - return dataBytes - -# Map a pickled Python RDD of numpy double vectors to a Java RDD of -# _serialized_double_vectors -def _get_unmangled_double_vector_rdd(data): - return _get_unmangled_rdd(data, _serialize_double_vector) - -# If we weren't given initial weights, take a zero vector of the appropriate -# length. -def _get_initial_weights(initial_weights, data): - if initial_weights is None: - initial_weights = data.first() - if type(initial_weights) != ndarray: - raise TypeError("At least one data element has type " - + type(initial_weights) + " which is not ndarray") - if initial_weights.ndim != 1: - raise TypeError("At least one data element has " - + initial_weights.ndim + " dimensions, which is not 1") - initial_weights = zeros([initial_weights.shape[0] - 1]) - return initial_weights - -# train_func should take two parameters, namely data and initial_weights, and -# return the result of a call to the appropriate JVM stub. -# _regression_train_wrapper is responsible for setup and error checking. -def _regression_train_wrapper(sc, train_func, klass, data, initial_weights): - initial_weights = _get_initial_weights(initial_weights, data) - dataBytes = _get_unmangled_double_vector_rdd(data) - ans = train_func(dataBytes, _serialize_double_vector(initial_weights)) - if len(ans) != 2: - raise RuntimeError("JVM call result had unexpected length") - elif type(ans[0]) != bytearray: - raise RuntimeError("JVM call result had first element of type " - + type(ans[0]) + " which is not bytearray") - elif type(ans[1]) != float: - raise RuntimeError("JVM call result had second element of type " - + type(ans[0]) + " which is not float") - return klass(_deserialize_double_vector(ans[0]), ans[1]) - -class LinearRegressionModel(LinearRegressionModelBase): - """A linear regression model derived from a least-squares fit. - - >>> data = array([0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0]).reshape(4,2) - >>> lrm = LinearRegressionModel.train(sc, sc.parallelize(data), initial_weights=array([1.0])) - """ - @classmethod - def train(cls, sc, data, iterations=100, step=1.0, - mini_batch_fraction=1.0, initial_weights=None): - """Train a linear regression model on the given data.""" - return _regression_train_wrapper(sc, lambda d, i: - sc._jvm.PythonMLLibAPI().trainLinearRegressionModel( - d._jrdd, iterations, step, mini_batch_fraction, i), - LinearRegressionModel, data, initial_weights) - -class LassoModel(LinearRegressionModelBase): - """A linear regression model derived from a least-squares fit with an - l_1 penalty term. - - >>> data = array([0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0]).reshape(4,2) - >>> lrm = LassoModel.train(sc, sc.parallelize(data), initial_weights=array([1.0])) - """ - @classmethod - def train(cls, sc, data, iterations=100, step=1.0, reg_param=1.0, - mini_batch_fraction=1.0, initial_weights=None): - """Train a Lasso regression model on the given data.""" - return _regression_train_wrapper(sc, lambda d, i: - sc._jvm.PythonMLLibAPI().trainLassoModel(d._jrdd, - iterations, step, reg_param, mini_batch_fraction, i), - LassoModel, data, initial_weights) - -class RidgeRegressionModel(LinearRegressionModelBase): - """A linear regression model derived from a least-squares fit with an - l_2 penalty term. - - >>> data = array([0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0]).reshape(4,2) - >>> lrm = RidgeRegressionModel.train(sc, sc.parallelize(data), initial_weights=array([1.0])) - """ - @classmethod - def train(cls, sc, data, iterations=100, step=1.0, reg_param=1.0, - mini_batch_fraction=1.0, initial_weights=None): - """Train a ridge regression model on the given data.""" - return _regression_train_wrapper(sc, lambda d, i: - sc._jvm.PythonMLLibAPI().trainRidgeModel(d._jrdd, - iterations, step, reg_param, mini_batch_fraction, i), - RidgeRegressionModel, data, initial_weights) - -class LogisticRegressionModel(LinearModel): - """A linear binary classification model derived from logistic regression. - - >>> data = array([0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 1.0, 3.0]).reshape(4,2) - >>> lrm = LogisticRegressionModel.train(sc, sc.parallelize(data)) - """ - def predict(self, x): - _linear_predictor_typecheck(x, _coeff) - margin = dot(x, _coeff) + intercept - prob = 1/(1 + exp(-margin)) - return 1 if prob > 0.5 else 0 - - @classmethod - def train(cls, sc, data, iterations=100, step=1.0, - mini_batch_fraction=1.0, initial_weights=None): - """Train a logistic regression model on the given data.""" - return _regression_train_wrapper(sc, lambda d, i: - sc._jvm.PythonMLLibAPI().trainLogisticRegressionModel(d._jrdd, - iterations, step, mini_batch_fraction, i), - LogisticRegressionModel, data, initial_weights) - -class SVMModel(LinearModel): - """A support vector machine. - - >>> data = array([0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 1.0, 3.0]).reshape(4,2) - >>> svm = SVMModel.train(sc, sc.parallelize(data)) - """ - def predict(self, x): - _linear_predictor_typecheck(x, _coeff) - margin = dot(x, _coeff) + intercept - return 1 if margin >= 0 else 0 - @classmethod - def train(cls, sc, data, iterations=100, step=1.0, reg_param=1.0, - mini_batch_fraction=1.0, initial_weights=None): - """Train a support vector machine on the given data.""" - return _regression_train_wrapper(sc, lambda d, i: - sc._jvm.PythonMLLibAPI().trainSVMModel(d._jrdd, - iterations, step, reg_param, mini_batch_fraction, i), - SVMModel, data, initial_weights) - -class KMeansModel(object): - """A clustering model derived from the k-means method. - - >>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4,2) - >>> clusters = KMeansModel.train(sc, sc.parallelize(data), 2, maxIterations=10, runs=30, initialization_mode="random") - >>> clusters.predict(array([0.0, 0.0])) == clusters.predict(array([1.0, 1.0])) - True - >>> clusters.predict(array([8.0, 9.0])) == clusters.predict(array([9.0, 8.0])) - True - >>> clusters = KMeansModel.train(sc, sc.parallelize(data), 2) - """ - def __init__(self, centers_): - self.centers = centers_ - - def predict(self, x): - """Find the cluster to which x belongs in this model.""" - best = 0 - best_distance = 1e75 - for i in range(0, self.centers.shape[0]): - diff = x - self.centers[i] - distance = sqrt(dot(diff, diff)) - if distance < best_distance: - best = i - best_distance = distance - return best - - @classmethod - def train(cls, sc, data, k, maxIterations=100, runs=1, - initialization_mode="k-means||"): - """Train a k-means clustering model.""" - dataBytes = _get_unmangled_double_vector_rdd(data) - ans = sc._jvm.PythonMLLibAPI().trainKMeansModel(dataBytes._jrdd, - k, maxIterations, runs, initialization_mode) - if len(ans) != 1: - raise RuntimeError("JVM call result had unexpected length") - elif type(ans[0]) != bytearray: - raise RuntimeError("JVM call result had first element of type " - + type(ans[0]) + " which is not bytearray") - return KMeansModel(_deserialize_double_matrix(ans[0])) - -def _serialize_rating(r): - ba = bytearray(16) - intpart = ndarray(shape=[2], buffer=ba, dtype=int32) - doublepart = ndarray(shape=[1], buffer=ba, dtype=float64, offset=8) - intpart[0], intpart[1], doublepart[0] = r - return ba - -class ALSModel(object): - """A matrix factorisation model trained by regularized alternating - least-squares. - - >>> r1 = (1, 1, 1.0) - >>> r2 = (1, 2, 2.0) - >>> r3 = (2, 1, 2.0) - >>> ratings = sc.parallelize([r1, r2, r3]) - >>> model = ALSModel.trainImplicit(sc, ratings, 1) - >>> model.predict(2,2) is not None - True - """ - - def __init__(self, sc, java_model): - self._context = sc - self._java_model = java_model - - def __del__(self): - self._context._gateway.detach(self._java_model) - - def predict(self, user, product): - return self._java_model.predict(user, product) - - @classmethod - def train(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1): - ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating) - mod = sc._jvm.PythonMLLibAPI().trainALSModel(ratingBytes._jrdd, - rank, iterations, lambda_, blocks) - return ALSModel(sc, mod) - - @classmethod - def trainImplicit(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01): - ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating) - mod = sc._jvm.PythonMLLibAPI().trainImplicitALSModel(ratingBytes._jrdd, - rank, iterations, lambda_, blocks, alpha) - return ALSModel(sc, mod) - -def _test(): - import doctest - globs = globals().copy() - globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) - (failure_count, test_count) = doctest.testmod(globs=globs, - optionflags=doctest.ELLIPSIS) - globs['sc'].stop() - if failure_count: - exit(-1) - -if __name__ == "__main__": - _test() diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py new file mode 100644 index 0000000000..6037a3aa63 --- /dev/null +++ b/python/pyspark/mllib/__init__.py @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +PySpark is the Python API for Spark. + +Public classes: + + - L{SparkContext} + Main entry point for Spark functionality. + - L{RDD} + A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. + - L{Broadcast} + A broadcast variable that gets reused across tasks. + - L{Accumulator} + An "add-only" shared variable that tasks can only add values to. + - L{SparkFiles} + Access files shipped with jobs. + - L{StorageLevel} + Finer-grained cache persistence levels. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.environ["SPARK_HOME"], "python/lib/py4j0.7.egg")) + +from pyspark.mllib.regression import LinearRegressionModel, LassoModel, RidgeRegressionModel, LinearRegressionWithSGD, LassoWithSGD, RidgeRegressionWithSGD +from pyspark.mllib.classification import LogisticRegressionModel, SVMModel, LogisticRegressionWithSGD, SVMWithSGD +from pyspark.mllib.recommendation import MatrixFactorizationModel, ALS +from pyspark.mllib.clustering import KMeansModel, KMeans + + +__all__ = ["LinearRegressionModel", "LassoModel", "RidgeRegressionModel", "LinearRegressionWithSGD", "LassoWithSGD", "RidgeRegressionWithSGD", "LogisticRegressionModel", "SVMModel", "LogisticRegressionWithSGD", "SVMWithSGD", "MatrixFactorizationModel", "ALS", "KMeansModel", "KMeans"] diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py new file mode 100644 index 0000000000..e68bd8a9db --- /dev/null +++ b/python/pyspark/mllib/_common.py @@ -0,0 +1,227 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from numpy import ndarray, copyto, float64, int64, int32, zeros, array_equal, array, dot, shape +from pyspark import SparkContext + +# Double vector format: +# +# [8-byte 1] [8-byte length] [length*8 bytes of data] +# +# Double matrix format: +# +# [8-byte 2] [8-byte rows] [8-byte cols] [rows*cols*8 bytes of data] +# +# This is all in machine-endian. That means that the Java interpreter and the +# Python interpreter must agree on what endian the machine is. + +def _deserialize_byte_array(shape, ba, offset): + """Wrapper around ndarray aliasing hack. + + >>> x = array([1.0, 2.0, 3.0, 4.0, 5.0]) + >>> array_equal(x, _deserialize_byte_array(x.shape, x.data, 0)) + True + >>> x = array([1.0, 2.0, 3.0, 4.0]).reshape(2,2) + >>> array_equal(x, _deserialize_byte_array(x.shape, x.data, 0)) + True + """ + ar = ndarray(shape=shape, buffer=ba, offset=offset, dtype="float64", + order='C') + return ar.copy() + +def _serialize_double_vector(v): + """Serialize a double vector into a mutually understood format.""" + if type(v) != ndarray: + raise TypeError("_serialize_double_vector called on a %s; " + "wanted ndarray" % type(v)) + if v.dtype != float64: + raise TypeError("_serialize_double_vector called on an ndarray of %s; " + "wanted ndarray of float64" % v.dtype) + if v.ndim != 1: + raise TypeError("_serialize_double_vector called on a %ddarray; " + "wanted a 1darray" % v.ndim) + length = v.shape[0] + ba = bytearray(16 + 8*length) + header = ndarray(shape=[2], buffer=ba, dtype="int64") + header[0] = 1 + header[1] = length + copyto(ndarray(shape=[length], buffer=ba, offset=16, + dtype="float64"), v) + return ba + +def _deserialize_double_vector(ba): + """Deserialize a double vector from a mutually understood format. + + >>> x = array([1.0, 2.0, 3.0, 4.0, -1.0, 0.0, -0.0]) + >>> array_equal(x, _deserialize_double_vector(_serialize_double_vector(x))) + True + """ + if type(ba) != bytearray: + raise TypeError("_deserialize_double_vector called on a %s; " + "wanted bytearray" % type(ba)) + if len(ba) < 16: + raise TypeError("_deserialize_double_vector called on a %d-byte array, " + "which is too short" % len(ba)) + if (len(ba) & 7) != 0: + raise TypeError("_deserialize_double_vector called on a %d-byte array, " + "which is not a multiple of 8" % len(ba)) + header = ndarray(shape=[2], buffer=ba, dtype="int64") + if header[0] != 1: + raise TypeError("_deserialize_double_vector called on bytearray " + "with wrong magic") + length = header[1] + if len(ba) != 8*length + 16: + raise TypeError("_deserialize_double_vector called on bytearray " + "with wrong length") + return _deserialize_byte_array([length], ba, 16) + +def _serialize_double_matrix(m): + """Serialize a double matrix into a mutually understood format.""" + if (type(m) == ndarray and m.dtype == float64 and m.ndim == 2): + rows = m.shape[0] + cols = m.shape[1] + ba = bytearray(24 + 8 * rows * cols) + header = ndarray(shape=[3], buffer=ba, dtype="int64") + header[0] = 2 + header[1] = rows + header[2] = cols + copyto(ndarray(shape=[rows, cols], buffer=ba, offset=24, + dtype="float64", order='C'), m) + return ba + else: + raise TypeError("_serialize_double_matrix called on a " + "non-double-matrix") + +def _deserialize_double_matrix(ba): + """Deserialize a double matrix from a mutually understood format.""" + if type(ba) != bytearray: + raise TypeError("_deserialize_double_matrix called on a %s; " + "wanted bytearray" % type(ba)) + if len(ba) < 24: + raise TypeError("_deserialize_double_matrix called on a %d-byte array, " + "which is too short" % len(ba)) + if (len(ba) & 7) != 0: + raise TypeError("_deserialize_double_matrix called on a %d-byte array, " + "which is not a multiple of 8" % len(ba)) + header = ndarray(shape=[3], buffer=ba, dtype="int64") + if (header[0] != 2): + raise TypeError("_deserialize_double_matrix called on bytearray " + "with wrong magic") + rows = header[1] + cols = header[2] + if (len(ba) != 8*rows*cols + 24): + raise TypeError("_deserialize_double_matrix called on bytearray " + "with wrong length") + return _deserialize_byte_array([rows, cols], ba, 24) + +def _linear_predictor_typecheck(x, coeffs): + """Check that x is a one-dimensional vector of the right shape. + This is a temporary hackaround until I actually implement bulk predict.""" + if type(x) == ndarray: + if x.ndim == 1: + if x.shape == coeffs.shape: + pass + else: + raise RuntimeError("Got array of %d elements; wanted %d" + % (shape(x)[0], shape(coeffs)[0])) + else: + raise RuntimeError("Bulk predict not yet supported.") + elif (type(x) == RDD): + raise RuntimeError("Bulk predict not yet supported.") + else: + raise TypeError("Argument of type " + type(x) + " unsupported") + +def _get_unmangled_rdd(data, serializer): + dataBytes = data.map(serializer) + dataBytes._bypass_serializer = True + dataBytes.cache() + return dataBytes + +# Map a pickled Python RDD of numpy double vectors to a Java RDD of +# _serialized_double_vectors +def _get_unmangled_double_vector_rdd(data): + return _get_unmangled_rdd(data, _serialize_double_vector) + +class LinearModel(object): + """Something that has a vector of coefficients and an intercept.""" + def __init__(self, coeff, intercept): + self._coeff = coeff + self._intercept = intercept + +class LinearRegressionModelBase(LinearModel): + """A linear regression model. + + >>> lrmb = LinearRegressionModelBase(array([1.0, 2.0]), 0.1) + >>> abs(lrmb.predict(array([-1.03, 7.777])) - 14.624) < 1e-6 + True + """ + def predict(self, x): + """Predict the value of the dependent variable given a vector x""" + """containing values for the independent variables.""" + _linear_predictor_typecheck(x, self._coeff) + return dot(self._coeff, x) + self._intercept + +# If we weren't given initial weights, take a zero vector of the appropriate +# length. +def _get_initial_weights(initial_weights, data): + if initial_weights is None: + initial_weights = data.first() + if type(initial_weights) != ndarray: + raise TypeError("At least one data element has type " + + type(initial_weights) + " which is not ndarray") + if initial_weights.ndim != 1: + raise TypeError("At least one data element has " + + initial_weights.ndim + " dimensions, which is not 1") + initial_weights = zeros([initial_weights.shape[0] - 1]) + return initial_weights + +# train_func should take two parameters, namely data and initial_weights, and +# return the result of a call to the appropriate JVM stub. +# _regression_train_wrapper is responsible for setup and error checking. +def _regression_train_wrapper(sc, train_func, klass, data, initial_weights): + initial_weights = _get_initial_weights(initial_weights, data) + dataBytes = _get_unmangled_double_vector_rdd(data) + ans = train_func(dataBytes, _serialize_double_vector(initial_weights)) + if len(ans) != 2: + raise RuntimeError("JVM call result had unexpected length") + elif type(ans[0]) != bytearray: + raise RuntimeError("JVM call result had first element of type " + + type(ans[0]) + " which is not bytearray") + elif type(ans[1]) != float: + raise RuntimeError("JVM call result had second element of type " + + type(ans[0]) + " which is not float") + return klass(_deserialize_double_vector(ans[0]), ans[1]) + +def _serialize_rating(r): + ba = bytearray(16) + intpart = ndarray(shape=[2], buffer=ba, dtype=int32) + doublepart = ndarray(shape=[1], buffer=ba, dtype=float64, offset=8) + intpart[0], intpart[1], doublepart[0] = r + return ba + +def _test(): + import doctest + globs = globals().copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + (failure_count, test_count) = doctest.testmod(globs=globs, + optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py new file mode 100644 index 0000000000..70de332d34 --- /dev/null +++ b/python/pyspark/mllib/classification.py @@ -0,0 +1,86 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from numpy import array, dot, shape +from pyspark import SparkContext +from pyspark.mllib._common import \ + _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \ + _serialize_double_matrix, _deserialize_double_matrix, \ + _serialize_double_vector, _deserialize_double_vector, \ + _get_initial_weights, _serialize_rating, _regression_train_wrapper, \ + LinearModel, _linear_predictor_typecheck +from math import exp, log + +class LogisticRegressionModel(LinearModel): + """A linear binary classification model derived from logistic regression. + + >>> data = array([0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 1.0, 3.0]).reshape(4,2) + >>> lrm = LogisticRegressionWithSGD.train(sc, sc.parallelize(data)) + >>> lrm.predict(array([1.0])) != None + True + """ + def predict(self, x): + _linear_predictor_typecheck(x, self._coeff) + margin = dot(x, self._coeff) + self._intercept + prob = 1/(1 + exp(-margin)) + return 1 if prob > 0.5 else 0 + +class LogisticRegressionWithSGD(object): + @classmethod + def train(cls, sc, data, iterations=100, step=1.0, + mini_batch_fraction=1.0, initial_weights=None): + """Train a logistic regression model on the given data.""" + return _regression_train_wrapper(sc, lambda d, i: + sc._jvm.PythonMLLibAPI().trainLogisticRegressionModelWithSGD(d._jrdd, + iterations, step, mini_batch_fraction, i), + LogisticRegressionModel, data, initial_weights) + +class SVMModel(LinearModel): + """A support vector machine. + + >>> data = array([0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 1.0, 3.0]).reshape(4,2) + >>> svm = SVMWithSGD.train(sc, sc.parallelize(data)) + >>> svm.predict(array([1.0])) != None + True + """ + def predict(self, x): + _linear_predictor_typecheck(x, self._coeff) + margin = dot(x, self._coeff) + self._intercept + return 1 if margin >= 0 else 0 + +class SVMWithSGD(object): + @classmethod + def train(cls, sc, data, iterations=100, step=1.0, reg_param=1.0, + mini_batch_fraction=1.0, initial_weights=None): + """Train a support vector machine on the given data.""" + return _regression_train_wrapper(sc, lambda d, i: + sc._jvm.PythonMLLibAPI().trainSVMModelWithSGD(d._jrdd, + iterations, step, reg_param, mini_batch_fraction, i), + SVMModel, data, initial_weights) + +def _test(): + import doctest + globs = globals().copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + (failure_count, test_count) = doctest.testmod(globs=globs, + optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py new file mode 100644 index 0000000000..8cf20e591a --- /dev/null +++ b/python/pyspark/mllib/clustering.py @@ -0,0 +1,79 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from numpy import array, dot +from math import sqrt +from pyspark import SparkContext +from pyspark.mllib._common import \ + _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \ + _serialize_double_matrix, _deserialize_double_matrix, \ + _serialize_double_vector, _deserialize_double_vector, \ + _get_initial_weights, _serialize_rating, _regression_train_wrapper + +class KMeansModel(object): + """A clustering model derived from the k-means method. + + >>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4,2) + >>> clusters = KMeans.train(sc, sc.parallelize(data), 2, maxIterations=10, runs=30, initialization_mode="random") + >>> clusters.predict(array([0.0, 0.0])) == clusters.predict(array([1.0, 1.0])) + True + >>> clusters.predict(array([8.0, 9.0])) == clusters.predict(array([9.0, 8.0])) + True + >>> clusters = KMeans.train(sc, sc.parallelize(data), 2) + """ + def __init__(self, centers_): + self.centers = centers_ + + def predict(self, x): + """Find the cluster to which x belongs in this model.""" + best = 0 + best_distance = 1e75 + for i in range(0, self.centers.shape[0]): + diff = x - self.centers[i] + distance = sqrt(dot(diff, diff)) + if distance < best_distance: + best = i + best_distance = distance + return best + +class KMeans(object): + @classmethod + def train(cls, sc, data, k, maxIterations=100, runs=1, + initialization_mode="k-means||"): + """Train a k-means clustering model.""" + dataBytes = _get_unmangled_double_vector_rdd(data) + ans = sc._jvm.PythonMLLibAPI().trainKMeansModel(dataBytes._jrdd, + k, maxIterations, runs, initialization_mode) + if len(ans) != 1: + raise RuntimeError("JVM call result had unexpected length") + elif type(ans[0]) != bytearray: + raise RuntimeError("JVM call result had first element of type " + + type(ans[0]) + " which is not bytearray") + return KMeansModel(_deserialize_double_matrix(ans[0])) + +def _test(): + import doctest + globs = globals().copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + (failure_count, test_count) = doctest.testmod(globs=globs, + optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py new file mode 100644 index 0000000000..14d06cba21 --- /dev/null +++ b/python/pyspark/mllib/recommendation.py @@ -0,0 +1,74 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark import SparkContext +from pyspark.mllib._common import \ + _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \ + _serialize_double_matrix, _deserialize_double_matrix, \ + _serialize_double_vector, _deserialize_double_vector, \ + _get_initial_weights, _serialize_rating, _regression_train_wrapper + +class MatrixFactorizationModel(object): + """A matrix factorisation model trained by regularized alternating + least-squares. + + >>> r1 = (1, 1, 1.0) + >>> r2 = (1, 2, 2.0) + >>> r3 = (2, 1, 2.0) + >>> ratings = sc.parallelize([r1, r2, r3]) + >>> model = ALS.trainImplicit(sc, ratings, 1) + >>> model.predict(2,2) is not None + True + """ + + def __init__(self, sc, java_model): + self._context = sc + self._java_model = java_model + + def __del__(self): + self._context._gateway.detach(self._java_model) + + def predict(self, user, product): + return self._java_model.predict(user, product) + +class ALS(object): + @classmethod + def train(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1): + ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating) + mod = sc._jvm.PythonMLLibAPI().trainALSModel(ratingBytes._jrdd, + rank, iterations, lambda_, blocks) + return MatrixFactorizationModel(sc, mod) + + @classmethod + def trainImplicit(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01): + ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating) + mod = sc._jvm.PythonMLLibAPI().trainImplicitALSModel(ratingBytes._jrdd, + rank, iterations, lambda_, blocks, alpha) + return MatrixFactorizationModel(sc, mod) + +def _test(): + import doctest + globs = globals().copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + (failure_count, test_count) = doctest.testmod(globs=globs, + optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py new file mode 100644 index 0000000000..a3a68b29e0 --- /dev/null +++ b/python/pyspark/mllib/regression.py @@ -0,0 +1,110 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from numpy import array, dot +from pyspark import SparkContext +from pyspark.mllib._common import \ + _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \ + _serialize_double_matrix, _deserialize_double_matrix, \ + _serialize_double_vector, _deserialize_double_vector, \ + _get_initial_weights, _serialize_rating, _regression_train_wrapper, \ + _linear_predictor_typecheck + +class LinearModel(object): + """Something that has a vector of coefficients and an intercept.""" + def __init__(self, coeff, intercept): + self._coeff = coeff + self._intercept = intercept + +class LinearRegressionModelBase(LinearModel): + """A linear regression model. + + >>> lrmb = LinearRegressionModelBase(array([1.0, 2.0]), 0.1) + >>> abs(lrmb.predict(array([-1.03, 7.777])) - 14.624) < 1e-6 + True + """ + def predict(self, x): + """Predict the value of the dependent variable given a vector x""" + """containing values for the independent variables.""" + _linear_predictor_typecheck(x, self._coeff) + return dot(self._coeff, x) + self._intercept + +class LinearRegressionModel(LinearRegressionModelBase): + """A linear regression model derived from a least-squares fit. + + >>> data = array([0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0]).reshape(4,2) + >>> lrm = LinearRegressionWithSGD.train(sc, sc.parallelize(data), initial_weights=array([1.0])) + """ + +class LinearRegressionWithSGD(object): + @classmethod + def train(cls, sc, data, iterations=100, step=1.0, + mini_batch_fraction=1.0, initial_weights=None): + """Train a linear regression model on the given data.""" + return _regression_train_wrapper(sc, lambda d, i: + sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD( + d._jrdd, iterations, step, mini_batch_fraction, i), + LinearRegressionModel, data, initial_weights) + +class LassoModel(LinearRegressionModelBase): + """A linear regression model derived from a least-squares fit with an + l_1 penalty term. + + >>> data = array([0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0]).reshape(4,2) + >>> lrm = LassoWithSGD.train(sc, sc.parallelize(data), initial_weights=array([1.0])) + """ + +class LassoWithSGD(object): + @classmethod + def train(cls, sc, data, iterations=100, step=1.0, reg_param=1.0, + mini_batch_fraction=1.0, initial_weights=None): + """Train a Lasso regression model on the given data.""" + return _regression_train_wrapper(sc, lambda d, i: + sc._jvm.PythonMLLibAPI().trainLassoModelWithSGD(d._jrdd, + iterations, step, reg_param, mini_batch_fraction, i), + LassoModel, data, initial_weights) + +class RidgeRegressionModel(LinearRegressionModelBase): + """A linear regression model derived from a least-squares fit with an + l_2 penalty term. + + >>> data = array([0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0]).reshape(4,2) + >>> lrm = RidgeRegressionWithSGD.train(sc, sc.parallelize(data), initial_weights=array([1.0])) + """ + +class RidgeRegressionWithSGD(object): + @classmethod + def train(cls, sc, data, iterations=100, step=1.0, reg_param=1.0, + mini_batch_fraction=1.0, initial_weights=None): + """Train a ridge regression model on the given data.""" + return _regression_train_wrapper(sc, lambda d, i: + sc._jvm.PythonMLLibAPI().trainRidgeModelWithSGD(d._jrdd, + iterations, step, reg_param, mini_batch_fraction, i), + RidgeRegressionModel, data, initial_weights) + +def _test(): + import doctest + globs = globals().copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + (failure_count, test_count) = doctest.testmod(globs=globs, + optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + +if __name__ == "__main__": + _test() -- cgit v1.2.3 From 4e821390bca0d1f40b6f2f011bdc71476a1d3aa4 Mon Sep 17 00:00:00 2001 From: Tor Myklebust Date: Wed, 25 Dec 2013 00:09:00 -0500 Subject: Scala stubs for updated Python bindings. --- .../spark/mllib/api/python/PythonMLLibAPI.scala | 26 +++++++++++----------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index ca474322a8..8247c1ebc5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -125,9 +125,9 @@ class PythonMLLibAPI extends Serializable { } /** - * Java stub for Python mllib LinearRegressionModel.train() + * Java stub for Python mllib LinearRegressionWithSGD.train() */ - def trainLinearRegressionModel(dataBytesJRDD: JavaRDD[Array[Byte]], + def trainLinearRegressionModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, stepSize: Double, miniBatchFraction: Double, initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { return trainRegressionModel((data, initialWeights) => @@ -137,9 +137,9 @@ class PythonMLLibAPI extends Serializable { } /** - * Java stub for Python mllib LassoModel.train() + * Java stub for Python mllib LassoWithSGD.train() */ - def trainLassoModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, + def trainLassoModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, stepSize: Double, regParam: Double, miniBatchFraction: Double, initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { return trainRegressionModel((data, initialWeights) => @@ -149,9 +149,9 @@ class PythonMLLibAPI extends Serializable { } /** - * Java stub for Python mllib RidgeRegressionModel.train() + * Java stub for Python mllib RidgeRegressionWithSGD.train() */ - def trainRidgeModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, + def trainRidgeModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, stepSize: Double, regParam: Double, miniBatchFraction: Double, initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { return trainRegressionModel((data, initialWeights) => @@ -161,9 +161,9 @@ class PythonMLLibAPI extends Serializable { } /** - * Java stub for Python mllib SVMModel.train() + * Java stub for Python mllib SVMWithSGD.train() */ - def trainSVMModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, + def trainSVMModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, stepSize: Double, regParam: Double, miniBatchFraction: Double, initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { return trainRegressionModel((data, initialWeights) => @@ -173,9 +173,9 @@ class PythonMLLibAPI extends Serializable { } /** - * Java stub for Python mllib LogisticRegressionModel.train() + * Java stub for Python mllib LogisticRegressionWithSGD.train() */ - def trainLogisticRegressionModel(dataBytesJRDD: JavaRDD[Array[Byte]], + def trainLogisticRegressionModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, stepSize: Double, miniBatchFraction: Double, initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { return trainRegressionModel((data, initialWeights) => @@ -185,7 +185,7 @@ class PythonMLLibAPI extends Serializable { } /** - * Java stub for Python mllib KMeansModel.train() + * Java stub for Python mllib KMeans.train() */ def trainKMeansModel(dataBytesJRDD: JavaRDD[Array[Byte]], k: Int, maxIterations: Int, runs: Int, initializationMode: String): @@ -207,7 +207,7 @@ class PythonMLLibAPI extends Serializable { } /** - * Java stub for Python mllib ALSModel.train(). This stub returns a handle + * Java stub for Python mllib ALS.train(). This stub returns a handle * to the Java object instead of the content of the Java object. Extra care * needs to be taken in the Python code to ensure it gets freed on exit; see * the Py4J documentation. @@ -219,7 +219,7 @@ class PythonMLLibAPI extends Serializable { } /** - * Java stub for Python mllib ALSModel.trainImplicit(). This stub returns a + * Java stub for Python mllib ALS.trainImplicit(). This stub returns a * handle to the Java object instead of the content of the Java object. * Extra care needs to be taken in the Python code to ensure it gets freed on * exit; see the Py4J documentation. -- cgit v1.2.3 From 02208a175c76be111eeb66dc19c7499a6656a067 Mon Sep 17 00:00:00 2001 From: Tor Myklebust Date: Wed, 25 Dec 2013 00:53:48 -0500 Subject: Initial weights in Scala are ones; do that too. Also fix some errors. --- python/pyspark/mllib/_common.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index e68bd8a9db..e74ba0fabc 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -15,7 +15,7 @@ # limitations under the License. # -from numpy import ndarray, copyto, float64, int64, int32, zeros, array_equal, array, dot, shape +from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape from pyspark import SparkContext # Double vector format: @@ -143,7 +143,7 @@ def _linear_predictor_typecheck(x, coeffs): elif (type(x) == RDD): raise RuntimeError("Bulk predict not yet supported.") else: - raise TypeError("Argument of type " + type(x) + " unsupported") + raise TypeError("Argument of type " + type(x).__name__ + " unsupported") def _get_unmangled_rdd(data, serializer): dataBytes = data.map(serializer) @@ -182,11 +182,11 @@ def _get_initial_weights(initial_weights, data): initial_weights = data.first() if type(initial_weights) != ndarray: raise TypeError("At least one data element has type " - + type(initial_weights) + " which is not ndarray") + + type(initial_weights).__name__ + " which is not ndarray") if initial_weights.ndim != 1: raise TypeError("At least one data element has " + initial_weights.ndim + " dimensions, which is not 1") - initial_weights = zeros([initial_weights.shape[0] - 1]) + initial_weights = ones([initial_weights.shape[0] - 1]) return initial_weights # train_func should take two parameters, namely data and initial_weights, and @@ -200,10 +200,10 @@ def _regression_train_wrapper(sc, train_func, klass, data, initial_weights): raise RuntimeError("JVM call result had unexpected length") elif type(ans[0]) != bytearray: raise RuntimeError("JVM call result had first element of type " - + type(ans[0]) + " which is not bytearray") + + type(ans[0]).__name__ + " which is not bytearray") elif type(ans[1]) != float: raise RuntimeError("JVM call result had second element of type " - + type(ans[0]) + " which is not float") + + type(ans[0]).__name__ + " which is not float") return klass(_deserialize_double_vector(ans[0]), ans[1]) def _serialize_rating(r): -- cgit v1.2.3 From 5e71354cb7ff758d9a70494ca1788aebea1bbb08 Mon Sep 17 00:00:00 2001 From: Tor Myklebust Date: Wed, 25 Dec 2013 14:10:55 -0500 Subject: Fix copypasta in __init__.py. Don't import anything directly into pyspark.mllib. --- python/pyspark/mllib/__init__.py | 34 ++++++++-------------------------- 1 file changed, 8 insertions(+), 26 deletions(-) diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py index 6037a3aa63..e9c62f3410 100644 --- a/python/pyspark/mllib/__init__.py +++ b/python/pyspark/mllib/__init__.py @@ -16,31 +16,13 @@ # """ -PySpark is the Python API for Spark. - -Public classes: - - - L{SparkContext} - Main entry point for Spark functionality. - - L{RDD} - A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. - - L{Broadcast} - A broadcast variable that gets reused across tasks. - - L{Accumulator} - An "add-only" shared variable that tasks can only add values to. - - L{SparkFiles} - Access files shipped with jobs. - - L{StorageLevel} - Finer-grained cache persistence levels. +Python bindings for MLlib. """ -import sys -import os -sys.path.insert(0, os.path.join(os.environ["SPARK_HOME"], "python/lib/py4j0.7.egg")) - -from pyspark.mllib.regression import LinearRegressionModel, LassoModel, RidgeRegressionModel, LinearRegressionWithSGD, LassoWithSGD, RidgeRegressionWithSGD -from pyspark.mllib.classification import LogisticRegressionModel, SVMModel, LogisticRegressionWithSGD, SVMWithSGD -from pyspark.mllib.recommendation import MatrixFactorizationModel, ALS -from pyspark.mllib.clustering import KMeansModel, KMeans - -__all__ = ["LinearRegressionModel", "LassoModel", "RidgeRegressionModel", "LinearRegressionWithSGD", "LassoWithSGD", "RidgeRegressionWithSGD", "LogisticRegressionModel", "SVMModel", "LogisticRegressionWithSGD", "SVMWithSGD", "MatrixFactorizationModel", "ALS", "KMeansModel", "KMeans"] +#from pyspark.mllib.regression import LinearRegressionModel, LassoModel, RidgeRegressionModel, LinearRegressionWithSGD, LassoWithSGD, RidgeRegressionWithSGD +#from pyspark.mllib.classification import LogisticRegressionModel, SVMModel, LogisticRegressionWithSGD, SVMWithSGD +#from pyspark.mllib.recommendation import MatrixFactorizationModel, ALS +#from pyspark.mllib.clustering import KMeansModel, KMeans +# +# +#__all__ = ["LinearRegressionModel", "LassoModel", "RidgeRegressionModel", "LinearRegressionWithSGD", "LassoWithSGD", "RidgeRegressionWithSGD", "LogisticRegressionModel", "SVMModel", "LogisticRegressionWithSGD", "SVMWithSGD", "MatrixFactorizationModel", "ALS", "KMeansModel", "KMeans"] -- cgit v1.2.3 From 9cbcf81453a9afca58645969c1bc3ff366392734 Mon Sep 17 00:00:00 2001 From: Tor Myklebust Date: Wed, 25 Dec 2013 14:12:42 -0500 Subject: Remove commented code in __init__.py. --- python/pyspark/mllib/__init__.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py index e9c62f3410..b1a5df109b 100644 --- a/python/pyspark/mllib/__init__.py +++ b/python/pyspark/mllib/__init__.py @@ -18,11 +18,3 @@ """ Python bindings for MLlib. """ - -#from pyspark.mllib.regression import LinearRegressionModel, LassoModel, RidgeRegressionModel, LinearRegressionWithSGD, LassoWithSGD, RidgeRegressionWithSGD -#from pyspark.mllib.classification import LogisticRegressionModel, SVMModel, LogisticRegressionWithSGD, SVMWithSGD -#from pyspark.mllib.recommendation import MatrixFactorizationModel, ALS -#from pyspark.mllib.clustering import KMeansModel, KMeans -# -# -#__all__ = ["LinearRegressionModel", "LassoModel", "RidgeRegressionModel", "LinearRegressionWithSGD", "LassoWithSGD", "RidgeRegressionWithSGD", "LogisticRegressionModel", "SVMModel", "LogisticRegressionWithSGD", "SVMWithSGD", "MatrixFactorizationModel", "ALS", "KMeansModel", "KMeans"] -- cgit v1.2.3 From 14fcef72db765d0313d4ce3c986c08069a1a01ae Mon Sep 17 00:00:00 2001 From: liguoqiang Date: Thu, 26 Dec 2013 11:05:07 +0800 Subject: Renamed ClusterScheduler to TaskSchedulerImpl for yarn and new-yarn --- .../org/apache/spark/deploy/yarn/YarnAllocationHandler.scala | 9 +++++---- .../spark/scheduler/cluster/YarnClientClusterScheduler.scala | 3 ++- .../spark/scheduler/cluster/YarnClientSchedulerBackend.scala | 3 ++- .../apache/spark/scheduler/cluster/YarnClusterScheduler.scala | 3 ++- .../org/apache/spark/deploy/yarn/YarnAllocationHandler.scala | 9 +++++---- .../spark/scheduler/cluster/YarnClientClusterScheduler.scala | 3 ++- .../spark/scheduler/cluster/YarnClientSchedulerBackend.scala | 3 ++- .../apache/spark/scheduler/cluster/YarnClusterScheduler.scala | 4 ++-- 8 files changed, 22 insertions(+), 15 deletions(-) diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala index c27257cda4..96a24cd2b1 100644 --- a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala +++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala @@ -28,7 +28,8 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import org.apache.spark.Logging import org.apache.spark.scheduler.SplitInfo -import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedulerBackend} +import org.apache.spark.scheduler.TaskSchedulerImpl +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils import org.apache.hadoop.conf.Configuration @@ -233,9 +234,9 @@ private[yarn] class YarnAllocationHandler( // Note that the list we create below tries to ensure that not all containers end up within // a host if there is a sufficiently large number of hosts/containers. val allocatedContainersToProcess = new ArrayBuffer[Container](allocatedContainers.size) - allocatedContainersToProcess ++= ClusterScheduler.prioritizeContainers(dataLocalContainers) - allocatedContainersToProcess ++= ClusterScheduler.prioritizeContainers(rackLocalContainers) - allocatedContainersToProcess ++= ClusterScheduler.prioritizeContainers(offRackContainers) + allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(dataLocalContainers) + allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(rackLocalContainers) + allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(offRackContainers) // Run each of the allocated containers. for (container <- allocatedContainersToProcess) { diff --git a/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala index 63a0449e5a..40307ab972 100644 --- a/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala +++ b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala @@ -21,12 +21,13 @@ import org.apache.spark._ import org.apache.hadoop.conf.Configuration import org.apache.spark.deploy.yarn.YarnAllocationHandler import org.apache.spark.util.Utils +import org.apache.spark.scheduler.TaskSchedulerImpl /** * * This scheduler launch worker through Yarn - by call into Client to launch WorkerLauncher as AM. */ -private[spark] class YarnClientClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) { +private[spark] class YarnClientClusterScheduler(sc: SparkContext, conf: Configuration) extends TaskSchedulerImpl(sc) { def this(sc: SparkContext) = this(sc, new Configuration()) diff --git a/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index b206780c78..350fc760a4 100644 --- a/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -20,9 +20,10 @@ package org.apache.spark.scheduler.cluster import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState} import org.apache.spark.{SparkException, Logging, SparkContext} import org.apache.spark.deploy.yarn.{Client, ClientArguments} +import org.apache.spark.scheduler.TaskSchedulerImpl private[spark] class YarnClientSchedulerBackend( - scheduler: ClusterScheduler, + scheduler: TaskSchedulerImpl, sc: SparkContext) extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) with Logging { diff --git a/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala index 29b3f22e13..b318270f75 100644 --- a/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala +++ b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala @@ -21,12 +21,13 @@ import org.apache.spark._ import org.apache.spark.deploy.yarn.{ApplicationMaster, YarnAllocationHandler} import org.apache.spark.util.Utils import org.apache.hadoop.conf.Configuration +import org.apache.spark.scheduler.TaskSchedulerImpl /** * * This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of ApplicationMaster, etc is done */ -private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) { +private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) extends TaskSchedulerImpl(sc) { logInfo("Created YarnClusterScheduler") diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala index 9ab2073529..eeee78f8ad 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala @@ -28,7 +28,8 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import org.apache.spark.Logging import org.apache.spark.scheduler.SplitInfo -import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedulerBackend} +import org.apache.spark.scheduler.TaskSchedulerImpl +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils import org.apache.hadoop.conf.Configuration @@ -214,9 +215,9 @@ private[yarn] class YarnAllocationHandler( // host if there are sufficiently large number of hosts/containers. val allocatedContainers = new ArrayBuffer[Container](_allocatedContainers.size) - allocatedContainers ++= ClusterScheduler.prioritizeContainers(dataLocalContainers) - allocatedContainers ++= ClusterScheduler.prioritizeContainers(rackLocalContainers) - allocatedContainers ++= ClusterScheduler.prioritizeContainers(offRackContainers) + allocatedContainers ++= TaskSchedulerImpl.prioritizeContainers(dataLocalContainers) + allocatedContainers ++= TaskSchedulerImpl.prioritizeContainers(rackLocalContainers) + allocatedContainers ++= TaskSchedulerImpl.prioritizeContainers(offRackContainers) // Run each of the allocated containers for (container <- allocatedContainers) { diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala index 63a0449e5a..522e0a9ad7 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala @@ -20,13 +20,14 @@ package org.apache.spark.scheduler.cluster import org.apache.spark._ import org.apache.hadoop.conf.Configuration import org.apache.spark.deploy.yarn.YarnAllocationHandler +import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.util.Utils /** * * This scheduler launch worker through Yarn - by call into Client to launch WorkerLauncher as AM. */ -private[spark] class YarnClientClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) { +private[spark] class YarnClientClusterScheduler(sc: SparkContext, conf: Configuration) extends TaskSchedulerImpl(sc) { def this(sc: SparkContext) = this(sc, new Configuration()) diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index b206780c78..350fc760a4 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -20,9 +20,10 @@ package org.apache.spark.scheduler.cluster import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState} import org.apache.spark.{SparkException, Logging, SparkContext} import org.apache.spark.deploy.yarn.{Client, ClientArguments} +import org.apache.spark.scheduler.TaskSchedulerImpl private[spark] class YarnClientSchedulerBackend( - scheduler: ClusterScheduler, + scheduler: TaskSchedulerImpl, sc: SparkContext) extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) with Logging { diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala index 4e988b8017..2d9fbcb400 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala @@ -21,7 +21,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark._ import org.apache.spark.deploy.yarn.{ApplicationMaster, YarnAllocationHandler} -import org.apache.spark.scheduler.ClusterScheduler +import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.util.Utils /** @@ -30,7 +30,7 @@ import org.apache.spark.util.Utils * ApplicationMaster, etc. is done */ private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) - extends ClusterScheduler(sc) { + extends TaskSchedulerImpl(sc) { logInfo("Created YarnClusterScheduler") -- cgit v1.2.3 From 2bd76f693d87330e8cda9c0a9568ee3addd8a422 Mon Sep 17 00:00:00 2001 From: liguoqiang Date: Thu, 26 Dec 2013 11:10:35 +0800 Subject: Renamed ClusterScheduler to TaskSchedulerImpl for yarn and new-yarn --- .../scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala | 3 +-- .../scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala index 96a24cd2b1..784a3112de 100644 --- a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala +++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala @@ -27,8 +27,7 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import org.apache.spark.Logging -import org.apache.spark.scheduler.SplitInfo -import org.apache.spark.scheduler.TaskSchedulerImpl +import org.apache.spark.scheduler.{SplitInfo,TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala index eeee78f8ad..a01657c9fa 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala @@ -27,8 +27,7 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import org.apache.spark.Logging -import org.apache.spark.scheduler.SplitInfo -import org.apache.spark.scheduler.TaskSchedulerImpl +import org.apache.spark.scheduler.{SplitInfo,TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils -- cgit v1.2.3 From c529dceaffad0d0eb611ecb70fb5723dab43d0a1 Mon Sep 17 00:00:00 2001 From: Mark Hamstra Date: Wed, 25 Dec 2013 12:16:21 -0800 Subject: Avoid a lump of coal (NPE) in JobProgressListener's stocking. --- .../scala/org/apache/spark/ui/jobs/JobProgressListener.scala | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 07a42f0503..2e51dd5a99 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -146,12 +146,9 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList // update duration y.taskTime += taskEnd.taskInfo.duration - taskEnd.taskMetrics.shuffleReadMetrics.foreach { shuffleRead => - y.shuffleRead += shuffleRead.remoteBytesRead - } - - taskEnd.taskMetrics.shuffleWriteMetrics.foreach { shuffleWrite => - y.shuffleWrite += shuffleWrite.shuffleBytesWritten + Option(taskEnd.taskMetrics).foreach { taskMetrics => + taskMetrics.shuffleReadMetrics.foreach { y.shuffleRead += _.remoteBytesRead } + taskMetrics.shuffleWriteMetrics.foreach { y.shuffleWrite += _.shuffleBytesWritten } } } case _ => {} -- cgit v1.2.3 From b662c88a24b853542846db538863e04f4862bc20 Mon Sep 17 00:00:00 2001 From: liguoqiang Date: Thu, 26 Dec 2013 15:49:33 +0800 Subject: fix this import order --- .../org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala | 2 +- .../scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala index 40307ab972..522e0a9ad7 100644 --- a/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala +++ b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala @@ -20,8 +20,8 @@ package org.apache.spark.scheduler.cluster import org.apache.spark._ import org.apache.hadoop.conf.Configuration import org.apache.spark.deploy.yarn.YarnAllocationHandler -import org.apache.spark.util.Utils import org.apache.spark.scheduler.TaskSchedulerImpl +import org.apache.spark.util.Utils /** * diff --git a/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala index b318270f75..a4638cc863 100644 --- a/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala +++ b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala @@ -19,9 +19,9 @@ package org.apache.spark.scheduler.cluster import org.apache.spark._ import org.apache.spark.deploy.yarn.{ApplicationMaster, YarnAllocationHandler} +import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.util.Utils import org.apache.hadoop.conf.Configuration -import org.apache.spark.scheduler.TaskSchedulerImpl /** * -- cgit v1.2.3 From 4f2fb761b0d0c756b777a516b2ffb480264ab943 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Thu, 26 Dec 2013 15:26:06 -0800 Subject: Decrease margin of left side of log page --- .../main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index 40d6bdb3fd..19aa800a95 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -140,12 +140,12 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I {linkToMaster}
    -
    {backButton}
    +
    {backButton}
    {range}
    -
    {nextButton}
    +
    {nextButton}

    -
    +
    {logText}
    -- cgit v1.2.3 From 0cc1e0d43d90a4a6d48ced39a5ecbde163663efa Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 26 Dec 2013 23:21:08 -0800 Subject: SPARK-1007: spark-class2.cmd should change SCALA_VERSION to be 2.10 --- spark-class2.cmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark-class2.cmd b/spark-class2.cmd index a60c17d050..dc9dadf356 100644 --- a/spark-class2.cmd +++ b/spark-class2.cmd @@ -17,7 +17,7 @@ rem See the License for the specific language governing permissions and rem limitations under the License. rem -set SCALA_VERSION=2.9.3 +set SCALA_VERSION=2.10 rem Figure out where the Spark framework is installed set FWDIR=%~dp0 -- cgit v1.2.3 From 8c81068e16d4485e7f35dfaf99de6ee99fd76678 Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Fri, 27 Dec 2013 11:36:54 -0800 Subject: Fixed >100char lines in DAGScheduler.scala --- .../org/apache/spark/scheduler/DAGScheduler.scala | 42 ++++++++++++++-------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 963d15b76d..2a131fde28 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -152,7 +152,8 @@ class DAGScheduler( val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done val running = new HashSet[Stage] // Stages we are running right now val failed = new HashSet[Stage] // Stages that must be resubmitted due to fetch failures - val pendingTasks = new TimeStampedHashMap[Stage, HashSet[Task[_]]] // Missing tasks from each stage + // Missing tasks from each stage + val pendingTasks = new TimeStampedHashMap[Stage, HashSet[Task[_]]] var lastFetchFailureTime: Long = 0 // Used to wait a bit to avoid repeated resubmits val activeJobs = new HashSet[ActiveJob] @@ -239,7 +240,8 @@ class DAGScheduler( shuffleToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage case None => - val stage = newOrUsedStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, shuffleDep, jobId) + val stage = + newOrUsedStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, shuffleDep, jobId) shuffleToMapStage(shuffleDep.shuffleId) = stage stage } @@ -248,7 +250,8 @@ class DAGScheduler( /** * Create a Stage -- either directly for use as a result stage, or as part of the (re)-creation * of a shuffle map stage in newOrUsedStage. The stage will be associated with the provided - * jobId. Production of shuffle map stages should always use newOrUsedStage, not newStage directly. + * jobId. Production of shuffle map stages should always use newOrUsedStage, not newStage + * directly. */ private def newStage( rdd: RDD[_], @@ -358,7 +361,8 @@ class DAGScheduler( stageIdToJobIds.getOrElseUpdate(s.id, new HashSet[Int]()) += jobId jobIdToStageIds.getOrElseUpdate(jobId, new HashSet[Int]()) += s.id val parents = getParentStages(s.rdd, jobId) - val parentsWithoutThisJobId = parents.filter(p => !stageIdToJobIds.get(p.id).exists(_.contains(jobId))) + val parentsWithoutThisJobId = parents.filter( + p => !stageIdToJobIds.get(p.id).exists(_.contains(jobId))) updateJobIdStageIdMapsList(parentsWithoutThisJobId ++ stages.tail) } } @@ -366,8 +370,9 @@ class DAGScheduler( } /** - * Removes job and any stages that are not needed by any other job. Returns the set of ids for stages that - * were removed. The associated tasks for those stages need to be cancelled if we got here via job cancellation. + * Removes job and any stages that are not needed by any other job. Returns the set of ids for + * stages that were removed. The associated tasks for those stages need to be cancelled if we + * got here via job cancellation. */ private def removeJobAndIndependentStages(jobId: Int): Set[Int] = { val registeredStages = jobIdToStageIds(jobId) @@ -378,7 +383,8 @@ class DAGScheduler( stageIdToJobIds.filterKeys(stageId => registeredStages.contains(stageId)).foreach { case (stageId, jobSet) => if (!jobSet.contains(jobId)) { - logError("Job %d not registered for stage %d even though that stage was registered for the job" + logError( + "Job %d not registered for stage %d even though that stage was registered for the job" .format(jobId, stageId)) } else { def removeStage(stageId: Int) { @@ -389,7 +395,8 @@ class DAGScheduler( running -= s } stageToInfos -= s - shuffleToMapStage.keys.filter(shuffleToMapStage(_) == s).foreach(shuffleToMapStage.remove) + shuffleToMapStage.keys.filter(shuffleToMapStage(_) == s).foreach( + shuffleToMapStage.remove) if (pendingTasks.contains(s) && !pendingTasks(s).isEmpty) { logDebug("Removing pending status for stage %d".format(stageId)) } @@ -407,7 +414,8 @@ class DAGScheduler( stageIdToStage -= stageId stageIdToJobIds -= stageId - logDebug("After removal of stage %d, remaining stages = %d".format(stageId, stageIdToStage.size)) + logDebug("After removal of stage %d, remaining stages = %d" + .format(stageId, stageIdToStage.size)) } jobSet -= jobId @@ -459,7 +467,8 @@ class DAGScheduler( assert(partitions.size > 0) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler) - eventProcessActor ! JobSubmitted(jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties) + eventProcessActor ! JobSubmitted( + jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties) waiter } @@ -494,7 +503,8 @@ class DAGScheduler( val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val partitions = (0 until rdd.partitions.size).toArray val jobId = nextJobId.getAndIncrement() - eventProcessActor ! JobSubmitted(jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties) + eventProcessActor ! JobSubmitted( + jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties) listener.awaitResult() // Will throw an exception if the job fails } @@ -529,8 +539,8 @@ class DAGScheduler( case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) => var finalStage: Stage = null try { - // New stage creation at times and if its not protected, the scheduler thread is killed. - // e.g. it can fail when jobs are run on HadoopRDD whose underlying hdfs files have been deleted + // New stage creation may throw an exception if, for example, jobs are run on a HadoopRDD + // whose underlying HDFS files have been deleted. finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite)) } catch { case e: Exception => @@ -563,7 +573,8 @@ class DAGScheduler( case JobGroupCancelled(groupId) => // Cancel all jobs belonging to this job group. // First finds all active jobs with this group id, and then kill stages for them. - val activeInGroup = activeJobs.filter(groupId == _.properties.get(SparkContext.SPARK_JOB_GROUP_ID)) + val activeInGroup = activeJobs.filter( + groupId == _.properties.get(SparkContext.SPARK_JOB_GROUP_ID)) val jobIds = activeInGroup.map(_.jobId) jobIds.foreach { handleJobCancellation } @@ -585,7 +596,8 @@ class DAGScheduler( stage <- stageIdToStage.get(task.stageId); stageInfo <- stageToInfos.get(stage) ) { - if (taskInfo.serializedSize > TASK_SIZE_TO_WARN * 1024 && !stageInfo.emittedTaskSizeWarning) { + if (taskInfo.serializedSize > TASK_SIZE_TO_WARN * 1024 && + !stageInfo.emittedTaskSizeWarning) { stageInfo.emittedTaskSizeWarning = true logWarning(("Stage %d (%s) contains a task of very large " + "size (%d KB). The maximum recommended task size is %d KB.").format( -- cgit v1.2.3 From 0c71ffe924a158608b1760477b883e4818d53af4 Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Fri, 27 Dec 2013 12:19:38 -0800 Subject: Style fixes as per Reynold's review --- .../main/scala/org/apache/spark/scheduler/DAGScheduler.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 2a131fde28..c48a3d64ef 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -361,8 +361,8 @@ class DAGScheduler( stageIdToJobIds.getOrElseUpdate(s.id, new HashSet[Int]()) += jobId jobIdToStageIds.getOrElseUpdate(jobId, new HashSet[Int]()) += s.id val parents = getParentStages(s.rdd, jobId) - val parentsWithoutThisJobId = parents.filter( - p => !stageIdToJobIds.get(p.id).exists(_.contains(jobId))) + val parentsWithoutThisJobId = parents.filter(p => + !stageIdToJobIds.get(p.id).exists(_.contains(jobId))) updateJobIdStageIdMapsList(parentsWithoutThisJobId ++ stages.tail) } } @@ -395,8 +395,8 @@ class DAGScheduler( running -= s } stageToInfos -= s - shuffleToMapStage.keys.filter(shuffleToMapStage(_) == s).foreach( - shuffleToMapStage.remove) + shuffleToMapStage.keys.filter(shuffleToMapStage(_) == s).foreach(shuffleId => + shuffleToMapStage.remove(shuffleId)) if (pendingTasks.contains(s) && !pendingTasks(s).isEmpty) { logDebug("Removing pending status for stage %d".format(stageId)) } @@ -573,8 +573,8 @@ class DAGScheduler( case JobGroupCancelled(groupId) => // Cancel all jobs belonging to this job group. // First finds all active jobs with this group id, and then kill stages for them. - val activeInGroup = activeJobs.filter( - groupId == _.properties.get(SparkContext.SPARK_JOB_GROUP_ID)) + val activeInGroup = activeJobs.filter(activeJob => + groupId == activeJob.properties.get(SparkContext.SPARK_JOB_GROUP_ID)) val jobIds = activeInGroup.map(_.jobId) jobIds.foreach { handleJobCancellation } -- cgit v1.2.3 From 8419148e5fa54b5e3dd6b95fd5176b71506a951e Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Fri, 27 Dec 2013 15:14:38 -0800 Subject: Remove unused hasPendingTasks methods --- core/src/main/scala/org/apache/spark/scheduler/Pool.scala | 4 ---- core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala | 1 - .../main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala | 7 ------- .../src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala | 4 ---- 4 files changed, 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala index 596f9adde9..1791242215 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala @@ -117,8 +117,4 @@ private[spark] class Pool( parent.decreaseRunningTasks(taskNum) } } - - override def hasPendingTasks(): Boolean = { - schedulableQueue.exists(_.hasPendingTasks()) - } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala index 1c7ea2dccc..d573e125a3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala @@ -42,5 +42,4 @@ private[spark] trait Schedulable { def executorLost(executorId: String, host: String): Unit def checkSpeculatableTasks(): Boolean def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] - def hasPendingTasks(): Boolean } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index dbac6b96ac..1b0f82fa24 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -365,13 +365,6 @@ private[spark] class TaskSchedulerImpl( } } - // Check for pending tasks in all our active jobs. - def hasPendingTasks: Boolean = { - synchronized { - rootPool.hasPendingTasks() - } - } - def executorLost(executorId: String, reason: ExecutorLossReason) { var failedExecutor: Option[String] = None diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index c676e73e03..9b95e418d8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -681,10 +681,6 @@ private[spark] class TaskSetManager( return foundTasks } - override def hasPendingTasks(): Boolean = { - numTasks > 0 && tasksSuccessful < numTasks - } - private def getLocalityWait(level: TaskLocality.TaskLocality): Long = { val defaultWait = System.getProperty("spark.locality.wait", "3000") level match { -- cgit v1.2.3 From e17d7518ab10c218c9576db5754c7fa4cb92688a Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Fri, 27 Dec 2013 15:51:27 -0800 Subject: Removed unused OtherFailure TaskEndReason. --- core/src/main/scala/org/apache/spark/TaskEndReason.scala | 2 -- core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala | 4 ---- 2 files changed, 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index c1e5e04b31..faf6dcd618 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -53,5 +53,3 @@ private[spark] case class ExceptionFailure( private[spark] case object TaskResultLost extends TaskEndReason private[spark] case object TaskKilled extends TaskEndReason - -private[spark] case class OtherFailure(message: String) extends TaskEndReason diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index 60927831a1..be5c95e59e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -328,10 +328,6 @@ class JobLogger(val user: String, val logDirName: String) task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" + mapId + " REDUCE_ID=" + reduceId stageLogInfo(task.stageId, taskStatus) - case OtherFailure(message) => - taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId + - " STAGE_ID=" + task.stageId + " INFO=" + message - stageLogInfo(task.stageId, taskStatus) case _ => } } -- cgit v1.2.3
    Executor IDAddress Task Time Total Tasks Failed Tasks
    {k}{executorIdToAddress.getOrElse(k, "CANNOT FIND ADDRESS")} {parent.formatDuration(v.taskTime)} {v.failedTasks + v.succeededTasks} {v.failedTasks}