diff options
90 files changed, 2721 insertions, 1808 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index c109ff930c..6f54fa7a5a 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -43,11 +43,10 @@ import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ import org.apache.spark.scheduler._ -import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedulerBackend, -SimrSchedulerBackend, SparkDeploySchedulerBackend} -import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, -MesosSchedulerBackend} -import org.apache.spark.scheduler.local.LocalScheduler +import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, + SparkDeploySchedulerBackend, SimrSchedulerBackend} +import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} +import org.apache.spark.scheduler.local.LocalBackend import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils} import org.apache.spark.ui.SparkUI import org.apache.spark.util._ @@ -560,9 +559,7 @@ class SparkContext( } addedFiles(key) = System.currentTimeMillis - // Fetch the file locally in case a job is executed locally. - // Jobs that run through LocalScheduler will already fetch the required dependencies, - // but jobs run in DAGScheduler.runLocally() will not so we must fetch the files here. + // Fetch the file locally in case a job is executed using DAGScheduler.runLocally(). Utils.fetchFile(path, new File(SparkFiles.getRootDirectory), conf) logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key)) @@ -1070,18 +1067,30 @@ object SparkContext { // Regular expression for connection to Simr cluster val SIMR_REGEX = """simr://(.*)""".r + // When running locally, don't try to re-execute tasks on failure. + val MAX_LOCAL_TASK_FAILURES = 1 + master match { case "local" => - new LocalScheduler(1, 0, sc) + val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) + val backend = new LocalBackend(scheduler, 1) + scheduler.initialize(backend) + scheduler case LOCAL_N_REGEX(threads) => - new LocalScheduler(threads.toInt, 0, sc) + val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) + val backend = new LocalBackend(scheduler, threads.toInt) + scheduler.initialize(backend) + scheduler case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => - new LocalScheduler(threads.toInt, maxFailures.toInt, sc) + val scheduler = new TaskSchedulerImpl(sc, maxFailures.toInt, isLocal = true) + val backend = new LocalBackend(scheduler, threads.toInt) + scheduler.initialize(backend) + scheduler case SPARK_REGEX(sparkUrl) => - val scheduler = new ClusterScheduler(sc) + val scheduler = new TaskSchedulerImpl(sc) val masterUrls = sparkUrl.split(",").map("spark://" + _) val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls, appName) scheduler.initialize(backend) @@ -1096,7 +1105,7 @@ object SparkContext { memoryPerSlaveInt, sc.executorMemory)) } - val scheduler = new ClusterScheduler(sc) + val scheduler = new TaskSchedulerImpl(sc) val localCluster = new LocalSparkCluster( numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt) val masterUrls = localCluster.start() @@ -1111,7 +1120,7 @@ object SparkContext { val scheduler = try { val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler") val cons = clazz.getConstructor(classOf[SparkContext]) - cons.newInstance(sc).asInstanceOf[ClusterScheduler] + cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl] } catch { // TODO: Enumerate the exact reasons why it can fail // But irrespective of it, it means we cannot proceed ! @@ -1127,7 +1136,7 @@ object SparkContext { val scheduler = try { val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientClusterScheduler") val cons = clazz.getConstructor(classOf[SparkContext]) - cons.newInstance(sc).asInstanceOf[ClusterScheduler] + cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl] } catch { case th: Throwable => { @@ -1137,7 +1146,7 @@ object SparkContext { val backend = try { val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend") - val cons = clazz.getConstructor(classOf[ClusterScheduler], classOf[SparkContext]) + val cons = clazz.getConstructor(classOf[TaskSchedulerImpl], classOf[SparkContext]) cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend] } catch { case th: Throwable => { @@ -1150,7 +1159,7 @@ object SparkContext { case mesosUrl @ MESOS_REGEX(_) => MesosNativeLibrary.load() - val scheduler = new ClusterScheduler(sc) + val scheduler = new TaskSchedulerImpl(sc) val coarseGrained = sc.conf.getOrElse("spark.mesos.coarse", "false").toBoolean val url = mesosUrl.stripPrefix("mesos://") // strip scheme from raw Mesos URLs val backend = if (coarseGrained) { @@ -1162,7 +1171,7 @@ object SparkContext { scheduler case SIMR_REGEX(simrUrl) => - val scheduler = new ClusterScheduler(sc) + val scheduler = new TaskSchedulerImpl(sc) val backend = new SimrSchedulerBackend(scheduler, sc, simrUrl) scheduler.initialize(backend) scheduler diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index c1e5e04b31..faf6dcd618 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -53,5 +53,3 @@ private[spark] case class ExceptionFailure( private[spark] case object TaskResultLost extends TaskEndReason private[spark] case object TaskKilled extends TaskEndReason - -private[spark] case class OtherFailure(message: String) extends TaskEndReason diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index ec47ba1b56..a801d85770 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -140,12 +140,12 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I <body> {linkToMaster} <div> - <div style="float:left;width:40%">{backButton}</div> + <div style="float:left; margin-right:10px">{backButton}</div> <div style="float:left;">{range}</div> - <div style="float:right;">{nextButton}</div> + <div style="float:right; margin-left:10px">{nextButton}</div> </div> <br /> - <div style="height:500px;overflow:auto;padding:5px;"> + <div style="height:500px; overflow:auto; padding:5px;"> <pre>{logText}</pre> </div> </body> diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 77aa24e6b6..e06e49d9d2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -152,7 +152,8 @@ class DAGScheduler( val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done val running = new HashSet[Stage] // Stages we are running right now val failed = new HashSet[Stage] // Stages that must be resubmitted due to fetch failures - val pendingTasks = new TimeStampedHashMap[Stage, HashSet[Task[_]]] // Missing tasks from each stage + // Missing tasks from each stage + val pendingTasks = new TimeStampedHashMap[Stage, HashSet[Task[_]]] var lastFetchFailureTime: Long = 0 // Used to wait a bit to avoid repeated resubmits val activeJobs = new HashSet[ActiveJob] @@ -240,7 +241,8 @@ class DAGScheduler( shuffleToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage case None => - val stage = newOrUsedStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, shuffleDep, jobId) + val stage = + newOrUsedStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, shuffleDep, jobId) shuffleToMapStage(shuffleDep.shuffleId) = stage stage } @@ -249,7 +251,8 @@ class DAGScheduler( /** * Create a Stage -- either directly for use as a result stage, or as part of the (re)-creation * of a shuffle map stage in newOrUsedStage. The stage will be associated with the provided - * jobId. Production of shuffle map stages should always use newOrUsedStage, not newStage directly. + * jobId. Production of shuffle map stages should always use newOrUsedStage, not newStage + * directly. */ private def newStage( rdd: RDD[_], @@ -359,7 +362,8 @@ class DAGScheduler( stageIdToJobIds.getOrElseUpdate(s.id, new HashSet[Int]()) += jobId jobIdToStageIds.getOrElseUpdate(jobId, new HashSet[Int]()) += s.id val parents = getParentStages(s.rdd, jobId) - val parentsWithoutThisJobId = parents.filter(p => !stageIdToJobIds.get(p.id).exists(_.contains(jobId))) + val parentsWithoutThisJobId = parents.filter(p => + !stageIdToJobIds.get(p.id).exists(_.contains(jobId))) updateJobIdStageIdMapsList(parentsWithoutThisJobId ++ stages.tail) } } @@ -367,8 +371,9 @@ class DAGScheduler( } /** - * Removes job and any stages that are not needed by any other job. Returns the set of ids for stages that - * were removed. The associated tasks for those stages need to be cancelled if we got here via job cancellation. + * Removes job and any stages that are not needed by any other job. Returns the set of ids for + * stages that were removed. The associated tasks for those stages need to be cancelled if we + * got here via job cancellation. */ private def removeJobAndIndependentStages(jobId: Int): Set[Int] = { val registeredStages = jobIdToStageIds(jobId) @@ -379,7 +384,8 @@ class DAGScheduler( stageIdToJobIds.filterKeys(stageId => registeredStages.contains(stageId)).foreach { case (stageId, jobSet) => if (!jobSet.contains(jobId)) { - logError("Job %d not registered for stage %d even though that stage was registered for the job" + logError( + "Job %d not registered for stage %d even though that stage was registered for the job" .format(jobId, stageId)) } else { def removeStage(stageId: Int) { @@ -390,7 +396,8 @@ class DAGScheduler( running -= s } stageToInfos -= s - shuffleToMapStage.keys.filter(shuffleToMapStage(_) == s).foreach(shuffleToMapStage.remove) + shuffleToMapStage.keys.filter(shuffleToMapStage(_) == s).foreach(shuffleId => + shuffleToMapStage.remove(shuffleId)) if (pendingTasks.contains(s) && !pendingTasks(s).isEmpty) { logDebug("Removing pending status for stage %d".format(stageId)) } @@ -408,7 +415,8 @@ class DAGScheduler( stageIdToStage -= stageId stageIdToJobIds -= stageId - logDebug("After removal of stage %d, remaining stages = %d".format(stageId, stageIdToStage.size)) + logDebug("After removal of stage %d, remaining stages = %d" + .format(stageId, stageIdToStage.size)) } jobSet -= jobId @@ -460,7 +468,8 @@ class DAGScheduler( assert(partitions.size > 0) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler) - eventProcessActor ! JobSubmitted(jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties) + eventProcessActor ! JobSubmitted( + jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties) waiter } @@ -495,7 +504,8 @@ class DAGScheduler( val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val partitions = (0 until rdd.partitions.size).toArray val jobId = nextJobId.getAndIncrement() - eventProcessActor ! JobSubmitted(jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties) + eventProcessActor ! JobSubmitted( + jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties) listener.awaitResult() // Will throw an exception if the job fails } @@ -530,8 +540,8 @@ class DAGScheduler( case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) => var finalStage: Stage = null try { - // New stage creation at times and if its not protected, the scheduler thread is killed. - // e.g. it can fail when jobs are run on HadoopRDD whose underlying hdfs files have been deleted + // New stage creation may throw an exception if, for example, jobs are run on a HadoopRDD + // whose underlying HDFS files have been deleted. finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite)) } catch { case e: Exception => @@ -564,7 +574,8 @@ class DAGScheduler( case JobGroupCancelled(groupId) => // Cancel all jobs belonging to this job group. // First finds all active jobs with this group id, and then kill stages for them. - val activeInGroup = activeJobs.filter(groupId == _.properties.get(SparkContext.SPARK_JOB_GROUP_ID)) + val activeInGroup = activeJobs.filter(activeJob => + groupId == activeJob.properties.get(SparkContext.SPARK_JOB_GROUP_ID)) val jobIds = activeInGroup.map(_.jobId) jobIds.foreach { handleJobCancellation } @@ -586,7 +597,8 @@ class DAGScheduler( stage <- stageIdToStage.get(task.stageId); stageInfo <- stageToInfos.get(stage) ) { - if (taskInfo.serializedSize > TASK_SIZE_TO_WARN * 1024 && !stageInfo.emittedTaskSizeWarning) { + if (taskInfo.serializedSize > TASK_SIZE_TO_WARN * 1024 && + !stageInfo.emittedTaskSizeWarning) { stageInfo.emittedTaskSizeWarning = true logWarning(("Stage %d (%s) contains a task of very large " + "size (%d KB). The maximum recommended task size is %d KB.").format( diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorLossReason.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala index 5077b2b48b..2bc43a9186 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorLossReason.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.scheduler.cluster +package org.apache.spark.scheduler import org.apache.spark.executor.ExecutorExitCode diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index 60927831a1..be5c95e59e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -328,10 +328,6 @@ class JobLogger(val user: String, val logDirName: String) task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" + mapId + " REDUCE_ID=" + reduceId stageLogInfo(task.stageId, taskStatus) - case OtherFailure(message) => - taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId + - " STAGE_ID=" + task.stageId + " INFO=" + message - stageLogInfo(task.stageId, taskStatus) case _ => } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala index 596f9adde9..1791242215 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala @@ -117,8 +117,4 @@ private[spark] class Pool( parent.decreaseRunningTasks(taskNum) } } - - override def hasPendingTasks(): Boolean = { - schedulableQueue.exists(_.hasPendingTasks()) - } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala index 1c7ea2dccc..d573e125a3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala @@ -42,5 +42,4 @@ private[spark] trait Schedulable { def executorLost(executorId: String, host: String): Unit def checkSpeculatableTasks(): Boolean def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] - def hasPendingTasks(): Boolean } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala index 65d3fc8187..02bdbba825 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala @@ -15,12 +15,12 @@ * limitations under the License. */ -package org.apache.spark.scheduler.cluster +package org.apache.spark.scheduler import org.apache.spark.SparkContext /** - * A backend interface for cluster scheduling systems that allows plugging in different ones under + * A backend interface for scheduling systems that allows plugging in different ones under * ClusterScheduler. We assume a Mesos-like model where the application gets resource offers as * machines become available and can launch tasks on them. */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 3841b5616d..ee63b3c4a1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -63,7 +63,7 @@ trait SparkListener { * Called when a task begins remotely fetching its result (will not be called for tasks that do * not need to fetch the result remotely). */ - def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { } + def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { } /** * Called when a task ends @@ -131,8 +131,8 @@ object StatsReportListener extends Logging { def showDistribution(heading: String, d: Distribution, formatNumber: Double => String) { val stats = d.statCounter - logInfo(heading + stats) val quantiles = d.getQuantiles(probabilities).map{formatNumber} + logInfo(heading + stats) logInfo(percentilesHeader) logInfo("\t" + quantiles.mkString("\t")) } @@ -173,8 +173,6 @@ object StatsReportListener extends Logging { showMillisDistribution(heading, extractLongDistribution(stage, getMetric)) } - - val seconds = 1000L val minutes = seconds * 60 val hours = minutes * 60 @@ -198,7 +196,6 @@ object StatsReportListener extends Logging { } - case class RuntimePercentage(executorPct: Double, fetchPct: Option[Double], other: Double) object RuntimePercentage { def apply(totalTime: Long, metrics: TaskMetrics): RuntimePercentage = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index d5824e7954..85687ea330 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -91,4 +91,3 @@ private[spark] class SparkListenerBus() extends Logging { return true
}
}
-
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 319c91b933..29b0247f8a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -15,21 +15,20 @@ * limitations under the License. */ -package org.apache.spark.scheduler.cluster +package org.apache.spark.scheduler import java.nio.ByteBuffer import java.util.concurrent.{LinkedBlockingDeque, ThreadFactory, ThreadPoolExecutor, TimeUnit} import org.apache.spark._ import org.apache.spark.TaskState.TaskState -import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult} import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util.Utils /** * Runs a thread pool that deserializes and remotely fetches (if necessary) task results. */ -private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler) +private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedulerImpl) extends Logging { private val THREADS = sparkEnv.conf.getOrElse("spark.resultGetter.threads", "4").toInt @@ -43,7 +42,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche } def enqueueSuccessfulTask( - taskSetManager: ClusterTaskSetManager, tid: Long, serializedData: ByteBuffer) { + taskSetManager: TaskSetManager, tid: Long, serializedData: ByteBuffer) { getTaskResultExecutor.execute(new Runnable { override def run() { try { @@ -79,7 +78,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche }) } - def enqueueFailedTask(taskSetManager: ClusterTaskSetManager, tid: Long, taskState: TaskState, + def enqueueFailedTask(taskSetManager: TaskSetManager, tid: Long, taskState: TaskState, serializedData: ByteBuffer) { var reason: Option[TaskEndReason] = None getTaskResultExecutor.execute(new Runnable { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 10e0478108..17b6d97e90 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -20,11 +20,12 @@ package org.apache.spark.scheduler import org.apache.spark.scheduler.SchedulingMode.SchedulingMode /** - * Low-level task scheduler interface, implemented by both ClusterScheduler and LocalScheduler. - * Each TaskScheduler schedulers task for a single SparkContext. - * These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage, - * and are responsible for sending the tasks to the cluster, running them, retrying if there - * are failures, and mitigating stragglers. They return events to the DAGScheduler. + * Low-level task scheduler interface, currently implemented exclusively by the ClusterScheduler. + * This interface allows plugging in different task schedulers. Each TaskScheduler schedulers tasks + * for a single SparkContext. These schedulers get sets of tasks submitted to them from the + * DAGScheduler for each stage, and are responsible for sending the tasks to the cluster, running + * them, retrying if there are failures, and mitigating stragglers. They return events to the + * DAGScheduler. */ private[spark] trait TaskScheduler { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 2707740d44..56a038dc69 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.scheduler.cluster +package org.apache.spark.scheduler import java.nio.ByteBuffer import java.util.concurrent.atomic.AtomicLong @@ -28,37 +28,40 @@ import scala.concurrent.duration._ import org.apache.spark._ import org.apache.spark.TaskState.TaskState -import org.apache.spark.scheduler._ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode /** - * The main TaskScheduler implementation, for running tasks on a cluster. Clients should first call - * initialize() and start(), then submit task sets through the runTasks method. - * - * This class can work with multiple types of clusters by acting through a SchedulerBackend. + * Schedules tasks for multiple types of clusters by acting through a SchedulerBackend. + * It can also work with a local setup by using a LocalBackend and setting isLocal to true. * It handles common logic, like determining a scheduling order across jobs, waking up to launch * speculative tasks, etc. * + * Clients should first call initialize() and start(), then submit task sets through the + * runTasks method. + * * THREADING: SchedulerBackends and task-submitting clients can call this class from multiple * threads, so it needs locks in public API methods to maintain its state. In addition, some * SchedulerBackends sycnchronize on themselves when they want to send events here, and then * acquire a lock on us, so we need to make sure that we don't try to lock the backend while * we are holding a lock on ourselves. */ -private[spark] class ClusterScheduler(val sc: SparkContext) - extends TaskScheduler - with Logging +private[spark] class TaskSchedulerImpl( + val sc: SparkContext, + val maxTaskFailures: Int = System.getProperty("spark.task.maxFailures", "4").toInt, + isLocal: Boolean = false) + extends TaskScheduler with Logging { val conf = sc.conf + // How often to check for speculative tasks val SPECULATION_INTERVAL = conf.getOrElse("spark.speculation.interval", "100").toLong // Threshold above which we warn user initial TaskSet may be starved val STARVATION_TIMEOUT = conf.getOrElse("spark.starvation.timeout", "15000").toLong - // ClusterTaskSetManagers are not thread safe, so any access to one should be synchronized + // TaskSetManagers are not thread safe, so any access to one should be synchronized // on this class. - val activeTaskSets = new HashMap[String, ClusterTaskSetManager] + val activeTaskSets = new HashMap[String, TaskSetManager] val taskIdToTaskSetId = new HashMap[Long, String] val taskIdToExecutorId = new HashMap[Long, String] @@ -120,7 +123,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) override def start() { backend.start() - if (conf.getOrElse("spark.speculation", "false").toBoolean) { + if (!isLocal && conf.getOrElse("spark.speculation", "false").toBoolean) { logInfo("Starting speculative execution thread") import sc.env.actorSystem.dispatcher sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL milliseconds, @@ -134,12 +137,12 @@ private[spark] class ClusterScheduler(val sc: SparkContext) val tasks = taskSet.tasks logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") this.synchronized { - val manager = new ClusterTaskSetManager(this, taskSet) + val manager = new TaskSetManager(this, taskSet, maxTaskFailures) activeTaskSets(taskSet.id) = manager schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) taskSetTaskIds(taskSet.id) = new HashSet[Long]() - if (!hasReceivedTask) { + if (!isLocal && !hasReceivedTask) { starvationTimer.scheduleAtFixedRate(new TimerTask() { override def run() { if (!hasLaunchedTask) { @@ -293,19 +296,19 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } } - def handleTaskGettingResult(taskSetManager: ClusterTaskSetManager, tid: Long) { + def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long) { taskSetManager.handleTaskGettingResult(tid) } def handleSuccessfulTask( - taskSetManager: ClusterTaskSetManager, + taskSetManager: TaskSetManager, tid: Long, taskResult: DirectTaskResult[_]) = synchronized { taskSetManager.handleSuccessfulTask(tid, taskResult) } def handleFailedTask( - taskSetManager: ClusterTaskSetManager, + taskSetManager: TaskSetManager, tid: Long, taskState: TaskState, reason: Option[TaskEndReason]) = synchronized { @@ -353,7 +356,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext) override def defaultParallelism() = backend.defaultParallelism() - // Check for speculatable tasks in all our active jobs. def checkSpeculatableTasks() { var shouldRevive = false @@ -365,13 +367,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } } - // Check for pending tasks in all our active jobs. - def hasPendingTasks: Boolean = { - synchronized { - rootPool.hasPendingTasks() - } - } - def executorLost(executorId: String, reason: ExecutorLossReason) { var failedExecutor: Option[String] = None @@ -430,7 +425,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } -object ClusterScheduler { +private[spark] object TaskSchedulerImpl { /** * Used to balance containers across hosts. * diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 90f6bcefac..9b95e418d8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -17,32 +17,702 @@ package org.apache.spark.scheduler -import java.nio.ByteBuffer +import java.io.NotSerializableException +import java.util.Arrays +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet +import scala.math.max +import scala.math.min + +import org.apache.spark.{ExceptionFailure, FetchFailed, Logging, Resubmitted, SparkEnv, + Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState} import org.apache.spark.TaskState.TaskState +import org.apache.spark.util.{Clock, SystemClock} + /** - * Tracks and schedules the tasks within a single TaskSet. This class keeps track of the status of - * each task and is responsible for retries on failure and locality. The main interfaces to it - * are resourceOffer, which asks the TaskSet whether it wants to run a task on one node, and - * statusUpdate, which tells it that one of its tasks changed state (e.g. finished). + * Schedules the tasks within a single TaskSet in the ClusterScheduler. This class keeps track of + * each task, retries tasks if they fail (up to a limited number of times), and + * handles locality-aware scheduling for this TaskSet via delay scheduling. The main interfaces + * to it are resourceOffer, which asks the TaskSet whether it wants to run a task on one node, + * and statusUpdate, which tells it that one of its tasks changed state (e.g. finished). + * + * THREADING: This class is designed to only be called from code with a lock on the + * TaskScheduler (e.g. its event handlers). It should not be called from other threads. * - * THREADING: This class is designed to only be called from code with a lock on the TaskScheduler - * (e.g. its event handlers). It should not be called from other threads. + * @param sched the ClusterScheduler associated with the TaskSetManager + * @param taskSet the TaskSet to manage scheduling for + * @param maxTaskFailures if any particular task fails more than this number of times, the entire + * task set will be aborted */ -private[spark] trait TaskSetManager extends Schedulable { - def schedulableQueue = null - - def schedulingMode = SchedulingMode.NONE - - def taskSet: TaskSet +private[spark] class TaskSetManager( + sched: TaskSchedulerImpl, + val taskSet: TaskSet, + val maxTaskFailures: Int, + clock: Clock = SystemClock) + extends Schedulable with Logging +{ + // CPUs to request per task + val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toInt + + // Quantile of tasks at which to start speculation + val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble + val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble + + // Serializer for closures and tasks. + val env = SparkEnv.get + val ser = env.closureSerializer.newInstance() + + val tasks = taskSet.tasks + val numTasks = tasks.length + val copiesRunning = new Array[Int](numTasks) + val successful = new Array[Boolean](numTasks) + val numFailures = new Array[Int](numTasks) + val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) + var tasksSuccessful = 0 + + var weight = 1 + var minShare = 0 + var priority = taskSet.priority + var stageId = taskSet.stageId + var name = "TaskSet_"+taskSet.stageId.toString + var parent: Pool = null + + var runningTasks = 0 + private val runningTasksSet = new HashSet[Long] + + // Set of pending tasks for each executor. These collections are actually + // treated as stacks, in which new tasks are added to the end of the + // ArrayBuffer and removed from the end. This makes it faster to detect + // tasks that repeatedly fail because whenever a task failed, it is put + // back at the head of the stack. They are also only cleaned up lazily; + // when a task is launched, it remains in all the pending lists except + // the one that it was launched from, but gets removed from them later. + private val pendingTasksForExecutor = new HashMap[String, ArrayBuffer[Int]] + + // Set of pending tasks for each host. Similar to pendingTasksForExecutor, + // but at host level. + private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]] + + // Set of pending tasks for each rack -- similar to the above. + private val pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]] + + // Set containing pending tasks with no locality preferences. + val pendingTasksWithNoPrefs = new ArrayBuffer[Int] + + // Set containing all pending tasks (also used as a stack, as above). + val allPendingTasks = new ArrayBuffer[Int] + + // Tasks that can be speculated. Since these will be a small fraction of total + // tasks, we'll just hold them in a HashSet. + val speculatableTasks = new HashSet[Int] + + // Task index, start and finish time for each task attempt (indexed by task ID) + val taskInfos = new HashMap[Long, TaskInfo] + + // Did the TaskSet fail? + var failed = false + var causeOfFailure = "" + + // How frequently to reprint duplicate exceptions in full, in milliseconds + val EXCEPTION_PRINT_INTERVAL = + System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong + + // Map of recent exceptions (identified by string representation and top stack frame) to + // duplicate count (how many times the same exception has appeared) and time the full exception + // was printed. This should ideally be an LRU map that can drop old exceptions automatically. + val recentExceptions = HashMap[String, (Int, Long)]() + + // Figure out the current map output tracker epoch and set it on all tasks + val epoch = sched.mapOutputTracker.getEpoch + logDebug("Epoch for " + taskSet + ": " + epoch) + for (t <- tasks) { + t.epoch = epoch + } + + // Add all our tasks to the pending lists. We do this in reverse order + // of task index so that tasks with low indices get launched first. + for (i <- (0 until numTasks).reverse) { + addPendingTask(i) + } + + // Figure out which locality levels we have in our TaskSet, so we can do delay scheduling + val myLocalityLevels = computeValidLocalityLevels() + val localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level + + // Delay scheduling variables: we keep track of our current locality level and the time we + // last launched a task at that level, and move up a level when localityWaits[curLevel] expires. + // We then move down if we manage to launch a "more local" task. + var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels + var lastLaunchTime = clock.getTime() // Time we last launched a task at this level + + override def schedulableQueue = null + + override def schedulingMode = SchedulingMode.NONE + + /** + * Add a task to all the pending-task lists that it should be on. If readding is set, we are + * re-adding the task so only include it in each list if it's not already there. + */ + private def addPendingTask(index: Int, readding: Boolean = false) { + // Utility method that adds `index` to a list only if readding=false or it's not already there + def addTo(list: ArrayBuffer[Int]) { + if (!readding || !list.contains(index)) { + list += index + } + } + + var hadAliveLocations = false + for (loc <- tasks(index).preferredLocations) { + for (execId <- loc.executorId) { + if (sched.isExecutorAlive(execId)) { + addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer)) + hadAliveLocations = true + } + } + if (sched.hasExecutorsAliveOnHost(loc.host)) { + addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer)) + for (rack <- sched.getRackForHost(loc.host)) { + addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer)) + } + hadAliveLocations = true + } + } + + if (!hadAliveLocations) { + // Even though the task might've had preferred locations, all of those hosts or executors + // are dead; put it in the no-prefs list so we can schedule it elsewhere right away. + addTo(pendingTasksWithNoPrefs) + } + + if (!readding) { + allPendingTasks += index // No point scanning this whole list to find the old task there + } + } + + /** + * Return the pending tasks list for a given executor ID, or an empty list if + * there is no map entry for that host + */ + private def getPendingTasksForExecutor(executorId: String): ArrayBuffer[Int] = { + pendingTasksForExecutor.getOrElse(executorId, ArrayBuffer()) + } + + /** + * Return the pending tasks list for a given host, or an empty list if + * there is no map entry for that host + */ + private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = { + pendingTasksForHost.getOrElse(host, ArrayBuffer()) + } + + /** + * Return the pending rack-local task list for a given rack, or an empty list if + * there is no map entry for that rack + */ + private def getPendingTasksForRack(rack: String): ArrayBuffer[Int] = { + pendingTasksForRack.getOrElse(rack, ArrayBuffer()) + } + + /** + * Dequeue a pending task from the given list and return its index. + * Return None if the list is empty. + * This method also cleans up any tasks in the list that have already + * been launched, since we want that to happen lazily. + */ + private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = { + while (!list.isEmpty) { + val index = list.last + list.trimEnd(1) + if (copiesRunning(index) == 0 && !successful(index)) { + return Some(index) + } + } + return None + } + + /** Check whether a task is currently running an attempt on a given host */ + private def hasAttemptOnHost(taskIndex: Int, host: String): Boolean = { + !taskAttempts(taskIndex).exists(_.host == host) + } + + /** + * Return a speculative task for a given executor if any are available. The task should not have + * an attempt running on this host, in case the host is slow. In addition, the task should meet + * the given locality constraint. + */ + private def findSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value) + : Option[(Int, TaskLocality.Value)] = + { + speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set + + if (!speculatableTasks.isEmpty) { + // Check for process-local or preference-less tasks; note that tasks can be process-local + // on multiple nodes when we replicate cached blocks, as in Spark Streaming + for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { + val prefs = tasks(index).preferredLocations + val executors = prefs.flatMap(_.executorId) + if (prefs.size == 0 || executors.contains(execId)) { + speculatableTasks -= index + return Some((index, TaskLocality.PROCESS_LOCAL)) + } + } + + // Check for node-local tasks + if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) { + for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { + val locations = tasks(index).preferredLocations.map(_.host) + if (locations.contains(host)) { + speculatableTasks -= index + return Some((index, TaskLocality.NODE_LOCAL)) + } + } + } + // Check for rack-local tasks + if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { + for (rack <- sched.getRackForHost(host)) { + for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { + val racks = tasks(index).preferredLocations.map(_.host).map(sched.getRackForHost) + if (racks.contains(rack)) { + speculatableTasks -= index + return Some((index, TaskLocality.RACK_LOCAL)) + } + } + } + } + + // Check for non-local tasks + if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { + for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { + speculatableTasks -= index + return Some((index, TaskLocality.ANY)) + } + } + } + + return None + } + + /** + * Dequeue a pending task for a given node and return its index and locality level. + * Only search for tasks matching the given locality constraint. + */ + private def findTask(execId: String, host: String, locality: TaskLocality.Value) + : Option[(Int, TaskLocality.Value)] = + { + for (index <- findTaskFromList(getPendingTasksForExecutor(execId))) { + return Some((index, TaskLocality.PROCESS_LOCAL)) + } + + if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) { + for (index <- findTaskFromList(getPendingTasksForHost(host))) { + return Some((index, TaskLocality.NODE_LOCAL)) + } + } + + if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { + for { + rack <- sched.getRackForHost(host) + index <- findTaskFromList(getPendingTasksForRack(rack)) + } { + return Some((index, TaskLocality.RACK_LOCAL)) + } + } + + // Look for no-pref tasks after rack-local tasks since they can run anywhere. + for (index <- findTaskFromList(pendingTasksWithNoPrefs)) { + return Some((index, TaskLocality.PROCESS_LOCAL)) + } + + if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { + for (index <- findTaskFromList(allPendingTasks)) { + return Some((index, TaskLocality.ANY)) + } + } + + // Finally, if all else has failed, find a speculative task + return findSpeculativeTask(execId, host, locality) + } + + /** + * Respond to an offer of a single executor from the scheduler by finding a task + */ def resourceOffer( execId: String, host: String, availableCpus: Int, maxLocality: TaskLocality.TaskLocality) - : Option[TaskDescription] + : Option[TaskDescription] = + { + if (tasksSuccessful < numTasks && availableCpus >= CPUS_PER_TASK) { + val curTime = clock.getTime() + + var allowedLocality = getAllowedLocalityLevel(curTime) + if (allowedLocality > maxLocality) { + allowedLocality = maxLocality // We're not allowed to search for farther-away tasks + } + + findTask(execId, host, allowedLocality) match { + case Some((index, taskLocality)) => { + // Found a task; do some bookkeeping and return a task description + val task = tasks(index) + val taskId = sched.newTaskId() + // Figure out whether this should count as a preferred launch + logInfo("Starting task %s:%d as TID %s on executor %s: %s (%s)".format( + taskSet.id, index, taskId, execId, host, taskLocality)) + // Do various bookkeeping + copiesRunning(index) += 1 + val info = new TaskInfo(taskId, index, curTime, execId, host, taskLocality) + taskInfos(taskId) = info + taskAttempts(index) = info :: taskAttempts(index) + // Update our locality level for delay scheduling + currentLocalityIndex = getLocalityIndex(taskLocality) + lastLaunchTime = curTime + // Serialize and return the task + val startTime = clock.getTime() + // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here + // we assume the task can be serialized without exceptions. + val serializedTask = Task.serializeWithDependencies( + task, sched.sc.addedFiles, sched.sc.addedJars, ser) + val timeTaken = clock.getTime() - startTime + addRunningTask(taskId) + logInfo("Serialized task %s:%d as %d bytes in %d ms".format( + taskSet.id, index, serializedTask.limit, timeTaken)) + val taskName = "task %s:%d".format(taskSet.id, index) + if (taskAttempts(index).size == 1) + taskStarted(task,info) + return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask)) + } + case _ => + } + } + return None + } + + /** + * Get the level we can launch tasks according to delay scheduling, based on current wait time. + */ + private def getAllowedLocalityLevel(curTime: Long): TaskLocality.TaskLocality = { + while (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex) && + currentLocalityIndex < myLocalityLevels.length - 1) + { + // Jump to the next locality level, and remove our waiting time for the current one since + // we don't want to count it again on the next one + lastLaunchTime += localityWaits(currentLocalityIndex) + currentLocalityIndex += 1 + } + myLocalityLevels(currentLocalityIndex) + } + + /** + * Find the index in myLocalityLevels for a given locality. This is also designed to work with + * localities that are not in myLocalityLevels (in case we somehow get those) by returning the + * next-biggest level we have. Uses the fact that the last value in myLocalityLevels is ANY. + */ + def getLocalityIndex(locality: TaskLocality.TaskLocality): Int = { + var index = 0 + while (locality > myLocalityLevels(index)) { + index += 1 + } + index + } + + private def taskStarted(task: Task[_], info: TaskInfo) { + sched.dagScheduler.taskStarted(task, info) + } + + def handleTaskGettingResult(tid: Long) = { + val info = taskInfos(tid) + info.markGettingResult() + sched.dagScheduler.taskGettingResult(tasks(info.index), info) + } + + /** + * Marks the task as successful and notifies the DAGScheduler that a task has ended. + */ + def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = { + val info = taskInfos(tid) + val index = info.index + info.markSuccessful() + removeRunningTask(tid) + if (!successful(index)) { + logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format( + tid, info.duration, info.host, tasksSuccessful, numTasks)) + sched.dagScheduler.taskEnded( + tasks(index), Success, result.value, result.accumUpdates, info, result.metrics) + + // Mark successful and stop if all the tasks have succeeded. + tasksSuccessful += 1 + successful(index) = true + if (tasksSuccessful == numTasks) { + sched.taskSetFinished(this) + } + } else { + logInfo("Ignorning task-finished event for TID " + tid + " because task " + + index + " has already completed successfully") + } + } + + /** + * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the + * DAG Scheduler. + */ + def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) { + val info = taskInfos(tid) + if (info.failed) { + return + } + removeRunningTask(tid) + val index = info.index + info.markFailed() + var failureReason = "unknown" + if (!successful(index)) { + logWarning("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) + copiesRunning(index) -= 1 + // Check if the problem is a map output fetch failure. In that case, this + // task will never succeed on any node, so tell the scheduler about it. + reason.foreach { + case fetchFailed: FetchFailed => + logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress) + sched.dagScheduler.taskEnded(tasks(index), fetchFailed, null, null, info, null) + successful(index) = true + tasksSuccessful += 1 + sched.taskSetFinished(this) + removeAllRunningTasks() + return + + case TaskKilled => + logWarning("Task %d was killed.".format(tid)) + sched.dagScheduler.taskEnded(tasks(index), reason.get, null, null, info, null) + return + + case ef: ExceptionFailure => + sched.dagScheduler.taskEnded( + tasks(index), ef, null, null, info, ef.metrics.getOrElse(null)) + if (ef.className == classOf[NotSerializableException].getName()) { + // If the task result wasn't rerializable, there's no point in trying to re-execute it. + logError("Task %s:%s had a not serializable result: %s; not retrying".format( + taskSet.id, index, ef.description)) + abort("Task %s:%s had a not serializable result: %s".format( + taskSet.id, index, ef.description)) + return + } + val key = ef.description + failureReason = "Exception failure: %s".format(ef.description) + val now = clock.getTime() + val (printFull, dupCount) = { + if (recentExceptions.contains(key)) { + val (dupCount, printTime) = recentExceptions(key) + if (now - printTime > EXCEPTION_PRINT_INTERVAL) { + recentExceptions(key) = (0, now) + (true, 0) + } else { + recentExceptions(key) = (dupCount + 1, printTime) + (false, dupCount + 1) + } + } else { + recentExceptions(key) = (0, now) + (true, 0) + } + } + if (printFull) { + val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString)) + logWarning("Loss was due to %s\n%s\n%s".format( + ef.className, ef.description, locs.mkString("\n"))) + } else { + logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount)) + } + + case TaskResultLost => + failureReason = "Lost result for TID %s on host %s".format(tid, info.host) + logWarning(failureReason) + sched.dagScheduler.taskEnded(tasks(index), TaskResultLost, null, null, info, null) + + case _ => {} + } + // On non-fetch failures, re-enqueue the task as pending for a max number of retries + addPendingTask(index) + if (state != TaskState.KILLED) { + numFailures(index) += 1 + if (numFailures(index) >= maxTaskFailures) { + logError("Task %s:%d failed %d times; aborting job".format( + taskSet.id, index, maxTaskFailures)) + abort("Task %s:%d failed %d times (most recent failure: %s)".format( + taskSet.id, index, maxTaskFailures, failureReason)) + } + } + } else { + logInfo("Ignoring task-lost event for TID " + tid + + " because task " + index + " is already finished") + } + } + + def error(message: String) { + // Save the error message + abort("Error: " + message) + } + + def abort(message: String) { + failed = true + causeOfFailure = message + // TODO: Kill running tasks if we were not terminated due to a Mesos error + sched.dagScheduler.taskSetFailed(taskSet, message) + removeAllRunningTasks() + sched.taskSetFinished(this) + } + + /** If the given task ID is not in the set of running tasks, adds it. + * + * Used to keep track of the number of running tasks, for enforcing scheduling policies. + */ + def addRunningTask(tid: Long) { + if (runningTasksSet.add(tid) && parent != null) { + parent.increaseRunningTasks(1) + } + runningTasks = runningTasksSet.size + } + + /** If the given task ID is in the set of running tasks, removes it. */ + def removeRunningTask(tid: Long) { + if (runningTasksSet.remove(tid) && parent != null) { + parent.decreaseRunningTasks(1) + } + runningTasks = runningTasksSet.size + } + + private[scheduler] def removeAllRunningTasks() { + val numRunningTasks = runningTasksSet.size + runningTasksSet.clear() + if (parent != null) { + parent.decreaseRunningTasks(numRunningTasks) + } + runningTasks = 0 + } + + override def getSchedulableByName(name: String): Schedulable = { + return null + } + + override def addSchedulable(schedulable: Schedulable) {} + + override def removeSchedulable(schedulable: Schedulable) {} + + override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { + var sortedTaskSetQueue = ArrayBuffer[TaskSetManager](this) + sortedTaskSetQueue += this + return sortedTaskSetQueue + } + + /** Called by TaskScheduler when an executor is lost so we can re-enqueue our tasks */ + override def executorLost(execId: String, host: String) { + logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id) + + // Re-enqueue pending tasks for this host based on the status of the cluster -- for example, a + // task that used to have locations on only this host might now go to the no-prefs list. Note + // that it's okay if we add a task to the same queue twice (if it had multiple preferred + // locations), because findTaskFromList will skip already-running tasks. + for (index <- getPendingTasksForExecutor(execId)) { + addPendingTask(index, readding=true) + } + for (index <- getPendingTasksForHost(host)) { + addPendingTask(index, readding=true) + } + + // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage + if (tasks(0).isInstanceOf[ShuffleMapTask]) { + for ((tid, info) <- taskInfos if info.executorId == execId) { + val index = taskInfos(tid).index + if (successful(index)) { + successful(index) = false + copiesRunning(index) -= 1 + tasksSuccessful -= 1 + addPendingTask(index) + // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our + // stage finishes when a total of tasks.size tasks finish. + sched.dagScheduler.taskEnded(tasks(index), Resubmitted, null, null, info, null) + } + } + } + // Also re-enqueue any tasks that were running on the node + for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { + handleFailedTask(tid, TaskState.KILLED, None) + } + } + + /** + * Check for tasks to be speculated and return true if there are any. This is called periodically + * by the TaskScheduler. + * + * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that + * we don't scan the whole task set. It might also help to make this sorted by launch time. + */ + override def checkSpeculatableTasks(): Boolean = { + // Can't speculate if we only have one task, or if all tasks have finished. + if (numTasks == 1 || tasksSuccessful == numTasks) { + return false + } + var foundTasks = false + val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt + logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation) + if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) { + val time = clock.getTime() + val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray + Arrays.sort(durations) + val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.size - 1)) + val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100) + // TODO: Threshold should also look at standard deviation of task durations and have a lower + // bound based on that. + logDebug("Task length threshold for speculation: " + threshold) + for ((tid, info) <- taskInfos) { + val index = info.index + if (!successful(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold && + !speculatableTasks.contains(index)) { + logInfo( + "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format( + taskSet.id, index, info.host, threshold)) + speculatableTasks += index + foundTasks = true + } + } + } + return foundTasks + } + + private def getLocalityWait(level: TaskLocality.TaskLocality): Long = { + val defaultWait = System.getProperty("spark.locality.wait", "3000") + level match { + case TaskLocality.PROCESS_LOCAL => + System.getProperty("spark.locality.wait.process", defaultWait).toLong + case TaskLocality.NODE_LOCAL => + System.getProperty("spark.locality.wait.node", defaultWait).toLong + case TaskLocality.RACK_LOCAL => + System.getProperty("spark.locality.wait.rack", defaultWait).toLong + case TaskLocality.ANY => + 0L + } + } - def error(message: String) + /** + * Compute the locality levels used in this TaskSet. Assumes that all tasks have already been + * added to queues using addPendingTask. + */ + private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = { + import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY} + val levels = new ArrayBuffer[TaskLocality.TaskLocality] + if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0) { + levels += PROCESS_LOCAL + } + if (!pendingTasksForHost.isEmpty && getLocalityWait(NODE_LOCAL) != 0) { + levels += NODE_LOCAL + } + if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0) { + levels += RACK_LOCAL + } + levels += ANY + logDebug("Valid locality levels for " + taskSet + ": " + levels.mkString(", ")) + levels.toArray + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/WorkerOffer.scala b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala index 938f62883a..ba6bab3f91 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/WorkerOffer.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.scheduler.cluster +package org.apache.spark.scheduler /** * Represents free resources available on an executor. diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala deleted file mode 100644 index a46b16b92f..0000000000 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala +++ /dev/null @@ -1,714 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.cluster - -import java.io.NotSerializableException -import java.util.Arrays - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet -import scala.math.max -import scala.math.min - -import org.apache.spark.{ExceptionFailure, FetchFailed, Logging, Resubmitted, SparkEnv, - Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState} -import org.apache.spark.TaskState.TaskState -import org.apache.spark.scheduler._ -import org.apache.spark.util.{SystemClock, Clock} - - -/** - * Schedules the tasks within a single TaskSet in the ClusterScheduler. This class keeps track of - * the status of each task, retries tasks if they fail (up to a limited number of times), and - * handles locality-aware scheduling for this TaskSet via delay scheduling. The main interfaces - * to it are resourceOffer, which asks the TaskSet whether it wants to run a task on one node, - * and statusUpdate, which tells it that one of its tasks changed state (e.g. finished). - * - * THREADING: This class is designed to only be called from code with a lock on the - * ClusterScheduler (e.g. its event handlers). It should not be called from other threads. - */ -private[spark] class ClusterTaskSetManager( - sched: ClusterScheduler, - val taskSet: TaskSet, - clock: Clock = SystemClock) - extends TaskSetManager - with Logging -{ - val conf = sched.sc.conf - // CPUs to request per task - val CPUS_PER_TASK = conf.getOrElse("spark.task.cpus", "1").toInt - - // Maximum times a task is allowed to fail before failing the job - val MAX_TASK_FAILURES = conf.getOrElse("spark.task.maxFailures", "4").toInt - - // Quantile of tasks at which to start speculation - val SPECULATION_QUANTILE = conf.getOrElse("spark.speculation.quantile", "0.75").toDouble - val SPECULATION_MULTIPLIER = conf.getOrElse("spark.speculation.multiplier", "1.5").toDouble - - // Serializer for closures and tasks. - val env = SparkEnv.get - val ser = env.closureSerializer.newInstance() - - val tasks = taskSet.tasks - val numTasks = tasks.length - val copiesRunning = new Array[Int](numTasks) - val successful = new Array[Boolean](numTasks) - val numFailures = new Array[Int](numTasks) - val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) - var tasksSuccessful = 0 - - var weight = 1 - var minShare = 0 - var priority = taskSet.priority - var stageId = taskSet.stageId - var name = "TaskSet_"+taskSet.stageId.toString - var parent: Pool = null - - var runningTasks = 0 - private val runningTasksSet = new HashSet[Long] - - // Set of pending tasks for each executor. These collections are actually - // treated as stacks, in which new tasks are added to the end of the - // ArrayBuffer and removed from the end. This makes it faster to detect - // tasks that repeatedly fail because whenever a task failed, it is put - // back at the head of the stack. They are also only cleaned up lazily; - // when a task is launched, it remains in all the pending lists except - // the one that it was launched from, but gets removed from them later. - private val pendingTasksForExecutor = new HashMap[String, ArrayBuffer[Int]] - - // Set of pending tasks for each host. Similar to pendingTasksForExecutor, - // but at host level. - private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]] - - // Set of pending tasks for each rack -- similar to the above. - private val pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]] - - // Set containing pending tasks with no locality preferences. - val pendingTasksWithNoPrefs = new ArrayBuffer[Int] - - // Set containing all pending tasks (also used as a stack, as above). - val allPendingTasks = new ArrayBuffer[Int] - - // Tasks that can be speculated. Since these will be a small fraction of total - // tasks, we'll just hold them in a HashSet. - val speculatableTasks = new HashSet[Int] - - // Task index, start and finish time for each task attempt (indexed by task ID) - val taskInfos = new HashMap[Long, TaskInfo] - - // Did the TaskSet fail? - var failed = false - var causeOfFailure = "" - - // How frequently to reprint duplicate exceptions in full, in milliseconds - val EXCEPTION_PRINT_INTERVAL = - conf.getOrElse("spark.logging.exceptionPrintInterval", "10000").toLong - - // Map of recent exceptions (identified by string representation and top stack frame) to - // duplicate count (how many times the same exception has appeared) and time the full exception - // was printed. This should ideally be an LRU map that can drop old exceptions automatically. - val recentExceptions = HashMap[String, (Int, Long)]() - - // Figure out the current map output tracker epoch and set it on all tasks - val epoch = sched.mapOutputTracker.getEpoch - logDebug("Epoch for " + taskSet + ": " + epoch) - for (t <- tasks) { - t.epoch = epoch - } - - // Add all our tasks to the pending lists. We do this in reverse order - // of task index so that tasks with low indices get launched first. - for (i <- (0 until numTasks).reverse) { - addPendingTask(i) - } - - // Figure out which locality levels we have in our TaskSet, so we can do delay scheduling - val myLocalityLevels = computeValidLocalityLevels() - val localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level - - // Delay scheduling variables: we keep track of our current locality level and the time we - // last launched a task at that level, and move up a level when localityWaits[curLevel] expires. - // We then move down if we manage to launch a "more local" task. - var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels - var lastLaunchTime = clock.getTime() // Time we last launched a task at this level - - /** - * Add a task to all the pending-task lists that it should be on. If readding is set, we are - * re-adding the task so only include it in each list if it's not already there. - */ - private def addPendingTask(index: Int, readding: Boolean = false) { - // Utility method that adds `index` to a list only if readding=false or it's not already there - def addTo(list: ArrayBuffer[Int]) { - if (!readding || !list.contains(index)) { - list += index - } - } - - var hadAliveLocations = false - for (loc <- tasks(index).preferredLocations) { - for (execId <- loc.executorId) { - if (sched.isExecutorAlive(execId)) { - addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer)) - hadAliveLocations = true - } - } - if (sched.hasExecutorsAliveOnHost(loc.host)) { - addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer)) - for (rack <- sched.getRackForHost(loc.host)) { - addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer)) - } - hadAliveLocations = true - } - } - - if (!hadAliveLocations) { - // Even though the task might've had preferred locations, all of those hosts or executors - // are dead; put it in the no-prefs list so we can schedule it elsewhere right away. - addTo(pendingTasksWithNoPrefs) - } - - if (!readding) { - allPendingTasks += index // No point scanning this whole list to find the old task there - } - } - - /** - * Return the pending tasks list for a given executor ID, or an empty list if - * there is no map entry for that host - */ - private def getPendingTasksForExecutor(executorId: String): ArrayBuffer[Int] = { - pendingTasksForExecutor.getOrElse(executorId, ArrayBuffer()) - } - - /** - * Return the pending tasks list for a given host, or an empty list if - * there is no map entry for that host - */ - private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = { - pendingTasksForHost.getOrElse(host, ArrayBuffer()) - } - - /** - * Return the pending rack-local task list for a given rack, or an empty list if - * there is no map entry for that rack - */ - private def getPendingTasksForRack(rack: String): ArrayBuffer[Int] = { - pendingTasksForRack.getOrElse(rack, ArrayBuffer()) - } - - /** - * Dequeue a pending task from the given list and return its index. - * Return None if the list is empty. - * This method also cleans up any tasks in the list that have already - * been launched, since we want that to happen lazily. - */ - private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = { - while (!list.isEmpty) { - val index = list.last - list.trimEnd(1) - if (copiesRunning(index) == 0 && !successful(index)) { - return Some(index) - } - } - return None - } - - /** Check whether a task is currently running an attempt on a given host */ - private def hasAttemptOnHost(taskIndex: Int, host: String): Boolean = { - !taskAttempts(taskIndex).exists(_.host == host) - } - - /** - * Return a speculative task for a given executor if any are available. The task should not have - * an attempt running on this host, in case the host is slow. In addition, the task should meet - * the given locality constraint. - */ - private def findSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value) - : Option[(Int, TaskLocality.Value)] = - { - speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set - - if (!speculatableTasks.isEmpty) { - // Check for process-local or preference-less tasks; note that tasks can be process-local - // on multiple nodes when we replicate cached blocks, as in Spark Streaming - for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { - val prefs = tasks(index).preferredLocations - val executors = prefs.flatMap(_.executorId) - if (prefs.size == 0 || executors.contains(execId)) { - speculatableTasks -= index - return Some((index, TaskLocality.PROCESS_LOCAL)) - } - } - - // Check for node-local tasks - if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) { - for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { - val locations = tasks(index).preferredLocations.map(_.host) - if (locations.contains(host)) { - speculatableTasks -= index - return Some((index, TaskLocality.NODE_LOCAL)) - } - } - } - - // Check for rack-local tasks - if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { - for (rack <- sched.getRackForHost(host)) { - for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { - val racks = tasks(index).preferredLocations.map(_.host).map(sched.getRackForHost) - if (racks.contains(rack)) { - speculatableTasks -= index - return Some((index, TaskLocality.RACK_LOCAL)) - } - } - } - } - - // Check for non-local tasks - if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { - for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { - speculatableTasks -= index - return Some((index, TaskLocality.ANY)) - } - } - } - - return None - } - - /** - * Dequeue a pending task for a given node and return its index and locality level. - * Only search for tasks matching the given locality constraint. - */ - private def findTask(execId: String, host: String, locality: TaskLocality.Value) - : Option[(Int, TaskLocality.Value)] = - { - for (index <- findTaskFromList(getPendingTasksForExecutor(execId))) { - return Some((index, TaskLocality.PROCESS_LOCAL)) - } - - if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) { - for (index <- findTaskFromList(getPendingTasksForHost(host))) { - return Some((index, TaskLocality.NODE_LOCAL)) - } - } - - if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { - for { - rack <- sched.getRackForHost(host) - index <- findTaskFromList(getPendingTasksForRack(rack)) - } { - return Some((index, TaskLocality.RACK_LOCAL)) - } - } - - // Look for no-pref tasks after rack-local tasks since they can run anywhere. - for (index <- findTaskFromList(pendingTasksWithNoPrefs)) { - return Some((index, TaskLocality.PROCESS_LOCAL)) - } - - if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { - for (index <- findTaskFromList(allPendingTasks)) { - return Some((index, TaskLocality.ANY)) - } - } - - // Finally, if all else has failed, find a speculative task - return findSpeculativeTask(execId, host, locality) - } - - /** - * Respond to an offer of a single executor from the scheduler by finding a task - */ - override def resourceOffer( - execId: String, - host: String, - availableCpus: Int, - maxLocality: TaskLocality.TaskLocality) - : Option[TaskDescription] = - { - if (tasksSuccessful < numTasks && availableCpus >= CPUS_PER_TASK) { - val curTime = clock.getTime() - - var allowedLocality = getAllowedLocalityLevel(curTime) - if (allowedLocality > maxLocality) { - allowedLocality = maxLocality // We're not allowed to search for farther-away tasks - } - - findTask(execId, host, allowedLocality) match { - case Some((index, taskLocality)) => { - // Found a task; do some bookkeeping and return a task description - val task = tasks(index) - val taskId = sched.newTaskId() - // Figure out whether this should count as a preferred launch - logInfo("Starting task %s:%d as TID %s on executor %s: %s (%s)".format( - taskSet.id, index, taskId, execId, host, taskLocality)) - // Do various bookkeeping - copiesRunning(index) += 1 - val info = new TaskInfo(taskId, index, curTime, execId, host, taskLocality) - taskInfos(taskId) = info - taskAttempts(index) = info :: taskAttempts(index) - // Update our locality level for delay scheduling - currentLocalityIndex = getLocalityIndex(taskLocality) - lastLaunchTime = curTime - // Serialize and return the task - val startTime = clock.getTime() - // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here - // we assume the task can be serialized without exceptions. - val serializedTask = Task.serializeWithDependencies( - task, sched.sc.addedFiles, sched.sc.addedJars, ser) - val timeTaken = clock.getTime() - startTime - addRunningTask(taskId) - logInfo("Serialized task %s:%d as %d bytes in %d ms".format( - taskSet.id, index, serializedTask.limit, timeTaken)) - val taskName = "task %s:%d".format(taskSet.id, index) - info.serializedSize = serializedTask.limit - if (taskAttempts(index).size == 1) - taskStarted(task,info) - return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask)) - } - case _ => - } - } - return None - } - - /** - * Get the level we can launch tasks according to delay scheduling, based on current wait time. - */ - private def getAllowedLocalityLevel(curTime: Long): TaskLocality.TaskLocality = { - while (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex) && - currentLocalityIndex < myLocalityLevels.length - 1) - { - // Jump to the next locality level, and remove our waiting time for the current one since - // we don't want to count it again on the next one - lastLaunchTime += localityWaits(currentLocalityIndex) - currentLocalityIndex += 1 - } - myLocalityLevels(currentLocalityIndex) - } - - /** - * Find the index in myLocalityLevels for a given locality. This is also designed to work with - * localities that are not in myLocalityLevels (in case we somehow get those) by returning the - * next-biggest level we have. Uses the fact that the last value in myLocalityLevels is ANY. - */ - def getLocalityIndex(locality: TaskLocality.TaskLocality): Int = { - var index = 0 - while (locality > myLocalityLevels(index)) { - index += 1 - } - index - } - - private def taskStarted(task: Task[_], info: TaskInfo) { - sched.dagScheduler.taskStarted(task, info) - } - - def handleTaskGettingResult(tid: Long) = { - val info = taskInfos(tid) - info.markGettingResult() - sched.dagScheduler.taskGettingResult(tasks(info.index), info) - } - - /** - * Marks the task as successful and notifies the DAGScheduler that a task has ended. - */ - def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = { - val info = taskInfos(tid) - val index = info.index - info.markSuccessful() - removeRunningTask(tid) - if (!successful(index)) { - logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format( - tid, info.duration, info.host, tasksSuccessful, numTasks)) - sched.dagScheduler.taskEnded( - tasks(index), Success, result.value, result.accumUpdates, info, result.metrics) - - // Mark successful and stop if all the tasks have succeeded. - tasksSuccessful += 1 - successful(index) = true - if (tasksSuccessful == numTasks) { - sched.taskSetFinished(this) - } - } else { - logInfo("Ignorning task-finished event for TID " + tid + " because task " + - index + " has already completed successfully") - } - } - - /** - * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the - * DAG Scheduler. - */ - def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) { - val info = taskInfos(tid) - if (info.failed) { - return - } - removeRunningTask(tid) - val index = info.index - info.markFailed() - if (!successful(index)) { - logWarning("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) - copiesRunning(index) -= 1 - // Check if the problem is a map output fetch failure. In that case, this - // task will never succeed on any node, so tell the scheduler about it. - reason.foreach { - case fetchFailed: FetchFailed => - logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress) - sched.dagScheduler.taskEnded(tasks(index), fetchFailed, null, null, info, null) - successful(index) = true - tasksSuccessful += 1 - sched.taskSetFinished(this) - removeAllRunningTasks() - return - - case TaskKilled => - logWarning("Task %d was killed.".format(tid)) - sched.dagScheduler.taskEnded(tasks(index), reason.get, null, null, info, null) - return - - case ef: ExceptionFailure => - sched.dagScheduler.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null)) - if (ef.className == classOf[NotSerializableException].getName()) { - // If the task result wasn't serializable, there's no point in trying to re-execute it. - logError("Task %s:%s had a not serializable result: %s; not retrying".format( - taskSet.id, index, ef.description)) - abort("Task %s:%s had a not serializable result: %s".format( - taskSet.id, index, ef.description)) - return - } - val key = ef.description - val now = clock.getTime() - val (printFull, dupCount) = { - if (recentExceptions.contains(key)) { - val (dupCount, printTime) = recentExceptions(key) - if (now - printTime > EXCEPTION_PRINT_INTERVAL) { - recentExceptions(key) = (0, now) - (true, 0) - } else { - recentExceptions(key) = (dupCount + 1, printTime) - (false, dupCount + 1) - } - } else { - recentExceptions(key) = (0, now) - (true, 0) - } - } - if (printFull) { - val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString)) - logWarning("Loss was due to %s\n%s\n%s".format( - ef.className, ef.description, locs.mkString("\n"))) - } else { - logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount)) - } - - case TaskResultLost => - logWarning("Lost result for TID %s on host %s".format(tid, info.host)) - sched.dagScheduler.taskEnded(tasks(index), TaskResultLost, null, null, info, null) - - case _ => {} - } - // On non-fetch failures, re-enqueue the task as pending for a max number of retries - addPendingTask(index) - if (state != TaskState.KILLED) { - numFailures(index) += 1 - if (numFailures(index) >= MAX_TASK_FAILURES) { - logError("Task %s:%d failed %d times; aborting job".format( - taskSet.id, index, MAX_TASK_FAILURES)) - abort("Task %s:%d failed %d times".format(taskSet.id, index, MAX_TASK_FAILURES)) - } - } - } else { - logInfo("Ignoring task-lost event for TID " + tid + - " because task " + index + " is already finished") - } - } - - override def error(message: String) { - // Save the error message - abort("Error: " + message) - } - - def abort(message: String) { - failed = true - causeOfFailure = message - // TODO: Kill running tasks if we were not terminated due to a Mesos error - sched.dagScheduler.taskSetFailed(taskSet, message) - removeAllRunningTasks() - sched.taskSetFinished(this) - } - - /** If the given task ID is not in the set of running tasks, adds it. - * - * Used to keep track of the number of running tasks, for enforcing scheduling policies. - */ - def addRunningTask(tid: Long) { - if (runningTasksSet.add(tid) && parent != null) { - parent.increaseRunningTasks(1) - } - runningTasks = runningTasksSet.size - } - - /** If the given task ID is in the set of running tasks, removes it. */ - def removeRunningTask(tid: Long) { - if (runningTasksSet.remove(tid) && parent != null) { - parent.decreaseRunningTasks(1) - } - runningTasks = runningTasksSet.size - } - - private[cluster] def removeAllRunningTasks() { - val numRunningTasks = runningTasksSet.size - runningTasksSet.clear() - if (parent != null) { - parent.decreaseRunningTasks(numRunningTasks) - } - runningTasks = 0 - } - - override def getSchedulableByName(name: String): Schedulable = { - return null - } - - override def addSchedulable(schedulable: Schedulable) {} - - override def removeSchedulable(schedulable: Schedulable) {} - - override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { - var sortedTaskSetQueue = ArrayBuffer[TaskSetManager](this) - sortedTaskSetQueue += this - return sortedTaskSetQueue - } - - /** Called by cluster scheduler when an executor is lost so we can re-enqueue our tasks */ - override def executorLost(execId: String, host: String) { - logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id) - - // Re-enqueue pending tasks for this host based on the status of the cluster -- for example, a - // task that used to have locations on only this host might now go to the no-prefs list. Note - // that it's okay if we add a task to the same queue twice (if it had multiple preferred - // locations), because findTaskFromList will skip already-running tasks. - for (index <- getPendingTasksForExecutor(execId)) { - addPendingTask(index, readding=true) - } - for (index <- getPendingTasksForHost(host)) { - addPendingTask(index, readding=true) - } - - // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage - if (tasks(0).isInstanceOf[ShuffleMapTask]) { - for ((tid, info) <- taskInfos if info.executorId == execId) { - val index = taskInfos(tid).index - if (successful(index)) { - successful(index) = false - copiesRunning(index) -= 1 - tasksSuccessful -= 1 - addPendingTask(index) - // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our - // stage finishes when a total of tasks.size tasks finish. - sched.dagScheduler.taskEnded(tasks(index), Resubmitted, null, null, info, null) - } - } - } - // Also re-enqueue any tasks that were running on the node - for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { - handleFailedTask(tid, TaskState.KILLED, None) - } - } - - /** - * Check for tasks to be speculated and return true if there are any. This is called periodically - * by the ClusterScheduler. - * - * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that - * we don't scan the whole task set. It might also help to make this sorted by launch time. - */ - override def checkSpeculatableTasks(): Boolean = { - // Can't speculate if we only have one task, or if all tasks have finished. - if (numTasks == 1 || tasksSuccessful == numTasks) { - return false - } - var foundTasks = false - val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt - logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation) - if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) { - val time = clock.getTime() - val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray - Arrays.sort(durations) - val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.size - 1)) - val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100) - // TODO: Threshold should also look at standard deviation of task durations and have a lower - // bound based on that. - logDebug("Task length threshold for speculation: " + threshold) - for ((tid, info) <- taskInfos) { - val index = info.index - if (!successful(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold && - !speculatableTasks.contains(index)) { - logInfo( - "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format( - taskSet.id, index, info.host, threshold)) - speculatableTasks += index - foundTasks = true - } - } - } - return foundTasks - } - - override def hasPendingTasks(): Boolean = { - numTasks > 0 && tasksSuccessful < numTasks - } - - private def getLocalityWait(level: TaskLocality.TaskLocality): Long = { - val defaultWait = conf.getOrElse("spark.locality.wait", "3000") - level match { - case TaskLocality.PROCESS_LOCAL => - conf.getOrElse("spark.locality.wait.process", defaultWait).toLong - case TaskLocality.NODE_LOCAL => - conf.getOrElse("spark.locality.wait.node", defaultWait).toLong - case TaskLocality.RACK_LOCAL => - conf.getOrElse("spark.locality.wait.rack", defaultWait).toLong - case TaskLocality.ANY => - 0L - } - } - - /** - * Compute the locality levels used in this TaskSet. Assumes that all tasks have already been - * added to queues using addPendingTask. - */ - private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = { - import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY} - val levels = new ArrayBuffer[TaskLocality.TaskLocality] - if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0) { - levels += PROCESS_LOCAL - } - if (!pendingTasksForHost.isEmpty && getLocalityWait(NODE_LOCAL) != 0) { - levels += NODE_LOCAL - } - if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0) { - levels += RACK_LOCAL - } - levels += ANY - logDebug("Valid locality levels for " + taskSet + ": " + levels.mkString(", ")) - levels.toArray - } -} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 156b01b149..b4a3ecca39 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -28,8 +28,10 @@ import akka.actor._ import akka.pattern.ask import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} +import org.apache.spark.{SparkException, Logging, TaskState} import org.apache.spark.{Logging, SparkException, TaskState} -import org.apache.spark.scheduler.TaskDescription +import org.apache.spark.scheduler.{TaskSchedulerImpl, SchedulerBackend, SlaveLost, TaskDescription, + WorkerOffer} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.{AkkaUtils, Utils} @@ -42,7 +44,7 @@ import org.apache.spark.util.{AkkaUtils, Utils} * (spark.deploy.*). */ private[spark] -class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: ActorSystem) +class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: ActorSystem) extends SchedulerBackend with Logging { // Use an atomic variable to track total number of cores in the cluster for simplicity and speed diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index d74f000ebb..f41fbbd1f3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -19,10 +19,12 @@ package org.apache.spark.scheduler.cluster import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, FileSystem} + import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.scheduler.TaskSchedulerImpl private[spark] class SimrSchedulerBackend( - scheduler: ClusterScheduler, + scheduler: TaskSchedulerImpl, sc: SparkContext, driverFilePath: String) extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index de69e3260d..224077566d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -17,14 +17,16 @@ package org.apache.spark.scheduler.cluster +import scala.collection.mutable.HashMap + import org.apache.spark.{Logging, SparkContext} import org.apache.spark.deploy.client.{Client, ClientListener} import org.apache.spark.deploy.{Command, ApplicationDescription} -import scala.collection.mutable.HashMap +import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskSchedulerImpl} import org.apache.spark.util.Utils private[spark] class SparkDeploySchedulerBackend( - scheduler: ClusterScheduler, + scheduler: TaskSchedulerImpl, sc: SparkContext, masters: Array[String], appName: String) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 1695374152..9e2cd3f699 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -30,7 +30,8 @@ import org.apache.mesos._ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} import org.apache.spark.{SparkException, Logging, SparkContext, TaskState} -import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedulerBackend} +import org.apache.spark.scheduler.TaskSchedulerImpl +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend /** * A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds @@ -43,7 +44,7 @@ import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedu * remove this. */ private[spark] class CoarseMesosSchedulerBackend( - scheduler: ClusterScheduler, + scheduler: TaskSchedulerImpl, sc: SparkContext, master: String, appName: String) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 8dfd4d5fb3..be96382983 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -30,9 +30,8 @@ import org.apache.mesos._ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} import org.apache.spark.{Logging, SparkException, SparkContext, TaskState} -import org.apache.spark.scheduler.TaskDescription -import org.apache.spark.scheduler.cluster.{ClusterScheduler, ExecutorExited, ExecutorLossReason} -import org.apache.spark.scheduler.cluster.{SchedulerBackend, SlaveLost, WorkerOffer} +import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SchedulerBackend, SlaveLost, + TaskDescription, TaskSchedulerImpl, WorkerOffer} import org.apache.spark.util.Utils /** @@ -41,7 +40,7 @@ import org.apache.spark.util.Utils * from multiple apps can run on different cores) and in time (a core can switch ownership). */ private[spark] class MesosSchedulerBackend( - scheduler: ClusterScheduler, + scheduler: TaskSchedulerImpl, sc: SparkContext, master: String, appName: String) diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala new file mode 100644 index 0000000000..4edc6a0d3f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.local + +import java.nio.ByteBuffer + +import akka.actor.{Actor, ActorRef, Props} + +import org.apache.spark.{Logging, SparkContext, SparkEnv, TaskState} +import org.apache.spark.TaskState.TaskState +import org.apache.spark.executor.{Executor, ExecutorBackend} +import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} + +private case class ReviveOffers() + +private case class StatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) + +private case class KillTask(taskId: Long) + +/** + * Calls to LocalBackend are all serialized through LocalActor. Using an actor makes the calls on + * LocalBackend asynchronous, which is necessary to prevent deadlock between LocalBackend + * and the ClusterScheduler. + */ +private[spark] class LocalActor( + scheduler: TaskSchedulerImpl, + executorBackend: LocalBackend, + private val totalCores: Int) extends Actor with Logging { + + private var freeCores = totalCores + + private val localExecutorId = "localhost" + private val localExecutorHostname = "localhost" + + val executor = new Executor(localExecutorId, localExecutorHostname, Seq.empty, isLocal = true) + + def receive = { + case ReviveOffers => + reviveOffers() + + case StatusUpdate(taskId, state, serializedData) => + scheduler.statusUpdate(taskId, state, serializedData) + if (TaskState.isFinished(state)) { + freeCores += 1 + reviveOffers() + } + + case KillTask(taskId) => + executor.killTask(taskId) + } + + def reviveOffers() { + val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores)) + for (task <- scheduler.resourceOffers(offers).flatten) { + freeCores -= 1 + executor.launchTask(executorBackend, task.taskId, task.serializedTask) + } + } +} + +/** + * LocalBackend is used when running a local version of Spark where the executor, backend, and + * master all run in the same JVM. It sits behind a ClusterScheduler and handles launching tasks + * on a single Executor (created by the LocalBackend) running locally. + */ +private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: Int) + extends SchedulerBackend with ExecutorBackend { + + var localActor: ActorRef = null + + override def start() { + localActor = SparkEnv.get.actorSystem.actorOf( + Props(new LocalActor(scheduler, this, totalCores)), + "LocalBackendActor") + } + + override def stop() { + } + + override def reviveOffers() { + localActor ! ReviveOffers + } + + override def defaultParallelism() = totalCores + + override def killTask(taskId: Long, executorId: String) { + localActor ! KillTask(taskId) + } + + override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) { + localActor ! StatusUpdate(taskId, state, serializedData) + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala deleted file mode 100644 index 7c173e3ad5..0000000000 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala +++ /dev/null @@ -1,224 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.local - -import java.nio.ByteBuffer -import java.util.concurrent.atomic.AtomicInteger - -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} - -import akka.actor._ - -import org.apache.spark._ -import org.apache.spark.TaskState.TaskState -import org.apache.spark.executor.{Executor, ExecutorBackend} -import org.apache.spark.scheduler._ -import org.apache.spark.scheduler.SchedulingMode.SchedulingMode - - -/** - * A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally - * the scheduler also allows each task to fail up to maxFailures times, which is useful for - * testing fault recovery. - */ - -private[local] -case class LocalReviveOffers() - -private[local] -case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) - -private[local] -case class KillTask(taskId: Long) - -private[spark] -class LocalActor(localScheduler: LocalScheduler, private var freeCores: Int) - extends Actor with Logging { - - val executor = new Executor( - "localhost", "localhost", localScheduler.sc.conf.getAll, isLocal = true) - - def receive = { - case LocalReviveOffers => - launchTask(localScheduler.resourceOffer(freeCores)) - - case LocalStatusUpdate(taskId, state, serializeData) => - if (TaskState.isFinished(state)) { - freeCores += 1 - launchTask(localScheduler.resourceOffer(freeCores)) - } - - case KillTask(taskId) => - executor.killTask(taskId) - } - - private def launchTask(tasks: Seq[TaskDescription]) { - for (task <- tasks) { - freeCores -= 1 - executor.launchTask(localScheduler, task.taskId, task.serializedTask) - } - } -} - -private[spark] class LocalScheduler(val threads: Int, val maxFailures: Int, val sc: SparkContext) - extends TaskScheduler - with ExecutorBackend - with Logging { - - val env = SparkEnv.get - val conf = env.conf - val attemptId = new AtomicInteger - var dagScheduler: DAGScheduler = null - - // Application dependencies (added through SparkContext) that we've fetched so far on this node. - // Each map holds the master's timestamp for the version of that file or JAR we got. - val currentFiles: HashMap[String, Long] = new HashMap[String, Long]() - val currentJars: HashMap[String, Long] = new HashMap[String, Long]() - - var schedulableBuilder: SchedulableBuilder = null - var rootPool: Pool = null - val schedulingMode: SchedulingMode = SchedulingMode.withName( - conf.getOrElse("spark.scheduler.mode", "FIFO")) - val activeTaskSets = new HashMap[String, LocalTaskSetManager] - val taskIdToTaskSetId = new HashMap[Long, String] - val taskSetTaskIds = new HashMap[String, HashSet[Long]] - - var localActor: ActorRef = null - - override def start() { - // temporarily set rootPool name to empty - rootPool = new Pool("", schedulingMode, 0, 0) - schedulableBuilder = { - schedulingMode match { - case SchedulingMode.FIFO => - new FIFOSchedulableBuilder(rootPool) - case SchedulingMode.FAIR => - new FairSchedulableBuilder(rootPool, conf) - } - } - schedulableBuilder.buildPools() - - localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test") - } - - override def setDAGScheduler(dagScheduler: DAGScheduler) { - this.dagScheduler = dagScheduler - } - - override def submitTasks(taskSet: TaskSet) { - synchronized { - val manager = new LocalTaskSetManager(this, taskSet) - schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) - activeTaskSets(taskSet.id) = manager - taskSetTaskIds(taskSet.id) = new HashSet[Long]() - localActor ! LocalReviveOffers - } - } - - override def cancelTasks(stageId: Int): Unit = synchronized { - logInfo("Cancelling stage " + stageId) - logInfo("Cancelling stage " + activeTaskSets.map(_._2.stageId)) - activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) => - // There are two possible cases here: - // 1. The task set manager has been created and some tasks have been scheduled. - // In this case, send a kill signal to the executors to kill the task and then abort - // the stage. - // 2. The task set manager has been created but no tasks has been scheduled. In this case, - // simply abort the stage. - val taskIds = taskSetTaskIds(tsm.taskSet.id) - if (taskIds.size > 0) { - taskIds.foreach { tid => - localActor ! KillTask(tid) - } - } - logInfo("Stage %d was cancelled".format(stageId)) - taskSetFinished(tsm) - } - } - - def resourceOffer(freeCores: Int): Seq[TaskDescription] = { - synchronized { - var freeCpuCores = freeCores - val tasks = new ArrayBuffer[TaskDescription](freeCores) - val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue() - for (manager <- sortedTaskSetQueue) { - logDebug("parentName:%s,name:%s,runningTasks:%s".format( - manager.parent.name, manager.name, manager.runningTasks)) - } - - var launchTask = false - for (manager <- sortedTaskSetQueue) { - do { - launchTask = false - manager.resourceOffer(null, null, freeCpuCores, null) match { - case Some(task) => - tasks += task - taskIdToTaskSetId(task.taskId) = manager.taskSet.id - taskSetTaskIds(manager.taskSet.id) += task.taskId - freeCpuCores -= 1 - launchTask = true - case None => {} - } - } while(launchTask) - } - return tasks - } - } - - def taskSetFinished(manager: TaskSetManager) { - synchronized { - activeTaskSets -= manager.taskSet.id - manager.parent.removeSchedulable(manager) - logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name)) - taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) - taskSetTaskIds -= manager.taskSet.id - } - } - - override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) { - if (TaskState.isFinished(state)) { - synchronized { - taskIdToTaskSetId.get(taskId) match { - case Some(taskSetId) => - val taskSetManager = activeTaskSets.get(taskSetId) - taskSetManager.foreach { tsm => - taskSetTaskIds(taskSetId) -= taskId - - state match { - case TaskState.FINISHED => - tsm.taskEnded(taskId, state, serializedData) - case TaskState.FAILED => - tsm.taskFailed(taskId, state, serializedData) - case TaskState.KILLED => - tsm.error("Task %d was killed".format(taskId)) - case _ => {} - } - } - case None => - logInfo("Ignoring update from TID " + taskId + " because its task set is gone") - } - } - localActor ! LocalStatusUpdate(taskId, state, serializedData) - } - } - - override def stop() { - } - - override def defaultParallelism() = threads -} diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala deleted file mode 100644 index 53bf78267e..0000000000 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala +++ /dev/null @@ -1,191 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.local - -import java.nio.ByteBuffer -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap - -import org.apache.spark.{ExceptionFailure, Logging, SparkEnv, SparkException, Success, TaskState} -import org.apache.spark.TaskState.TaskState -import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Pool, Schedulable, Task, - TaskDescription, TaskInfo, TaskLocality, TaskResult, TaskSet, TaskSetManager} - - -private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet) - extends TaskSetManager with Logging { - - var parent: Pool = null - var weight: Int = 1 - var minShare: Int = 0 - var runningTasks: Int = 0 - var priority: Int = taskSet.priority - var stageId: Int = taskSet.stageId - var name: String = "TaskSet_" + taskSet.stageId.toString - - var failCount = new Array[Int](taskSet.tasks.size) - val taskInfos = new HashMap[Long, TaskInfo] - val numTasks = taskSet.tasks.size - var numFinished = 0 - val env = SparkEnv.get - val ser = env.closureSerializer.newInstance() - val copiesRunning = new Array[Int](numTasks) - val finished = new Array[Boolean](numTasks) - val numFailures = new Array[Int](numTasks) - val MAX_TASK_FAILURES = sched.maxFailures - - def increaseRunningTasks(taskNum: Int): Unit = { - runningTasks += taskNum - if (parent != null) { - parent.increaseRunningTasks(taskNum) - } - } - - def decreaseRunningTasks(taskNum: Int): Unit = { - runningTasks -= taskNum - if (parent != null) { - parent.decreaseRunningTasks(taskNum) - } - } - - override def addSchedulable(schedulable: Schedulable): Unit = { - // nothing - } - - override def removeSchedulable(schedulable: Schedulable): Unit = { - // nothing - } - - override def getSchedulableByName(name: String): Schedulable = { - return null - } - - override def executorLost(executorId: String, host: String): Unit = { - // nothing - } - - override def checkSpeculatableTasks() = true - - override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { - var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] - sortedTaskSetQueue += this - return sortedTaskSetQueue - } - - override def hasPendingTasks() = true - - def findTask(): Option[Int] = { - for (i <- 0 to numTasks-1) { - if (copiesRunning(i) == 0 && !finished(i)) { - return Some(i) - } - } - return None - } - - override def resourceOffer( - execId: String, - host: String, - availableCpus: Int, - maxLocality: TaskLocality.TaskLocality) - : Option[TaskDescription] = - { - SparkEnv.set(sched.env) - logDebug("availableCpus:%d, numFinished:%d, numTasks:%d".format( - availableCpus.toInt, numFinished, numTasks)) - if (availableCpus > 0 && numFinished < numTasks) { - findTask() match { - case Some(index) => - val taskId = sched.attemptId.getAndIncrement() - val task = taskSet.tasks(index) - val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1", - TaskLocality.NODE_LOCAL) - taskInfos(taskId) = info - // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here - // we assume the task can be serialized without exceptions. - val bytes = Task.serializeWithDependencies( - task, sched.sc.addedFiles, sched.sc.addedJars, ser) - logInfo("Size of task " + taskId + " is " + bytes.limit + " bytes") - val taskName = "task %s:%d".format(taskSet.id, index) - copiesRunning(index) += 1 - increaseRunningTasks(1) - taskStarted(task, info) - return Some(new TaskDescription(taskId, null, taskName, index, bytes)) - case None => {} - } - } - return None - } - - def taskStarted(task: Task[_], info: TaskInfo) { - sched.dagScheduler.taskStarted(task, info) - } - - def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) { - val info = taskInfos(tid) - val index = info.index - val task = taskSet.tasks(index) - info.markSuccessful() - val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader) match { - case directResult: DirectTaskResult[_] => directResult - case IndirectTaskResult(blockId) => { - throw new SparkException("Expect only DirectTaskResults when using LocalScheduler") - } - } - result.metrics.resultSize = serializedData.limit() - sched.dagScheduler.taskEnded(task, Success, result.value, result.accumUpdates, info, - result.metrics) - numFinished += 1 - decreaseRunningTasks(1) - finished(index) = true - if (numFinished == numTasks) { - sched.taskSetFinished(this) - } - } - - def taskFailed(tid: Long, state: TaskState, serializedData: ByteBuffer) { - val info = taskInfos(tid) - val index = info.index - val task = taskSet.tasks(index) - info.markFailed() - decreaseRunningTasks(1) - val reason: ExceptionFailure = ser.deserialize[ExceptionFailure]( - serializedData, getClass.getClassLoader) - sched.dagScheduler.taskEnded(task, reason, null, null, info, reason.metrics.getOrElse(null)) - if (!finished(index)) { - copiesRunning(index) -= 1 - numFailures(index) += 1 - val locs = reason.stackTrace.map(loc => "\tat %s".format(loc.toString)) - logInfo("Loss was due to %s\n%s\n%s".format( - reason.className, reason.description, locs.mkString("\n"))) - if (numFailures(index) > MAX_TASK_FAILURES) { - val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format( - taskSet.id, index, MAX_TASK_FAILURES, reason.description) - decreaseRunningTasks(runningTasks) - sched.dagScheduler.taskSetFailed(taskSet, errorMessage) - // need to delete failed Taskset from schedule queue - sched.taskSetFinished(this) - } - } - } - - override def error(message: String) { - sched.dagScheduler.taskSetFailed(taskSet, message) - sched.taskSetFinished(this) - } -} diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index f592df283a..151eedb783 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -74,10 +74,16 @@ class ShuffleBlockManager(blockManager: BlockManager) { * Contains all the state related to a particular shuffle. This includes a pool of unused * ShuffleFileGroups, as well as all ShuffleFileGroups that have been created for the shuffle. */ - private class ShuffleState() { + private class ShuffleState(val numBuckets: Int) { val nextFileId = new AtomicInteger(0) val unusedFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]() val allFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]() + + /** + * The mapIds of all map tasks completed on this Executor for this shuffle. + * NB: This is only populated if consolidateShuffleFiles is FALSE. We don't need it otherwise. + */ + val completedMapTasks = new ConcurrentLinkedQueue[Int]() } type ShuffleId = Int @@ -88,7 +94,7 @@ class ShuffleBlockManager(blockManager: BlockManager) { def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer) = { new ShuffleWriterGroup { - shuffleStates.putIfAbsent(shuffleId, new ShuffleState()) + shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets)) private val shuffleState = shuffleStates(shuffleId) private var fileGroup: ShuffleFileGroup = null @@ -113,6 +119,8 @@ class ShuffleBlockManager(blockManager: BlockManager) { fileGroup.recordMapOutput(mapId, offsets) } recycleFileGroup(fileGroup) + } else { + shuffleState.completedMapTasks.add(mapId) } } @@ -158,7 +166,18 @@ class ShuffleBlockManager(blockManager: BlockManager) { } private def cleanup(cleanupTime: Long) { - shuffleStates.clearOldValues(cleanupTime) + shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => { + if (consolidateShuffleFiles) { + for (fileGroup <- state.allFileGroups; file <- fileGroup.files) { + file.delete() + } + } else { + for (mapId <- state.completedMapTasks; reduceId <- 0 until state.numBuckets) { + val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId) + blockManager.diskBlockManager.getFile(blockId).delete() + } + } + }) } } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala index e596690bc3..a31a7e1d58 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala @@ -56,7 +56,8 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize)).fold(0L)(_+_) val execHead = Seq("Executor ID", "Address", "RDD blocks", "Memory used", "Disk used", - "Active tasks", "Failed tasks", "Complete tasks", "Total tasks") + "Active tasks", "Failed tasks", "Complete tasks", "Total tasks", "Task Time", "Shuffle Read", + "Shuffle Write") def execRow(kv: Seq[String]) = { <tr> @@ -73,6 +74,9 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { <td>{kv(7)}</td> <td>{kv(8)}</td> <td>{kv(9)}</td> + <td>{Utils.msDurationToString(kv(10).toLong)}</td> + <td>{Utils.bytesToString(kv(11).toLong)}</td> + <td>{Utils.bytesToString(kv(12).toLong)}</td> </tr> } @@ -111,6 +115,9 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { val failedTasks = listener.executorToTasksFailed.getOrElse(execId, 0) val completedTasks = listener.executorToTasksComplete.getOrElse(execId, 0) val totalTasks = activeTasks + failedTasks + completedTasks + val totalDuration = listener.executorToDuration.getOrElse(execId, 0) + val totalShuffleRead = listener.executorToShuffleRead.getOrElse(execId, 0) + val totalShuffleWrite = listener.executorToShuffleWrite.getOrElse(execId, 0) Seq( execId, @@ -122,7 +129,10 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { activeTasks.toString, failedTasks.toString, completedTasks.toString, - totalTasks.toString + totalTasks.toString, + totalDuration.toString, + totalShuffleRead.toString, + totalShuffleWrite.toString ) } @@ -130,6 +140,9 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { val executorToTasksActive = HashMap[String, HashSet[TaskInfo]]() val executorToTasksComplete = HashMap[String, Int]() val executorToTasksFailed = HashMap[String, Int]() + val executorToDuration = HashMap[String, Long]() + val executorToShuffleRead = HashMap[String, Long]() + val executorToShuffleWrite = HashMap[String, Long]() override def onTaskStart(taskStart: SparkListenerTaskStart) { val eid = taskStart.taskInfo.executorId @@ -140,6 +153,9 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { val eid = taskEnd.taskInfo.executorId val activeTasks = executorToTasksActive.getOrElseUpdate(eid, new HashSet[TaskInfo]()) + val newDuration = executorToDuration.getOrElse(eid, 0L) + taskEnd.taskInfo.duration + executorToDuration.put(eid, newDuration) + activeTasks -= taskEnd.taskInfo val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) = taskEnd.reason match { @@ -150,6 +166,17 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { executorToTasksComplete(eid) = executorToTasksComplete.getOrElse(eid, 0) + 1 (None, Option(taskEnd.taskMetrics)) } + + // update shuffle read/write + if (null != taskEnd.taskMetrics) { + taskEnd.taskMetrics.shuffleReadMetrics.foreach(shuffleRead => + executorToShuffleRead.put(eid, executorToShuffleRead.getOrElse(eid, 0L) + + shuffleRead.remoteBytesRead)) + + taskEnd.taskMetrics.shuffleWriteMetrics.foreach(shuffleWrite => + executorToShuffleWrite.put(eid, executorToShuffleWrite.getOrElse(eid, 0L) + + shuffleWrite.shuffleBytesWritten)) + } } } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala new file mode 100644 index 0000000000..3c53e88380 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +/** class for reporting aggregated metrics for each executors in stageUI */ +private[spark] class ExecutorSummary { + var taskTime : Long = 0 + var failedTasks : Int = 0 + var succeededTasks : Int = 0 + var shuffleRead : Long = 0 + var shuffleWrite : Long = 0 +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala new file mode 100644 index 0000000000..0dd876480a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +import scala.xml.Node + +import org.apache.spark.scheduler.SchedulingMode +import org.apache.spark.util.Utils +import scala.collection.mutable + +/** Page showing executor summary */ +private[spark] class ExecutorTable(val parent: JobProgressUI, val stageId: Int) { + + val listener = parent.listener + val dateFmt = parent.dateFmt + val isFairScheduler = listener.sc.getSchedulingMode == SchedulingMode.FAIR + + def toNodeSeq(): Seq[Node] = { + listener.synchronized { + executorTable() + } + } + + /** Special table which merges two header cells. */ + private def executorTable[T](): Seq[Node] = { + <table class="table table-bordered table-striped table-condensed sortable"> + <thead> + <th>Executor ID</th> + <th>Address</th> + <th>Task Time</th> + <th>Total Tasks</th> + <th>Failed Tasks</th> + <th>Succeeded Tasks</th> + <th>Shuffle Read</th> + <th>Shuffle Write</th> + </thead> + <tbody> + {createExecutorTable()} + </tbody> + </table> + } + + private def createExecutorTable() : Seq[Node] = { + // make a executor-id -> address map + val executorIdToAddress = mutable.HashMap[String, String]() + val storageStatusList = parent.sc.getExecutorStorageStatus + for (statusId <- 0 until storageStatusList.size) { + val blockManagerId = parent.sc.getExecutorStorageStatus(statusId).blockManagerId + val address = blockManagerId.hostPort + val executorId = blockManagerId.executorId + executorIdToAddress.put(executorId, address) + } + + val executorIdToSummary = listener.stageIdToExecutorSummaries.get(stageId) + executorIdToSummary match { + case Some(x) => { + x.toSeq.sortBy(_._1).map{ + case (k,v) => { + <tr> + <td>{k}</td> + <td>{executorIdToAddress.getOrElse(k, "CANNOT FIND ADDRESS")}</td> + <td>{parent.formatDuration(v.taskTime)}</td> + <td>{v.failedTasks + v.succeededTasks}</td> + <td>{v.failedTasks}</td> + <td>{v.succeededTasks}</td> + <td>{Utils.bytesToString(v.shuffleRead)}</td> + <td>{Utils.bytesToString(v.shuffleWrite)}</td> + </tr> + } + } + } + case _ => { Seq[Node]() } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 6ff8e9fb14..eed3544b70 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -57,6 +57,7 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList val stageIdToTasksFailed = HashMap[Int, Int]() val stageIdToTaskInfos = HashMap[Int, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]]() + val stageIdToExecutorSummaries = HashMap[Int, HashMap[String, ExecutorSummary]]() override def onJobStart(jobStart: SparkListenerJobStart) {} @@ -124,8 +125,38 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { val sid = taskEnd.task.stageId + + // create executor summary map if necessary + val executorSummaryMap = stageIdToExecutorSummaries.getOrElseUpdate(key = sid, + op = new HashMap[String, ExecutorSummary]()) + executorSummaryMap.getOrElseUpdate(key = taskEnd.taskInfo.executorId, + op = new ExecutorSummary()) + + val executorSummary = executorSummaryMap.get(taskEnd.taskInfo.executorId) + executorSummary match { + case Some(y) => { + // first update failed-task, succeed-task + taskEnd.reason match { + case Success => + y.succeededTasks += 1 + case _ => + y.failedTasks += 1 + } + + // update duration + y.taskTime += taskEnd.taskInfo.duration + + Option(taskEnd.taskMetrics).foreach { taskMetrics => + taskMetrics.shuffleReadMetrics.foreach { y.shuffleRead += _.remoteBytesRead } + taskMetrics.shuffleWriteMetrics.foreach { y.shuffleWrite += _.shuffleBytesWritten } + } + } + case _ => {} + } + val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]()) tasksActive -= taskEnd.taskInfo + val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) = taskEnd.reason match { case e: ExceptionFailure => diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 996e1b4d1a..8dcfeacb60 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -66,7 +66,7 @@ private[spark] class StagePage(parent: JobProgressUI) { <div> <ul class="unstyled"> <li> - <strong>Total duration across all tasks: </strong> + <strong>Total task time across all tasks: </strong> {parent.formatDuration(listener.stageIdToTime.getOrElse(stageId, 0L) + activeTime)} </li> {if (hasShuffleRead) @@ -166,11 +166,12 @@ private[spark] class StagePage(parent: JobProgressUI) { def quantileRow(data: Seq[String]): Seq[Node] = <tr> {data.map(d => <td>{d}</td>)} </tr> Some(listingTable(quantileHeaders, quantileRow, listings, fixedWidth = true)) } - + val executorTable = new ExecutorTable(parent, stageId) val content = summary ++ <h4>Summary Metrics for {numCompleted} Completed Tasks</h4> ++ <div>{summaryTable.getOrElse("No tasks have reported metrics yet.")}</div> ++ + <h4>Aggregated Metrics by Executors</h4> ++ executorTable.toNodeSeq() ++ <h4>Tasks</h4> ++ taskTable headerSparkPage(content, parent.sc, "Details for Stage %d".format(stageId), Stages) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 9ad6de3c6d..463d85dfd5 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -48,7 +48,7 @@ private[spark] class StageTable(val stages: Seq[StageInfo], val parent: JobProgr {if (isFairScheduler) {<th>Pool Name</th>} else {}} <th>Description</th> <th>Submitted</th> - <th>Duration</th> + <th>Task Time</th> <th>Tasks: Succeeded/Total</th> <th>Shuffle Read</th> <th>Shuffle Write</th> diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index 431d88838f..9ea7fc2dfd 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -32,7 +32,7 @@ class MetadataCleaner( { val name = cleanerType.toString - private val delaySeconds = MetadataCleaner.getDelaySeconds(conf) + private val delaySeconds = MetadataCleaner.getDelaySeconds(conf, cleanerType) private val periodSeconds = math.max(10, delaySeconds / 10) private val timer = new Timer(name + " cleanup timer", true) diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala index dbff571de9..181ae2fd45 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala @@ -104,19 +104,28 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() with Logging { def toMap: immutable.Map[A, B] = iterator.toMap /** - * Removes old key-value pairs that have timestamp earlier than `threshTime` + * Removes old key-value pairs that have timestamp earlier than `threshTime`, + * calling the supplied function on each such entry before removing. */ - def clearOldValues(threshTime: Long) { + def clearOldValues(threshTime: Long, f: (A, B) => Unit) { val iterator = internalMap.entrySet().iterator() - while(iterator.hasNext) { + while (iterator.hasNext) { val entry = iterator.next() if (entry.getValue._2 < threshTime) { + f(entry.getKey, entry.getValue._1) logDebug("Removing key " + entry.getKey) iterator.remove() } } } + /** + * Removes old key-value pairs that have timestamp earlier than `threshTime` + */ + def clearOldValues(threshTime: Long) { + clearOldValues(threshTime, (_, _) => ()) + } + private def currentTime: Long = System.currentTimeMillis() } diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index af448fcb37..befdc1589f 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -42,7 +42,7 @@ class FailureSuite extends FunSuite with LocalSparkContext { // Run a 3-task map job in which task 1 deterministically fails once, and check // whether the job completes successfully and we ran 4 tasks in total. test("failure in a single-stage job") { - sc = new SparkContext("local[1,1]", "test") + sc = new SparkContext("local[1,2]", "test") val results = sc.makeRDD(1 to 3, 3).map { x => FailureSuiteState.synchronized { FailureSuiteState.tasksRun += 1 @@ -62,7 +62,7 @@ class FailureSuite extends FunSuite with LocalSparkContext { // Run a map-reduce job in which a reduce task deterministically fails once. test("failure in a two-stage job") { - sc = new SparkContext("local[1,1]", "test") + sc = new SparkContext("local[1,2]", "test") val results = sc.makeRDD(1 to 3).map(x => (x, x)).groupByKey(3).map { case (k, v) => FailureSuiteState.synchronized { diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index 151af0d213..f28d5c7b13 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -19,20 +19,21 @@ package org.apache.spark import org.scalatest.{FunSuite, PrivateMethodTester} -import org.apache.spark.scheduler.TaskScheduler -import org.apache.spark.scheduler.cluster.{ClusterScheduler, SimrSchedulerBackend, SparkDeploySchedulerBackend} +import org.apache.spark.scheduler.{TaskSchedulerImpl, TaskScheduler} +import org.apache.spark.scheduler.cluster.{SimrSchedulerBackend, SparkDeploySchedulerBackend} import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} -import org.apache.spark.scheduler.local.LocalScheduler +import org.apache.spark.scheduler.local.LocalBackend class SparkContextSchedulerCreationSuite extends FunSuite with PrivateMethodTester with LocalSparkContext with Logging { - def createTaskScheduler(master: String): TaskScheduler = { + def createTaskScheduler(master: String): TaskSchedulerImpl = { // Create local SparkContext to setup a SparkEnv. We don't actually want to start() the // real schedulers, so we don't want to create a full SparkContext with the desired scheduler. sc = new SparkContext("local", "test") val createTaskSchedulerMethod = PrivateMethod[TaskScheduler]('createTaskScheduler) - SparkContext invokePrivate createTaskSchedulerMethod(sc, master, "test") + val sched = SparkContext invokePrivate createTaskSchedulerMethod(sc, master, "test") + sched.asInstanceOf[TaskSchedulerImpl] } test("bad-master") { @@ -43,55 +44,49 @@ class SparkContextSchedulerCreationSuite } test("local") { - createTaskScheduler("local") match { - case s: LocalScheduler => - assert(s.threads === 1) - assert(s.maxFailures === 0) + val sched = createTaskScheduler("local") + sched.backend match { + case s: LocalBackend => assert(s.totalCores === 1) case _ => fail() } } test("local-n") { - createTaskScheduler("local[5]") match { - case s: LocalScheduler => - assert(s.threads === 5) - assert(s.maxFailures === 0) + val sched = createTaskScheduler("local[5]") + assert(sched.maxTaskFailures === 1) + sched.backend match { + case s: LocalBackend => assert(s.totalCores === 5) case _ => fail() } } test("local-n-failures") { - createTaskScheduler("local[4, 2]") match { - case s: LocalScheduler => - assert(s.threads === 4) - assert(s.maxFailures === 2) + val sched = createTaskScheduler("local[4, 2]") + assert(sched.maxTaskFailures === 2) + sched.backend match { + case s: LocalBackend => assert(s.totalCores === 4) case _ => fail() } } test("simr") { - createTaskScheduler("simr://uri") match { - case s: ClusterScheduler => - assert(s.backend.isInstanceOf[SimrSchedulerBackend]) + createTaskScheduler("simr://uri").backend match { + case s: SimrSchedulerBackend => // OK case _ => fail() } } test("local-cluster") { - createTaskScheduler("local-cluster[3, 14, 512]") match { - case s: ClusterScheduler => - assert(s.backend.isInstanceOf[SparkDeploySchedulerBackend]) + createTaskScheduler("local-cluster[3, 14, 512]").backend match { + case s: SparkDeploySchedulerBackend => // OK case _ => fail() } } def testYarn(master: String, expectedClassName: String) { try { - createTaskScheduler(master) match { - case s: ClusterScheduler => - assert(s.getClass === Class.forName(expectedClassName)) - case _ => fail() - } + val sched = createTaskScheduler(master) + assert(sched.getClass === Class.forName(expectedClassName)) } catch { case e: SparkException => assert(e.getMessage.contains("YARN mode not available")) @@ -110,11 +105,8 @@ class SparkContextSchedulerCreationSuite def testMesos(master: String, expectedClass: Class[_]) { try { - createTaskScheduler(master) match { - case s: ClusterScheduler => - assert(s.backend.getClass === expectedClass) - case _ => fail() - } + val sched = createTaskScheduler(master) + assert(sched.backend.getClass === expectedClass) } catch { case e: UnsatisfiedLinkError => assert(e.getMessage.contains("no mesos in")) diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala index 34d2e4cb8c..7bf2020fe3 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala @@ -15,14 +15,12 @@ * limitations under the License. */ -package org.apache.spark.scheduler.cluster +package org.apache.spark.scheduler import org.scalatest.FunSuite import org.scalatest.BeforeAndAfter import org.apache.spark._ -import org.apache.spark.scheduler._ -import org.apache.spark.scheduler.cluster._ import scala.collection.mutable.ArrayBuffer import java.util.Properties @@ -31,9 +29,9 @@ class FakeTaskSetManager( initPriority: Int, initStageId: Int, initNumTasks: Int, - clusterScheduler: ClusterScheduler, + clusterScheduler: TaskSchedulerImpl, taskSet: TaskSet) - extends ClusterTaskSetManager(clusterScheduler, taskSet) { + extends TaskSetManager(clusterScheduler, taskSet, 0) { parent = null weight = 1 @@ -106,7 +104,7 @@ class FakeTaskSetManager( class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging { - def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: ClusterScheduler, taskSet: TaskSet): FakeTaskSetManager = { + def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: TaskSchedulerImpl, taskSet: TaskSet): FakeTaskSetManager = { new FakeTaskSetManager(priority, stage, numTasks, cs , taskSet) } @@ -133,7 +131,7 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging test("FIFO Scheduler Test") { sc = new SparkContext("local", "ClusterSchedulerSuite") - val clusterScheduler = new ClusterScheduler(sc) + val clusterScheduler = new TaskSchedulerImpl(sc) var tasks = ArrayBuffer[Task[_]]() val task = new FakeTask(0) tasks += task @@ -160,7 +158,7 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging test("Fair Scheduler Test") { sc = new SparkContext("local", "ClusterSchedulerSuite") - val clusterScheduler = new ClusterScheduler(sc) + val clusterScheduler = new TaskSchedulerImpl(sc) var tasks = ArrayBuffer[Task[_]]() val task = new FakeTask(0) tasks += task @@ -217,7 +215,7 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging test("Nested Pool Test") { sc = new SparkContext("local", "ClusterSchedulerSuite") - val clusterScheduler = new ClusterScheduler(sc) + val clusterScheduler = new TaskSchedulerImpl(sc) var tasks = ArrayBuffer[Task[_]]() val task = new FakeTask(0) tasks += task diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index 0f01515179..0b90c4e74c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -15,10 +15,9 @@ * limitations under the License. */ -package org.apache.spark.scheduler.cluster +package org.apache.spark.scheduler import org.apache.spark.TaskContext -import org.apache.spark.scheduler.{TaskLocation, Task} class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0) { override def runTask(context: TaskContext): Int = 0 diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 2e41438a52..d4320e5e14 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -19,23 +19,26 @@ package org.apache.spark.scheduler import scala.collection.mutable.{Buffer, HashSet} -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} import org.scalatest.matchers.ShouldMatchers import org.apache.spark.{LocalSparkContext, SparkContext} import org.apache.spark.SparkContext._ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers - with BeforeAndAfterAll { + with BeforeAndAfter with BeforeAndAfterAll { /** Length of time to wait while draining listener events. */ val WAIT_TIMEOUT_MILLIS = 10000 + before { + sc = new SparkContext("local", "SparkListenerSuite") + } + override def afterAll { System.clearProperty("spark.akka.frameSize") } test("basic creation of StageInfo") { - sc = new SparkContext("local", "DAGSchedulerSuite") val listener = new SaveStageInfo sc.addSparkListener(listener) val rdd1 = sc.parallelize(1 to 100, 4) @@ -56,7 +59,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc } test("StageInfo with fewer tasks than partitions") { - sc = new SparkContext("local", "DAGSchedulerSuite") val listener = new SaveStageInfo sc.addSparkListener(listener) val rdd1 = sc.parallelize(1 to 100, 4) @@ -72,7 +74,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc } test("local metrics") { - sc = new SparkContext("local", "DAGSchedulerSuite") val listener = new SaveStageInfo sc.addSparkListener(listener) sc.addSparkListener(new StatsReportListener) @@ -135,10 +136,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc } test("onTaskGettingResult() called when result fetched remotely") { - // Need to use local cluster mode here, because results are not ever returned through the - // block manager when using the LocalScheduler. - sc = new SparkContext("local-cluster[1,1,512]", "test") - val listener = new SaveTaskEvents sc.addSparkListener(listener) @@ -157,10 +154,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc } test("onTaskGettingResult() not called when result sent directly") { - // Need to use local cluster mode here, because results are not ever returned through the - // block manager when using the LocalScheduler. - sc = new SparkContext("local-cluster[1,1,512]", "test") - val listener = new SaveTaskEvents sc.addSparkListener(listener) diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index 618fae7c16..4b52d9651e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -15,14 +15,13 @@ * limitations under the License. */ -package org.apache.spark.scheduler.cluster +package org.apache.spark.scheduler import java.nio.ByteBuffer import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} -import org.apache.spark.{SparkConf, LocalSparkContext, SparkContext, SparkEnv} -import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv} import org.apache.spark.storage.TaskResultBlockId /** @@ -31,12 +30,12 @@ import org.apache.spark.storage.TaskResultBlockId * Used to test the case where a BlockManager evicts the task result (or dies) before the * TaskResult is retrieved. */ -class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler) +class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedulerImpl) extends TaskResultGetter(sparkEnv, scheduler) { var removedResult = false override def enqueueSuccessfulTask( - taskSetManager: ClusterTaskSetManager, tid: Long, serializedData: ByteBuffer) { + taskSetManager: TaskSetManager, tid: Long, serializedData: ByteBuffer) { if (!removedResult) { // Only remove the result once, since we'd like to test the case where the task eventually // succeeds. @@ -44,13 +43,13 @@ class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSched case IndirectTaskResult(blockId) => sparkEnv.blockManager.master.removeBlock(blockId) case directResult: DirectTaskResult[_] => - taskSetManager.abort("Internal error: expect only indirect results") + taskSetManager.abort("Internal error: expect only indirect results") } serializedData.rewind() removedResult = true } super.enqueueSuccessfulTask(taskSetManager, tid, serializedData) - } + } } /** @@ -65,22 +64,18 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA System.setProperty("spark.akka.frameSize", "1") } - before { - // Use local-cluster mode because results are returned differently when running with the - // LocalScheduler. - sc = new SparkContext("local-cluster[1,1,512]", "test") - } - override def afterAll { System.clearProperty("spark.akka.frameSize") } test("handling results smaller than Akka frame size") { + sc = new SparkContext("local", "test") val result = sc.parallelize(Seq(1), 1).map(x => 2 * x).reduce((x, y) => x) assert(result === 2) } - test("handling results larger than Akka frame size") { + test("handling results larger than Akka frame size") { + sc = new SparkContext("local", "test") val akkaFrameSize = sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x) @@ -92,10 +87,13 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA } test("task retried if result missing from block manager") { + // Set the maximum number of task failures to > 0, so that the task set isn't aborted + // after the result is missing. + sc = new SparkContext("local[1,2]", "test") // If this test hangs, it's probably because no resource offers were made after the task // failed. - val scheduler: ClusterScheduler = sc.taskScheduler match { - case clusterScheduler: ClusterScheduler => + val scheduler: TaskSchedulerImpl = sc.taskScheduler match { + case clusterScheduler: TaskSchedulerImpl => clusterScheduler case _ => assert(false, "Expect local cluster to use ClusterScheduler") diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 3711382f2e..5d33e66253 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.scheduler.cluster +package org.apache.spark.scheduler import scala.collection.mutable.ArrayBuffer import scala.collection.mutable @@ -23,7 +23,6 @@ import scala.collection.mutable import org.scalatest.FunSuite import org.apache.spark._ -import org.apache.spark.scheduler._ import org.apache.spark.executor.TaskMetrics import java.nio.ByteBuffer import org.apache.spark.util.{Utils, FakeClock} @@ -56,10 +55,10 @@ class FakeDAGScheduler(taskScheduler: FakeClusterScheduler) extends DAGScheduler * A mock ClusterScheduler implementation that just remembers information about tasks started and * feedback received from the TaskSetManagers. Note that it's important to initialize this with * a list of "live" executors and their hostnames for isExecutorAlive and hasExecutorsAliveOnHost - * to work, and these are required for locality in ClusterTaskSetManager. + * to work, and these are required for locality in TaskSetManager. */ class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /* execId, host */) - extends ClusterScheduler(sc) + extends TaskSchedulerImpl(sc) { val startedTasks = new ArrayBuffer[Long] val endedTasks = new mutable.HashMap[Long, TaskEndReason] @@ -79,16 +78,19 @@ class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /* override def hasExecutorsAliveOnHost(host: String): Boolean = executors.values.exists(_ == host) } -class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { +class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { import TaskLocality.{ANY, PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL} + private val conf = new SparkConf + val LOCALITY_WAIT = conf.getOrElse("spark.locality.wait", "3000").toLong + val MAX_TASK_FAILURES = 4 test("TaskSet with no preferences") { sc = new SparkContext("local", "test") val sched = new FakeClusterScheduler(sc, ("exec1", "host1")) val taskSet = createTaskSet(1) - val manager = new ClusterTaskSetManager(sched, taskSet) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) // Offer a host with no CPUs assert(manager.resourceOffer("exec1", "host1", 0, ANY) === None) @@ -114,7 +116,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo sc = new SparkContext("local", "test") val sched = new FakeClusterScheduler(sc, ("exec1", "host1")) val taskSet = createTaskSet(3) - val manager = new ClusterTaskSetManager(sched, taskSet) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) // First three offers should all find tasks for (i <- 0 until 3) { @@ -151,7 +153,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo Seq() // Last task has no locality prefs ) val clock = new FakeClock - val manager = new ClusterTaskSetManager(sched, taskSet, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // First offer host1, exec1: first task should be chosen assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) @@ -197,7 +199,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo Seq(TaskLocation("host2")) ) val clock = new FakeClock - val manager = new ClusterTaskSetManager(sched, taskSet, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // First offer host1: first task should be chosen assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) @@ -234,7 +236,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo Seq(TaskLocation("host3")) ) val clock = new FakeClock - val manager = new ClusterTaskSetManager(sched, taskSet, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // First offer host1: first task should be chosen assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) @@ -262,7 +264,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo val sched = new FakeClusterScheduler(sc, ("exec1", "host1")) val taskSet = createTaskSet(1) val clock = new FakeClock - val manager = new ClusterTaskSetManager(sched, taskSet, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) @@ -279,17 +281,17 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo val sched = new FakeClusterScheduler(sc, ("exec1", "host1")) val taskSet = createTaskSet(1) val clock = new FakeClock - val manager = new ClusterTaskSetManager(sched, taskSet, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // Fail the task MAX_TASK_FAILURES times, and check that the task set is aborted // after the last failure. - (1 to manager.MAX_TASK_FAILURES).foreach { index => + (1 to manager.maxTaskFailures).foreach { index => val offerResult = manager.resourceOffer("exec1", "host1", 1, ANY) assert(offerResult != None, "Expect resource offer on iteration %s to return a task".format(index)) assert(offerResult.get.index === 0) manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, Some(TaskResultLost)) - if (index < manager.MAX_TASK_FAILURES) { + if (index < MAX_TASK_FAILURES) { assert(!sched.taskSetsFailed.contains(taskSet.id)) } else { assert(sched.taskSetsFailed.contains(taskSet.id)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala deleted file mode 100644 index 1e676c1719..0000000000 --- a/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala +++ /dev/null @@ -1,227 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.local - -import java.util.concurrent.Semaphore -import java.util.concurrent.CountDownLatch - -import scala.collection.mutable.HashMap - -import org.scalatest.{BeforeAndAfterEach, FunSuite} - -import org.apache.spark._ - - -class Lock() { - var finished = false - def jobWait() = { - synchronized { - while(!finished) { - this.wait() - } - } - } - - def jobFinished() = { - synchronized { - finished = true - this.notifyAll() - } - } -} - -object TaskThreadInfo { - val threadToLock = HashMap[Int, Lock]() - val threadToRunning = HashMap[Int, Boolean]() - val threadToStarted = HashMap[Int, CountDownLatch]() -} - -/* - * 1. each thread contains one job. - * 2. each job contains one stage. - * 3. each stage only contains one task. - * 4. each task(launched) must be lanched orderly(using threadToStarted) to make sure - * it will get cpu core resource, and will wait to finished after user manually - * release "Lock" and then cluster will contain another free cpu cores. - * 5. each task(pending) must use "sleep" to make sure it has been added to taskSetManager queue, - * thus it will be scheduled later when cluster has free cpu cores. - */ -class LocalSchedulerSuite extends FunSuite with LocalSparkContext with BeforeAndAfterEach { - - override def afterEach() { - super.afterEach() - System.clearProperty("spark.scheduler.mode") - } - - def createThread(threadIndex: Int, poolName: String, sc: SparkContext, sem: Semaphore) { - - TaskThreadInfo.threadToRunning(threadIndex) = false - val nums = sc.parallelize(threadIndex to threadIndex, 1) - TaskThreadInfo.threadToLock(threadIndex) = new Lock() - TaskThreadInfo.threadToStarted(threadIndex) = new CountDownLatch(1) - new Thread { - if (poolName != null) { - sc.setLocalProperty("spark.scheduler.pool", poolName) - } - override def run() { - val ans = nums.map(number => { - TaskThreadInfo.threadToRunning(number) = true - TaskThreadInfo.threadToStarted(number).countDown() - TaskThreadInfo.threadToLock(number).jobWait() - TaskThreadInfo.threadToRunning(number) = false - number - }).collect() - assert(ans.toList === List(threadIndex)) - sem.release() - } - }.start() - } - - test("Local FIFO scheduler end-to-end test") { - System.setProperty("spark.scheduler.mode", "FIFO") - sc = new SparkContext("local[4]", "test") - val sem = new Semaphore(0) - - createThread(1,null,sc,sem) - TaskThreadInfo.threadToStarted(1).await() - createThread(2,null,sc,sem) - TaskThreadInfo.threadToStarted(2).await() - createThread(3,null,sc,sem) - TaskThreadInfo.threadToStarted(3).await() - createThread(4,null,sc,sem) - TaskThreadInfo.threadToStarted(4).await() - // thread 5 and 6 (stage pending)must meet following two points - // 1. stages (taskSetManager) of jobs in thread 5 and 6 should be add to taskSetManager - // queue before executing TaskThreadInfo.threadToLock(1).jobFinished() - // 2. priority of stage in thread 5 should be prior to priority of stage in thread 6 - // So I just use "sleep" 1s here for each thread. - // TODO: any better solution? - createThread(5,null,sc,sem) - Thread.sleep(1000) - createThread(6,null,sc,sem) - Thread.sleep(1000) - - assert(TaskThreadInfo.threadToRunning(1) === true) - assert(TaskThreadInfo.threadToRunning(2) === true) - assert(TaskThreadInfo.threadToRunning(3) === true) - assert(TaskThreadInfo.threadToRunning(4) === true) - assert(TaskThreadInfo.threadToRunning(5) === false) - assert(TaskThreadInfo.threadToRunning(6) === false) - - TaskThreadInfo.threadToLock(1).jobFinished() - TaskThreadInfo.threadToStarted(5).await() - - assert(TaskThreadInfo.threadToRunning(1) === false) - assert(TaskThreadInfo.threadToRunning(2) === true) - assert(TaskThreadInfo.threadToRunning(3) === true) - assert(TaskThreadInfo.threadToRunning(4) === true) - assert(TaskThreadInfo.threadToRunning(5) === true) - assert(TaskThreadInfo.threadToRunning(6) === false) - - TaskThreadInfo.threadToLock(3).jobFinished() - TaskThreadInfo.threadToStarted(6).await() - - assert(TaskThreadInfo.threadToRunning(1) === false) - assert(TaskThreadInfo.threadToRunning(2) === true) - assert(TaskThreadInfo.threadToRunning(3) === false) - assert(TaskThreadInfo.threadToRunning(4) === true) - assert(TaskThreadInfo.threadToRunning(5) === true) - assert(TaskThreadInfo.threadToRunning(6) === true) - - TaskThreadInfo.threadToLock(2).jobFinished() - TaskThreadInfo.threadToLock(4).jobFinished() - TaskThreadInfo.threadToLock(5).jobFinished() - TaskThreadInfo.threadToLock(6).jobFinished() - sem.acquire(6) - } - - test("Local fair scheduler end-to-end test") { - System.setProperty("spark.scheduler.mode", "FAIR") - val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() - System.setProperty("spark.scheduler.allocation.file", xmlPath) - - sc = new SparkContext("local[8]", "LocalSchedulerSuite") - val sem = new Semaphore(0) - - createThread(10,"1",sc,sem) - TaskThreadInfo.threadToStarted(10).await() - createThread(20,"2",sc,sem) - TaskThreadInfo.threadToStarted(20).await() - createThread(30,"3",sc,sem) - TaskThreadInfo.threadToStarted(30).await() - - assert(TaskThreadInfo.threadToRunning(10) === true) - assert(TaskThreadInfo.threadToRunning(20) === true) - assert(TaskThreadInfo.threadToRunning(30) === true) - - createThread(11,"1",sc,sem) - TaskThreadInfo.threadToStarted(11).await() - createThread(21,"2",sc,sem) - TaskThreadInfo.threadToStarted(21).await() - createThread(31,"3",sc,sem) - TaskThreadInfo.threadToStarted(31).await() - - assert(TaskThreadInfo.threadToRunning(11) === true) - assert(TaskThreadInfo.threadToRunning(21) === true) - assert(TaskThreadInfo.threadToRunning(31) === true) - - createThread(12,"1",sc,sem) - TaskThreadInfo.threadToStarted(12).await() - createThread(22,"2",sc,sem) - TaskThreadInfo.threadToStarted(22).await() - createThread(32,"3",sc,sem) - - assert(TaskThreadInfo.threadToRunning(12) === true) - assert(TaskThreadInfo.threadToRunning(22) === true) - assert(TaskThreadInfo.threadToRunning(32) === false) - - TaskThreadInfo.threadToLock(10).jobFinished() - TaskThreadInfo.threadToStarted(32).await() - - assert(TaskThreadInfo.threadToRunning(32) === true) - - //1. Similar with above scenario, sleep 1s for stage of 23 and 33 to be added to taskSetManager - // queue so that cluster will assign free cpu core to stage 23 after stage 11 finished. - //2. priority of 23 and 33 will be meaningless as using fair scheduler here. - createThread(23,"2",sc,sem) - createThread(33,"3",sc,sem) - Thread.sleep(1000) - - TaskThreadInfo.threadToLock(11).jobFinished() - TaskThreadInfo.threadToStarted(23).await() - - assert(TaskThreadInfo.threadToRunning(23) === true) - assert(TaskThreadInfo.threadToRunning(33) === false) - - TaskThreadInfo.threadToLock(12).jobFinished() - TaskThreadInfo.threadToStarted(33).await() - - assert(TaskThreadInfo.threadToRunning(33) === true) - - TaskThreadInfo.threadToLock(20).jobFinished() - TaskThreadInfo.threadToLock(21).jobFinished() - TaskThreadInfo.threadToLock(22).jobFinished() - TaskThreadInfo.threadToLock(23).jobFinished() - TaskThreadInfo.threadToLock(30).jobFinished() - TaskThreadInfo.threadToLock(31).jobFinished() - TaskThreadInfo.threadToLock(32).jobFinished() - TaskThreadInfo.threadToLock(33).jobFinished() - - sem.acquire(11) - } -} diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala new file mode 100644 index 0000000000..67a57a0e7f --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +import org.scalatest.FunSuite +import org.apache.spark.scheduler._ +import org.apache.spark.{LocalSparkContext, SparkContext, Success} +import org.apache.spark.scheduler.SparkListenerTaskStart +import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics} + +class JobProgressListenerSuite extends FunSuite with LocalSparkContext { + test("test executor id to summary") { + val sc = new SparkContext("local", "test") + val listener = new JobProgressListener(sc) + val taskMetrics = new TaskMetrics() + val shuffleReadMetrics = new ShuffleReadMetrics() + + // nothing in it + assert(listener.stageIdToExecutorSummaries.size == 0) + + // finish this task, should get updated shuffleRead + shuffleReadMetrics.remoteBytesRead = 1000 + taskMetrics.shuffleReadMetrics = Some(shuffleReadMetrics) + var taskInfo = new TaskInfo(1234L, 0, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL) + taskInfo.finishTime = 1 + listener.onTaskEnd(new SparkListenerTaskEnd( + new ShuffleMapTask(0, null, null, 0, null), Success, taskInfo, taskMetrics)) + assert(listener.stageIdToExecutorSummaries.getOrElse(0, fail()).getOrElse("exe-1", fail()) + .shuffleRead == 1000) + + // finish a task with unknown executor-id, nothing should happen + taskInfo = new TaskInfo(1234L, 0, 1000L, "exe-unknown", "host1", TaskLocality.NODE_LOCAL) + taskInfo.finishTime = 1 + listener.onTaskEnd(new SparkListenerTaskEnd( + new ShuffleMapTask(0, null, null, 0, null), Success, taskInfo, taskMetrics)) + assert(listener.stageIdToExecutorSummaries.size == 1) + + // finish this task, should get updated duration + shuffleReadMetrics.remoteBytesRead = 1000 + taskMetrics.shuffleReadMetrics = Some(shuffleReadMetrics) + taskInfo = new TaskInfo(1235L, 0, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL) + taskInfo.finishTime = 1 + listener.onTaskEnd(new SparkListenerTaskEnd( + new ShuffleMapTask(0, null, null, 0, null), Success, taskInfo, taskMetrics)) + assert(listener.stageIdToExecutorSummaries.getOrElse(0, fail()).getOrElse("exe-1", fail()) + .shuffleRead == 2000) + + // finish this task, should get updated duration + shuffleReadMetrics.remoteBytesRead = 1000 + taskMetrics.shuffleReadMetrics = Some(shuffleReadMetrics) + taskInfo = new TaskInfo(1236L, 0, 0L, "exe-2", "host1", TaskLocality.NODE_LOCAL) + taskInfo.finishTime = 1 + listener.onTaskEnd(new SparkListenerTaskEnd( + new ShuffleMapTask(0, null, null, 0, null), Success, taskInfo, taskMetrics)) + assert(listener.stageIdToExecutorSummaries.getOrElse(0, fail()).getOrElse("exe-2", fail()) + .shuffleRead == 1000) + } +} diff --git a/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java b/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java index 9a8e4209ed..22994fb2ec 100644 --- a/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java +++ b/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java @@ -53,7 +53,7 @@ public class JavaKafkaWordCount { } // Create the context with a 1 second batch size - JavaStreamingContext ssc = new JavaStreamingContext(args[0], "NetworkWordCount", + JavaStreamingContext ssc = new JavaStreamingContext(args[0], "KafkaWordCount", new Duration(2000), System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); int numThreads = Integer.parseInt(args[4]); diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala new file mode 100644 index 0000000000..8247c1ebc5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.api.python +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.regression._ +import org.apache.spark.mllib.classification._ +import org.apache.spark.mllib.clustering._ +import org.apache.spark.mllib.recommendation._ +import org.apache.spark.rdd.RDD +import java.nio.ByteBuffer +import java.nio.ByteOrder +import java.nio.DoubleBuffer + +/** + * The Java stubs necessary for the Python mllib bindings. + */ +class PythonMLLibAPI extends Serializable { + private def deserializeDoubleVector(bytes: Array[Byte]): Array[Double] = { + val packetLength = bytes.length + if (packetLength < 16) { + throw new IllegalArgumentException("Byte array too short.") + } + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + val magic = bb.getLong() + if (magic != 1) { + throw new IllegalArgumentException("Magic " + magic + " is wrong.") + } + val length = bb.getLong() + if (packetLength != 16 + 8 * length) { + throw new IllegalArgumentException("Length " + length + " is wrong.") + } + val db = bb.asDoubleBuffer() + val ans = new Array[Double](length.toInt) + db.get(ans) + return ans + } + + private def serializeDoubleVector(doubles: Array[Double]): Array[Byte] = { + val len = doubles.length + val bytes = new Array[Byte](16 + 8 * len) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.putLong(1) + bb.putLong(len) + val db = bb.asDoubleBuffer() + db.put(doubles) + return bytes + } + + private def deserializeDoubleMatrix(bytes: Array[Byte]): Array[Array[Double]] = { + val packetLength = bytes.length + if (packetLength < 24) { + throw new IllegalArgumentException("Byte array too short.") + } + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + val magic = bb.getLong() + if (magic != 2) { + throw new IllegalArgumentException("Magic " + magic + " is wrong.") + } + val rows = bb.getLong() + val cols = bb.getLong() + if (packetLength != 24 + 8 * rows * cols) { + throw new IllegalArgumentException("Size " + rows + "x" + cols + " is wrong.") + } + val db = bb.asDoubleBuffer() + val ans = new Array[Array[Double]](rows.toInt) + var i = 0 + for (i <- 0 until rows.toInt) { + ans(i) = new Array[Double](cols.toInt) + db.get(ans(i)) + } + return ans + } + + private def serializeDoubleMatrix(doubles: Array[Array[Double]]): Array[Byte] = { + val rows = doubles.length + var cols = 0 + if (rows > 0) { + cols = doubles(0).length + } + val bytes = new Array[Byte](24 + 8 * rows * cols) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.putLong(2) + bb.putLong(rows) + bb.putLong(cols) + val db = bb.asDoubleBuffer() + var i = 0 + for (i <- 0 until rows) { + db.put(doubles(i)) + } + return bytes + } + + private def trainRegressionModel(trainFunc: (RDD[LabeledPoint], Array[Double]) => GeneralizedLinearModel, + dataBytesJRDD: JavaRDD[Array[Byte]], initialWeightsBA: Array[Byte]): + java.util.LinkedList[java.lang.Object] = { + val data = dataBytesJRDD.rdd.map(xBytes => { + val x = deserializeDoubleVector(xBytes) + LabeledPoint(x(0), x.slice(1, x.length)) + }) + val initialWeights = deserializeDoubleVector(initialWeightsBA) + val model = trainFunc(data, initialWeights) + val ret = new java.util.LinkedList[java.lang.Object]() + ret.add(serializeDoubleVector(model.weights)) + ret.add(model.intercept: java.lang.Double) + return ret + } + + /** + * Java stub for Python mllib LinearRegressionWithSGD.train() + */ + def trainLinearRegressionModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], + numIterations: Int, stepSize: Double, miniBatchFraction: Double, + initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + return trainRegressionModel((data, initialWeights) => + LinearRegressionWithSGD.train(data, numIterations, stepSize, + miniBatchFraction, initialWeights), + dataBytesJRDD, initialWeightsBA) + } + + /** + * Java stub for Python mllib LassoWithSGD.train() + */ + def trainLassoModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, + stepSize: Double, regParam: Double, miniBatchFraction: Double, + initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + return trainRegressionModel((data, initialWeights) => + LassoWithSGD.train(data, numIterations, stepSize, regParam, + miniBatchFraction, initialWeights), + dataBytesJRDD, initialWeightsBA) + } + + /** + * Java stub for Python mllib RidgeRegressionWithSGD.train() + */ + def trainRidgeModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, + stepSize: Double, regParam: Double, miniBatchFraction: Double, + initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + return trainRegressionModel((data, initialWeights) => + RidgeRegressionWithSGD.train(data, numIterations, stepSize, regParam, + miniBatchFraction, initialWeights), + dataBytesJRDD, initialWeightsBA) + } + + /** + * Java stub for Python mllib SVMWithSGD.train() + */ + def trainSVMModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, + stepSize: Double, regParam: Double, miniBatchFraction: Double, + initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + return trainRegressionModel((data, initialWeights) => + SVMWithSGD.train(data, numIterations, stepSize, regParam, + miniBatchFraction, initialWeights), + dataBytesJRDD, initialWeightsBA) + } + + /** + * Java stub for Python mllib LogisticRegressionWithSGD.train() + */ + def trainLogisticRegressionModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], + numIterations: Int, stepSize: Double, miniBatchFraction: Double, + initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + return trainRegressionModel((data, initialWeights) => + LogisticRegressionWithSGD.train(data, numIterations, stepSize, + miniBatchFraction, initialWeights), + dataBytesJRDD, initialWeightsBA) + } + + /** + * Java stub for Python mllib KMeans.train() + */ + def trainKMeansModel(dataBytesJRDD: JavaRDD[Array[Byte]], k: Int, + maxIterations: Int, runs: Int, initializationMode: String): + java.util.List[java.lang.Object] = { + val data = dataBytesJRDD.rdd.map(xBytes => deserializeDoubleVector(xBytes)) + val model = KMeans.train(data, k, maxIterations, runs, initializationMode) + val ret = new java.util.LinkedList[java.lang.Object]() + ret.add(serializeDoubleMatrix(model.clusterCenters)) + return ret + } + + private def unpackRating(ratingBytes: Array[Byte]): Rating = { + val bb = ByteBuffer.wrap(ratingBytes) + bb.order(ByteOrder.nativeOrder()) + val user = bb.getInt() + val product = bb.getInt() + val rating = bb.getDouble() + return new Rating(user, product, rating) + } + + /** + * Java stub for Python mllib ALS.train(). This stub returns a handle + * to the Java object instead of the content of the Java object. Extra care + * needs to be taken in the Python code to ensure it gets freed on exit; see + * the Py4J documentation. + */ + def trainALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int, + iterations: Int, lambda: Double, blocks: Int): MatrixFactorizationModel = { + val ratings = ratingsBytesJRDD.rdd.map(unpackRating) + return ALS.train(ratings, rank, iterations, lambda, blocks) + } + + /** + * Java stub for Python mllib ALS.trainImplicit(). This stub returns a + * handle to the Java object instead of the content of the Java object. + * Extra care needs to be taken in the Python code to ensure it gets freed on + * exit; see the Py4J documentation. + */ + def trainImplicitALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int, + iterations: Int, lambda: Double, blocks: Int, alpha: Double): MatrixFactorizationModel = { + val ratings = ratingsBytesJRDD.rdd.map(unpackRating) + return ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha) + } +} diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 963b5b88be..1bba6a5ae4 100644 --- a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -437,8 +437,10 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl } def monitorApplication(appId: ApplicationId): Boolean = { + val interval = new SparkConf().getOrElse("spark.yarn.report.interval", "1000").toLong + while (true) { - Thread.sleep(1000) + Thread.sleep(interval) val report = super.getApplicationReport(appId) logInfo("Application report from ASM: \n" + diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala index 71d1cbd416..abc3447746 100644 --- a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala +++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala @@ -27,8 +27,8 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import org.apache.spark.Logging -import org.apache.spark.scheduler.SplitInfo -import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedulerBackend} +import org.apache.spark.scheduler.{SplitInfo,TaskSchedulerImpl} +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils import org.apache.hadoop.conf.Configuration @@ -233,9 +233,9 @@ private[yarn] class YarnAllocationHandler( // Note that the list we create below tries to ensure that not all containers end up within // a host if there is a sufficiently large number of hosts/containers. val allocatedContainersToProcess = new ArrayBuffer[Container](allocatedContainers.size) - allocatedContainersToProcess ++= ClusterScheduler.prioritizeContainers(dataLocalContainers) - allocatedContainersToProcess ++= ClusterScheduler.prioritizeContainers(rackLocalContainers) - allocatedContainersToProcess ++= ClusterScheduler.prioritizeContainers(offRackContainers) + allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(dataLocalContainers) + allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(rackLocalContainers) + allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(offRackContainers) // Run each of the allocated containers. for (container <- allocatedContainersToProcess) { diff --git a/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala index 63a0449e5a..522e0a9ad7 100644 --- a/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala +++ b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala @@ -20,13 +20,14 @@ package org.apache.spark.scheduler.cluster import org.apache.spark._ import org.apache.hadoop.conf.Configuration import org.apache.spark.deploy.yarn.YarnAllocationHandler +import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.util.Utils /** * * This scheduler launch worker through Yarn - by call into Client to launch WorkerLauncher as AM. */ -private[spark] class YarnClientClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) { +private[spark] class YarnClientClusterScheduler(sc: SparkContext, conf: Configuration) extends TaskSchedulerImpl(sc) { def this(sc: SparkContext) = this(sc, new Configuration()) diff --git a/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 6feaaff014..4b69f5078b 100644 --- a/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -20,9 +20,10 @@ package org.apache.spark.scheduler.cluster import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState} import org.apache.spark.{SparkException, Logging, SparkContext} import org.apache.spark.deploy.yarn.{Client, ClientArguments} +import org.apache.spark.scheduler.TaskSchedulerImpl private[spark] class YarnClientSchedulerBackend( - scheduler: ClusterScheduler, + scheduler: TaskSchedulerImpl, sc: SparkContext) extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) with Logging { diff --git a/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala index 29b3f22e13..a4638cc863 100644 --- a/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala +++ b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala @@ -19,6 +19,7 @@ package org.apache.spark.scheduler.cluster import org.apache.spark._ import org.apache.spark.deploy.yarn.{ApplicationMaster, YarnAllocationHandler} +import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.util.Utils import org.apache.hadoop.conf.Configuration @@ -26,7 +27,7 @@ import org.apache.hadoop.conf.Configuration * * This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of ApplicationMaster, etc is done */ -private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) { +private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) extends TaskSchedulerImpl(sc) { logInfo("Created YarnClusterScheduler") diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index ffb54a24ac..37d6f1b60d 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -114,6 +114,9 @@ object SparkBuild extends Build { fork := true, javaOptions += "-Xmx3g", + // Show full stack trace and duration in test cases. + testOptions in Test += Tests.Argument("-oDF"), + // Only allow one test at a time, even across projects, since they run in the same JVM concurrentRestrictions in Global += Tags.limit(Tags.Test, 1), @@ -260,7 +263,7 @@ object SparkBuild extends Build { libraryDependencies <+= scalaVersion(v => "org.scala-lang" % "scala-reflect" % v ) ) - + def examplesSettings = sharedSettings ++ Seq( name := "spark-examples", libraryDependencies ++= Seq( diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 128f078d12..d8ca9fce00 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -63,5 +63,6 @@ def launch_gateway(): java_import(gateway.jvm, "org.apache.spark.SparkConf") java_import(gateway.jvm, "org.apache.spark.api.java.*") java_import(gateway.jvm, "org.apache.spark.api.python.*") + java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") java_import(gateway.jvm, "scala.Tuple2") return gateway diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py new file mode 100644 index 0000000000..b1a5df109b --- /dev/null +++ b/python/pyspark/mllib/__init__.py @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Python bindings for MLlib. +""" diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py new file mode 100644 index 0000000000..e74ba0fabc --- /dev/null +++ b/python/pyspark/mllib/_common.py @@ -0,0 +1,227 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape +from pyspark import SparkContext + +# Double vector format: +# +# [8-byte 1] [8-byte length] [length*8 bytes of data] +# +# Double matrix format: +# +# [8-byte 2] [8-byte rows] [8-byte cols] [rows*cols*8 bytes of data] +# +# This is all in machine-endian. That means that the Java interpreter and the +# Python interpreter must agree on what endian the machine is. + +def _deserialize_byte_array(shape, ba, offset): + """Wrapper around ndarray aliasing hack. + + >>> x = array([1.0, 2.0, 3.0, 4.0, 5.0]) + >>> array_equal(x, _deserialize_byte_array(x.shape, x.data, 0)) + True + >>> x = array([1.0, 2.0, 3.0, 4.0]).reshape(2,2) + >>> array_equal(x, _deserialize_byte_array(x.shape, x.data, 0)) + True + """ + ar = ndarray(shape=shape, buffer=ba, offset=offset, dtype="float64", + order='C') + return ar.copy() + +def _serialize_double_vector(v): + """Serialize a double vector into a mutually understood format.""" + if type(v) != ndarray: + raise TypeError("_serialize_double_vector called on a %s; " + "wanted ndarray" % type(v)) + if v.dtype != float64: + raise TypeError("_serialize_double_vector called on an ndarray of %s; " + "wanted ndarray of float64" % v.dtype) + if v.ndim != 1: + raise TypeError("_serialize_double_vector called on a %ddarray; " + "wanted a 1darray" % v.ndim) + length = v.shape[0] + ba = bytearray(16 + 8*length) + header = ndarray(shape=[2], buffer=ba, dtype="int64") + header[0] = 1 + header[1] = length + copyto(ndarray(shape=[length], buffer=ba, offset=16, + dtype="float64"), v) + return ba + +def _deserialize_double_vector(ba): + """Deserialize a double vector from a mutually understood format. + + >>> x = array([1.0, 2.0, 3.0, 4.0, -1.0, 0.0, -0.0]) + >>> array_equal(x, _deserialize_double_vector(_serialize_double_vector(x))) + True + """ + if type(ba) != bytearray: + raise TypeError("_deserialize_double_vector called on a %s; " + "wanted bytearray" % type(ba)) + if len(ba) < 16: + raise TypeError("_deserialize_double_vector called on a %d-byte array, " + "which is too short" % len(ba)) + if (len(ba) & 7) != 0: + raise TypeError("_deserialize_double_vector called on a %d-byte array, " + "which is not a multiple of 8" % len(ba)) + header = ndarray(shape=[2], buffer=ba, dtype="int64") + if header[0] != 1: + raise TypeError("_deserialize_double_vector called on bytearray " + "with wrong magic") + length = header[1] + if len(ba) != 8*length + 16: + raise TypeError("_deserialize_double_vector called on bytearray " + "with wrong length") + return _deserialize_byte_array([length], ba, 16) + +def _serialize_double_matrix(m): + """Serialize a double matrix into a mutually understood format.""" + if (type(m) == ndarray and m.dtype == float64 and m.ndim == 2): + rows = m.shape[0] + cols = m.shape[1] + ba = bytearray(24 + 8 * rows * cols) + header = ndarray(shape=[3], buffer=ba, dtype="int64") + header[0] = 2 + header[1] = rows + header[2] = cols + copyto(ndarray(shape=[rows, cols], buffer=ba, offset=24, + dtype="float64", order='C'), m) + return ba + else: + raise TypeError("_serialize_double_matrix called on a " + "non-double-matrix") + +def _deserialize_double_matrix(ba): + """Deserialize a double matrix from a mutually understood format.""" + if type(ba) != bytearray: + raise TypeError("_deserialize_double_matrix called on a %s; " + "wanted bytearray" % type(ba)) + if len(ba) < 24: + raise TypeError("_deserialize_double_matrix called on a %d-byte array, " + "which is too short" % len(ba)) + if (len(ba) & 7) != 0: + raise TypeError("_deserialize_double_matrix called on a %d-byte array, " + "which is not a multiple of 8" % len(ba)) + header = ndarray(shape=[3], buffer=ba, dtype="int64") + if (header[0] != 2): + raise TypeError("_deserialize_double_matrix called on bytearray " + "with wrong magic") + rows = header[1] + cols = header[2] + if (len(ba) != 8*rows*cols + 24): + raise TypeError("_deserialize_double_matrix called on bytearray " + "with wrong length") + return _deserialize_byte_array([rows, cols], ba, 24) + +def _linear_predictor_typecheck(x, coeffs): + """Check that x is a one-dimensional vector of the right shape. + This is a temporary hackaround until I actually implement bulk predict.""" + if type(x) == ndarray: + if x.ndim == 1: + if x.shape == coeffs.shape: + pass + else: + raise RuntimeError("Got array of %d elements; wanted %d" + % (shape(x)[0], shape(coeffs)[0])) + else: + raise RuntimeError("Bulk predict not yet supported.") + elif (type(x) == RDD): + raise RuntimeError("Bulk predict not yet supported.") + else: + raise TypeError("Argument of type " + type(x).__name__ + " unsupported") + +def _get_unmangled_rdd(data, serializer): + dataBytes = data.map(serializer) + dataBytes._bypass_serializer = True + dataBytes.cache() + return dataBytes + +# Map a pickled Python RDD of numpy double vectors to a Java RDD of +# _serialized_double_vectors +def _get_unmangled_double_vector_rdd(data): + return _get_unmangled_rdd(data, _serialize_double_vector) + +class LinearModel(object): + """Something that has a vector of coefficients and an intercept.""" + def __init__(self, coeff, intercept): + self._coeff = coeff + self._intercept = intercept + +class LinearRegressionModelBase(LinearModel): + """A linear regression model. + + >>> lrmb = LinearRegressionModelBase(array([1.0, 2.0]), 0.1) + >>> abs(lrmb.predict(array([-1.03, 7.777])) - 14.624) < 1e-6 + True + """ + def predict(self, x): + """Predict the value of the dependent variable given a vector x""" + """containing values for the independent variables.""" + _linear_predictor_typecheck(x, self._coeff) + return dot(self._coeff, x) + self._intercept + +# If we weren't given initial weights, take a zero vector of the appropriate +# length. +def _get_initial_weights(initial_weights, data): + if initial_weights is None: + initial_weights = data.first() + if type(initial_weights) != ndarray: + raise TypeError("At least one data element has type " + + type(initial_weights).__name__ + " which is not ndarray") + if initial_weights.ndim != 1: + raise TypeError("At least one data element has " + + initial_weights.ndim + " dimensions, which is not 1") + initial_weights = ones([initial_weights.shape[0] - 1]) + return initial_weights + +# train_func should take two parameters, namely data and initial_weights, and +# return the result of a call to the appropriate JVM stub. +# _regression_train_wrapper is responsible for setup and error checking. +def _regression_train_wrapper(sc, train_func, klass, data, initial_weights): + initial_weights = _get_initial_weights(initial_weights, data) + dataBytes = _get_unmangled_double_vector_rdd(data) + ans = train_func(dataBytes, _serialize_double_vector(initial_weights)) + if len(ans) != 2: + raise RuntimeError("JVM call result had unexpected length") + elif type(ans[0]) != bytearray: + raise RuntimeError("JVM call result had first element of type " + + type(ans[0]).__name__ + " which is not bytearray") + elif type(ans[1]) != float: + raise RuntimeError("JVM call result had second element of type " + + type(ans[0]).__name__ + " which is not float") + return klass(_deserialize_double_vector(ans[0]), ans[1]) + +def _serialize_rating(r): + ba = bytearray(16) + intpart = ndarray(shape=[2], buffer=ba, dtype=int32) + doublepart = ndarray(shape=[1], buffer=ba, dtype=float64, offset=8) + intpart[0], intpart[1], doublepart[0] = r + return ba + +def _test(): + import doctest + globs = globals().copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + (failure_count, test_count) = doctest.testmod(globs=globs, + optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py new file mode 100644 index 0000000000..70de332d34 --- /dev/null +++ b/python/pyspark/mllib/classification.py @@ -0,0 +1,86 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from numpy import array, dot, shape +from pyspark import SparkContext +from pyspark.mllib._common import \ + _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \ + _serialize_double_matrix, _deserialize_double_matrix, \ + _serialize_double_vector, _deserialize_double_vector, \ + _get_initial_weights, _serialize_rating, _regression_train_wrapper, \ + LinearModel, _linear_predictor_typecheck +from math import exp, log + +class LogisticRegressionModel(LinearModel): + """A linear binary classification model derived from logistic regression. + + >>> data = array([0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 1.0, 3.0]).reshape(4,2) + >>> lrm = LogisticRegressionWithSGD.train(sc, sc.parallelize(data)) + >>> lrm.predict(array([1.0])) != None + True + """ + def predict(self, x): + _linear_predictor_typecheck(x, self._coeff) + margin = dot(x, self._coeff) + self._intercept + prob = 1/(1 + exp(-margin)) + return 1 if prob > 0.5 else 0 + +class LogisticRegressionWithSGD(object): + @classmethod + def train(cls, sc, data, iterations=100, step=1.0, + mini_batch_fraction=1.0, initial_weights=None): + """Train a logistic regression model on the given data.""" + return _regression_train_wrapper(sc, lambda d, i: + sc._jvm.PythonMLLibAPI().trainLogisticRegressionModelWithSGD(d._jrdd, + iterations, step, mini_batch_fraction, i), + LogisticRegressionModel, data, initial_weights) + +class SVMModel(LinearModel): + """A support vector machine. + + >>> data = array([0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 1.0, 3.0]).reshape(4,2) + >>> svm = SVMWithSGD.train(sc, sc.parallelize(data)) + >>> svm.predict(array([1.0])) != None + True + """ + def predict(self, x): + _linear_predictor_typecheck(x, self._coeff) + margin = dot(x, self._coeff) + self._intercept + return 1 if margin >= 0 else 0 + +class SVMWithSGD(object): + @classmethod + def train(cls, sc, data, iterations=100, step=1.0, reg_param=1.0, + mini_batch_fraction=1.0, initial_weights=None): + """Train a support vector machine on the given data.""" + return _regression_train_wrapper(sc, lambda d, i: + sc._jvm.PythonMLLibAPI().trainSVMModelWithSGD(d._jrdd, + iterations, step, reg_param, mini_batch_fraction, i), + SVMModel, data, initial_weights) + +def _test(): + import doctest + globs = globals().copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + (failure_count, test_count) = doctest.testmod(globs=globs, + optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py new file mode 100644 index 0000000000..8cf20e591a --- /dev/null +++ b/python/pyspark/mllib/clustering.py @@ -0,0 +1,79 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from numpy import array, dot +from math import sqrt +from pyspark import SparkContext +from pyspark.mllib._common import \ + _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \ + _serialize_double_matrix, _deserialize_double_matrix, \ + _serialize_double_vector, _deserialize_double_vector, \ + _get_initial_weights, _serialize_rating, _regression_train_wrapper + +class KMeansModel(object): + """A clustering model derived from the k-means method. + + >>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4,2) + >>> clusters = KMeans.train(sc, sc.parallelize(data), 2, maxIterations=10, runs=30, initialization_mode="random") + >>> clusters.predict(array([0.0, 0.0])) == clusters.predict(array([1.0, 1.0])) + True + >>> clusters.predict(array([8.0, 9.0])) == clusters.predict(array([9.0, 8.0])) + True + >>> clusters = KMeans.train(sc, sc.parallelize(data), 2) + """ + def __init__(self, centers_): + self.centers = centers_ + + def predict(self, x): + """Find the cluster to which x belongs in this model.""" + best = 0 + best_distance = 1e75 + for i in range(0, self.centers.shape[0]): + diff = x - self.centers[i] + distance = sqrt(dot(diff, diff)) + if distance < best_distance: + best = i + best_distance = distance + return best + +class KMeans(object): + @classmethod + def train(cls, sc, data, k, maxIterations=100, runs=1, + initialization_mode="k-means||"): + """Train a k-means clustering model.""" + dataBytes = _get_unmangled_double_vector_rdd(data) + ans = sc._jvm.PythonMLLibAPI().trainKMeansModel(dataBytes._jrdd, + k, maxIterations, runs, initialization_mode) + if len(ans) != 1: + raise RuntimeError("JVM call result had unexpected length") + elif type(ans[0]) != bytearray: + raise RuntimeError("JVM call result had first element of type " + + type(ans[0]) + " which is not bytearray") + return KMeansModel(_deserialize_double_matrix(ans[0])) + +def _test(): + import doctest + globs = globals().copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + (failure_count, test_count) = doctest.testmod(globs=globs, + optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py new file mode 100644 index 0000000000..14d06cba21 --- /dev/null +++ b/python/pyspark/mllib/recommendation.py @@ -0,0 +1,74 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark import SparkContext +from pyspark.mllib._common import \ + _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \ + _serialize_double_matrix, _deserialize_double_matrix, \ + _serialize_double_vector, _deserialize_double_vector, \ + _get_initial_weights, _serialize_rating, _regression_train_wrapper + +class MatrixFactorizationModel(object): + """A matrix factorisation model trained by regularized alternating + least-squares. + + >>> r1 = (1, 1, 1.0) + >>> r2 = (1, 2, 2.0) + >>> r3 = (2, 1, 2.0) + >>> ratings = sc.parallelize([r1, r2, r3]) + >>> model = ALS.trainImplicit(sc, ratings, 1) + >>> model.predict(2,2) is not None + True + """ + + def __init__(self, sc, java_model): + self._context = sc + self._java_model = java_model + + def __del__(self): + self._context._gateway.detach(self._java_model) + + def predict(self, user, product): + return self._java_model.predict(user, product) + +class ALS(object): + @classmethod + def train(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1): + ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating) + mod = sc._jvm.PythonMLLibAPI().trainALSModel(ratingBytes._jrdd, + rank, iterations, lambda_, blocks) + return MatrixFactorizationModel(sc, mod) + + @classmethod + def trainImplicit(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01): + ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating) + mod = sc._jvm.PythonMLLibAPI().trainImplicitALSModel(ratingBytes._jrdd, + rank, iterations, lambda_, blocks, alpha) + return MatrixFactorizationModel(sc, mod) + +def _test(): + import doctest + globs = globals().copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + (failure_count, test_count) = doctest.testmod(globs=globs, + optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py new file mode 100644 index 0000000000..a3a68b29e0 --- /dev/null +++ b/python/pyspark/mllib/regression.py @@ -0,0 +1,110 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from numpy import array, dot +from pyspark import SparkContext +from pyspark.mllib._common import \ + _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \ + _serialize_double_matrix, _deserialize_double_matrix, \ + _serialize_double_vector, _deserialize_double_vector, \ + _get_initial_weights, _serialize_rating, _regression_train_wrapper, \ + _linear_predictor_typecheck + +class LinearModel(object): + """Something that has a vector of coefficients and an intercept.""" + def __init__(self, coeff, intercept): + self._coeff = coeff + self._intercept = intercept + +class LinearRegressionModelBase(LinearModel): + """A linear regression model. + + >>> lrmb = LinearRegressionModelBase(array([1.0, 2.0]), 0.1) + >>> abs(lrmb.predict(array([-1.03, 7.777])) - 14.624) < 1e-6 + True + """ + def predict(self, x): + """Predict the value of the dependent variable given a vector x""" + """containing values for the independent variables.""" + _linear_predictor_typecheck(x, self._coeff) + return dot(self._coeff, x) + self._intercept + +class LinearRegressionModel(LinearRegressionModelBase): + """A linear regression model derived from a least-squares fit. + + >>> data = array([0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0]).reshape(4,2) + >>> lrm = LinearRegressionWithSGD.train(sc, sc.parallelize(data), initial_weights=array([1.0])) + """ + +class LinearRegressionWithSGD(object): + @classmethod + def train(cls, sc, data, iterations=100, step=1.0, + mini_batch_fraction=1.0, initial_weights=None): + """Train a linear regression model on the given data.""" + return _regression_train_wrapper(sc, lambda d, i: + sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD( + d._jrdd, iterations, step, mini_batch_fraction, i), + LinearRegressionModel, data, initial_weights) + +class LassoModel(LinearRegressionModelBase): + """A linear regression model derived from a least-squares fit with an + l_1 penalty term. + + >>> data = array([0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0]).reshape(4,2) + >>> lrm = LassoWithSGD.train(sc, sc.parallelize(data), initial_weights=array([1.0])) + """ + +class LassoWithSGD(object): + @classmethod + def train(cls, sc, data, iterations=100, step=1.0, reg_param=1.0, + mini_batch_fraction=1.0, initial_weights=None): + """Train a Lasso regression model on the given data.""" + return _regression_train_wrapper(sc, lambda d, i: + sc._jvm.PythonMLLibAPI().trainLassoModelWithSGD(d._jrdd, + iterations, step, reg_param, mini_batch_fraction, i), + LassoModel, data, initial_weights) + +class RidgeRegressionModel(LinearRegressionModelBase): + """A linear regression model derived from a least-squares fit with an + l_2 penalty term. + + >>> data = array([0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0]).reshape(4,2) + >>> lrm = RidgeRegressionWithSGD.train(sc, sc.parallelize(data), initial_weights=array([1.0])) + """ + +class RidgeRegressionWithSGD(object): + @classmethod + def train(cls, sc, data, iterations=100, step=1.0, reg_param=1.0, + mini_batch_fraction=1.0, initial_weights=None): + """Train a ridge regression model on the given data.""" + return _regression_train_wrapper(sc, lambda d, i: + sc._jvm.PythonMLLibAPI().trainRidgeModelWithSGD(d._jrdd, + iterations, step, reg_param, mini_batch_fraction, i), + RidgeRegressionModel, data, initial_weights) + +def _test(): + import doctest + globs = globals().copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + (failure_count, test_count) = doctest.testmod(globs=globs, + optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 811fa6f018..2a500ab919 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -308,4 +308,4 @@ def write_int(value, stream): def write_with_length(obj, stream): write_int(len(obj), stream) - stream.write(obj)
\ No newline at end of file + stream.write(obj) diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index a475959090..ef07eb437b 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -42,7 +42,7 @@ print "Using Python version %s (%s, %s)" % ( platform.python_version(), platform.python_build()[0], platform.python_build()[1]) -print "Spark context avaiable as sc." +print "Spark context available as sc." if add_files != None: print "Adding files: [%s]" % ", ".join(add_files) diff --git a/spark-class b/spark-class index 4eb95a9ba2..802e4aa104 100755 --- a/spark-class +++ b/spark-class @@ -129,11 +129,11 @@ fi # Compute classpath using external script CLASSPATH=`$FWDIR/bin/compute-classpath.sh` -CLASSPATH="$SPARK_TOOLS_JAR:$CLASSPATH" +CLASSPATH="$CLASSPATH:$SPARK_TOOLS_JAR" if $cygwin; then - CLASSPATH=`cygpath -wp $CLASSPATH` - export SPARK_TOOLS_JAR=`cygpath -w $SPARK_TOOLS_JAR` + CLASSPATH=`cygpath -wp $CLASSPATH` + export SPARK_TOOLS_JAR=`cygpath -w $SPARK_TOOLS_JAR` fi export CLASSPATH diff --git a/spark-class2.cmd b/spark-class2.cmd index 3869d0761b..dc9dadf356 100644 --- a/spark-class2.cmd +++ b/spark-class2.cmd @@ -17,7 +17,7 @@ rem See the License for the specific language governing permissions and rem limitations under the License. rem -set SCALA_VERSION=2.9.3 +set SCALA_VERSION=2.10 rem Figure out where the Spark framework is installed set FWDIR=%~dp0 @@ -75,7 +75,7 @@ rem Compute classpath using external script set DONT_PRINT_CLASSPATH=1 call "%FWDIR%bin\compute-classpath.cmd" set DONT_PRINT_CLASSPATH=0 -set CLASSPATH=%SPARK_TOOLS_JAR%;%CLASSPATH% +set CLASSPATH=%CLASSPATH%;%SPARK_TOOLS_JAR% rem Figure out where java is. set RUNNER=java diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index f106bba678..35e23c1355 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -39,9 +39,9 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) val graph = ssc.graph val checkpointDir = ssc.checkpointDir val checkpointDuration = ssc.checkpointDuration - val pendingTimes = ssc.scheduler.jobManager.getPendingTimes() + val pendingTimes = ssc.scheduler.getPendingTimes() val delaySeconds = MetadataCleaner.getDelaySeconds(ssc.conf) - val sparkConf = ssc.sc.conf + val sparkConf = ssc.conf def validate() { assert(master != null, "Checkpoint.master is null") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala index 8005202500..ce2a9d4142 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala @@ -17,24 +17,19 @@ package org.apache.spark.streaming -import org.apache.spark.streaming.dstream._ import StreamingContext._ -import org.apache.spark.util.MetadataCleaner - -//import Time._ - +import org.apache.spark.streaming.dstream._ +import org.apache.spark.streaming.scheduler.Job import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.MetadataCleaner -import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import scala.reflect.ClassTag import java.io.{ObjectInputStream, IOException, ObjectOutputStream} -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.conf.Configuration /** * A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala index b9a58fded6..daed7ff7c3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala @@ -21,6 +21,7 @@ import dstream.InputDStream import java.io.{ObjectInputStream, IOException, ObjectOutputStream} import collection.mutable.ArrayBuffer import org.apache.spark.Logging +import org.apache.spark.streaming.scheduler.Job final private[streaming] class DStreamGraph extends Serializable with Logging { initLogging() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/JobManager.scala b/streaming/src/main/scala/org/apache/spark/streaming/JobManager.scala deleted file mode 100644 index 5233129506..0000000000 --- a/streaming/src/main/scala/org/apache/spark/streaming/JobManager.scala +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming - -import org.apache.spark.Logging -import org.apache.spark.SparkEnv -import java.util.concurrent.Executors -import collection.mutable.HashMap -import collection.mutable.ArrayBuffer - - -private[streaming] -class JobManager(ssc: StreamingContext, numThreads: Int = 1) extends Logging { - - class JobHandler(ssc: StreamingContext, job: Job) extends Runnable { - def run() { - SparkEnv.set(ssc.env) - try { - val timeTaken = job.run() - logInfo("Total delay: %.5f s for job %s of time %s (execution: %.5f s)".format( - (System.currentTimeMillis() - job.time.milliseconds) / 1000.0, job.id, job.time.milliseconds, timeTaken / 1000.0)) - } catch { - case e: Exception => - logError("Running " + job + " failed", e) - } - clearJob(job) - } - } - - initLogging() - - val jobExecutor = Executors.newFixedThreadPool(numThreads) - val jobs = new HashMap[Time, ArrayBuffer[Job]] - - def runJob(job: Job) { - jobs.synchronized { - jobs.getOrElseUpdate(job.time, new ArrayBuffer[Job]) += job - } - jobExecutor.execute(new JobHandler(ssc, job)) - logInfo("Added " + job + " to queue") - } - - def stop() { - jobExecutor.shutdown() - } - - private def clearJob(job: Job) { - var timeCleared = false - val time = job.time - jobs.synchronized { - val jobsOfTime = jobs.get(time) - if (jobsOfTime.isDefined) { - jobsOfTime.get -= job - if (jobsOfTime.get.isEmpty) { - jobs -= time - timeCleared = true - } - } else { - throw new Exception("Job finished for time " + job.time + - " but time does not exist in jobs") - } - } - if (timeCleared) { - ssc.scheduler.clearOldMetadata(time) - } - } - - def getPendingTimes(): Array[Time] = { - jobs.synchronized { - jobs.keySet.toArray - } - } -} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 286ec285a9..339f6e64a2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -47,9 +47,9 @@ import org.apache.hadoop.mapreduce.lib.input.TextInputFormat import org.apache.hadoop.fs.Path import twitter4j.Status import twitter4j.auth.Authorization +import org.apache.spark.streaming.scheduler._ import akka.util.ByteString - /** * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic * information (such as, cluster URL and job name) to internally create a SparkContext, it provides @@ -160,9 +160,10 @@ class StreamingContext private ( } } - protected[streaming] var checkpointDuration: Duration = if (isCheckpointPresent) cp_.checkpointDuration else null - protected[streaming] var receiverJobThread: Thread = null - protected[streaming] var scheduler: Scheduler = null + protected[streaming] val checkpointDuration: Duration = { + if (isCheckpointPresent) cp_.checkpointDuration else graph.batchDuration + } + protected[streaming] val scheduler = new JobScheduler(this) /** * Return the associated Spark context @@ -524,6 +525,13 @@ class StreamingContext private ( graph.addOutputStream(outputStream) } + /** Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for + * receiving system events related to streaming. + */ + def addStreamingListener(streamingListener: StreamingListener) { + scheduler.listenerBus.addListener(streamingListener) + } + protected def validate() { assert(graph != null, "Graph is null") graph.validate() @@ -539,27 +547,22 @@ class StreamingContext private ( * Start the execution of the streams. */ def start() { - if (checkpointDir != null && checkpointDuration == null && graph != null) { - checkpointDuration = graph.batchDuration - } - validate() + // Get the network input streams val networkInputStreams = graph.getInputStreams().filter(s => s match { case n: NetworkInputDStream[_] => true case _ => false }).map(_.asInstanceOf[NetworkInputDStream[_]]).toArray + // Start the network input tracker (must start before receivers) if (networkInputStreams.length > 0) { - // Start the network input tracker (must start before receivers) networkInputTracker = new NetworkInputTracker(this, networkInputStreams) networkInputTracker.start() } - Thread.sleep(1000) // Start the scheduler - scheduler = new Scheduler(this) scheduler.start() } @@ -570,7 +573,6 @@ class StreamingContext private ( try { if (scheduler != null) scheduler.stop() if (networkInputTracker != null) networkInputTracker.stop() - if (receiverJobThread != null) receiverJobThread.interrupt() sc.stop() logInfo("StreamingContext stopped successfully") } catch { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 5842a7cd68..29f673d8ae 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -40,6 +40,7 @@ import org.apache.spark.api.java.{JavaPairRDD, JavaSparkContext, JavaRDD} import org.apache.spark.streaming._ import org.apache.spark.streaming.dstream._ import org.apache.spark.SparkConf +import org.apache.spark.streaming.scheduler.StreamingListener /** * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic @@ -696,6 +697,13 @@ class JavaStreamingContext(val ssc: StreamingContext) { ssc.remember(duration) } + /** Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for + * receiving system events related to streaming. + */ + def addStreamingListener(streamingListener: StreamingListener) { + ssc.addStreamingListener(streamingListener) + } + /** * Starts the execution of the streams. */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala index 98b14cb224..364abcde68 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala @@ -18,7 +18,8 @@ package org.apache.spark.streaming.dstream import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.{Duration, DStream, Job, Time} +import org.apache.spark.streaming.{Duration, DStream, Time} +import org.apache.spark.streaming.scheduler.Job import scala.reflect.ClassTag private[streaming] diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala index bd607f9d18..1839ca3578 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala @@ -33,6 +33,7 @@ import org.apache.spark.streaming._ import org.apache.spark.{Logging, SparkEnv} import org.apache.spark.rdd.{RDD, BlockRDD} import org.apache.spark.storage.{BlockId, StorageLevel, StreamBlockId} +import org.apache.spark.streaming.scheduler.{DeregisterReceiver, AddBlocks, RegisterReceiver} /** * Abstract class for defining any InputDStream that has to start a receiver on worker diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala new file mode 100644 index 0000000000..4e8d07fe92 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import org.apache.spark.streaming.Time + +/** + * Class having information on completed batches. + * @param batchTime Time of the batch + * @param submissionTime Clock time of when jobs of this batch was submitted to + * the streaming scheduler queue + * @param processingStartTime Clock time of when the first job of this batch started processing + * @param processingEndTime Clock time of when the last job of this batch finished processing + */ +case class BatchInfo( + batchTime: Time, + submissionTime: Long, + processingStartTime: Option[Long], + processingEndTime: Option[Long] + ) { + + /** + * Time taken for the first job of this batch to start processing from the time this batch + * was submitted to the streaming scheduler. Essentially, it is + * `processingStartTime` - `submissionTime`. + */ + def schedulingDelay = processingStartTime.map(_ - submissionTime) + + /** + * Time taken for the all jobs of this batch to finish processing from the time they started + * processing. Essentially, it is `processingEndTime` - `processingStartTime`. + */ + def processingDelay = processingEndTime.zip(processingStartTime).map(x => x._1 - x._2).headOption + + /** + * Time taken for all the jobs of this batch to finish processing from the time they + * were submitted. Essentially, it is `processingDelay` + `schedulingDelay`. + */ + def totalDelay = schedulingDelay.zip(processingDelay).map(x => x._1 + x._2).headOption +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Job.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala index 2128b7c7a6..7341bfbc99 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Job.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala @@ -15,13 +15,17 @@ * limitations under the License. */ -package org.apache.spark.streaming +package org.apache.spark.streaming.scheduler -import java.util.concurrent.atomic.AtomicLong +import org.apache.spark.streaming.Time +/** + * Class representing a Spark computation. It may contain multiple Spark jobs. + */ private[streaming] class Job(val time: Time, func: () => _) { - val id = Job.getNewId() + var id: String = _ + def run(): Long = { val startTime = System.currentTimeMillis func() @@ -29,13 +33,9 @@ class Job(val time: Time, func: () => _) { (stopTime - startTime) } - override def toString = "streaming job " + id + " @ " + time -} - -private[streaming] -object Job { - val id = new AtomicLong(0) - - def getNewId() = id.getAndIncrement() -} + def setId(number: Int) { + id = "streaming job " + time + "." + number + } + override def toString = id +}
\ No newline at end of file diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Scheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 82ed6bed69..dbd08415a1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -15,31 +15,35 @@ * limitations under the License. */ -package org.apache.spark.streaming +package org.apache.spark.streaming.scheduler -import util.{ManualClock, RecurringTimer, Clock} import org.apache.spark.SparkEnv import org.apache.spark.Logging +import org.apache.spark.streaming.{Checkpoint, Time, CheckpointWriter} +import org.apache.spark.streaming.util.{ManualClock, RecurringTimer, Clock} +/** + * This class generates jobs from DStreams as well as drives checkpointing and cleaning + * up DStream metadata. + */ private[streaming] -class Scheduler(ssc: StreamingContext) extends Logging { +class JobGenerator(jobScheduler: JobScheduler) extends Logging { initLogging() - val concurrentJobs = ssc.sc.conf.getOrElse("spark.streaming.concurrentJobs", "1").toInt - val jobManager = new JobManager(ssc, concurrentJobs) - val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) { - new CheckpointWriter(ssc.conf, ssc.checkpointDir) - } else { - null - } - + val ssc = jobScheduler.ssc val clockClass = ssc.sc.conf.getOrElse( "spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock") val clock = Class.forName(clockClass).newInstance().asInstanceOf[Clock] val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds, longTime => generateJobs(new Time(longTime))) val graph = ssc.graph + lazy val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) { + new CheckpointWriter(ssc.conf, ssc.checkpointDir) + } else { + null + } + var latestTime: Time = null def start() = synchronized { @@ -48,26 +52,24 @@ class Scheduler(ssc: StreamingContext) extends Logging { } else { startFirstTime() } - logInfo("Scheduler started") + logInfo("JobGenerator started") } def stop() = synchronized { timer.stop() - jobManager.stop() if (checkpointWriter != null) checkpointWriter.stop() ssc.graph.stop() - logInfo("Scheduler stopped") + logInfo("JobGenerator stopped") } private def startFirstTime() { val startTime = new Time(timer.getStartTime()) graph.start(startTime - graph.batchDuration) timer.start(startTime.milliseconds) - logInfo("Scheduler's timer started at " + startTime) + logInfo("JobGenerator's timer started at " + startTime) } private def restart() { - // If manual clock is being used for testing, then // either set the manual clock to the last checkpointed time, // or if the property is defined set it to that time @@ -93,35 +95,34 @@ class Scheduler(ssc: StreamingContext) extends Logging { val timesToReschedule = (pendingTimes ++ downTimes).distinct.sorted(Time.ordering) logInfo("Batches to reschedule: " + timesToReschedule.mkString(", ")) timesToReschedule.foreach(time => - graph.generateJobs(time).foreach(jobManager.runJob) + jobScheduler.runJobs(time, graph.generateJobs(time)) ) // Restart the timer timer.start(restartTime.milliseconds) - logInfo("Scheduler's timer restarted at " + restartTime) + logInfo("JobGenerator's timer restarted at " + restartTime) } /** Generate jobs and perform checkpoint for the given `time`. */ - def generateJobs(time: Time) { + private def generateJobs(time: Time) { SparkEnv.set(ssc.env) logInfo("\n-----------------------------------------------------\n") - graph.generateJobs(time).foreach(jobManager.runJob) + jobScheduler.runJobs(time, graph.generateJobs(time)) latestTime = time doCheckpoint(time) } /** - * Clear old metadata assuming jobs of `time` have finished processing. - * And also perform checkpoint. + * On batch completion, clear old metadata and checkpoint computation. */ - def clearOldMetadata(time: Time) { + private[streaming] def onBatchCompletion(time: Time) { ssc.graph.clearOldMetadata(time) doCheckpoint(time) } /** Perform checkpoint for the give `time`. */ - def doCheckpoint(time: Time) = synchronized { - if (ssc.checkpointDuration != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) { + private def doCheckpoint(time: Time) = synchronized { + if (checkpointWriter != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) { logInfo("Checkpointing graph for time " + time) ssc.graph.updateCheckpointData(time) checkpointWriter.write(new Checkpoint(ssc, time)) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala new file mode 100644 index 0000000000..9511ccfbed --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import org.apache.spark.Logging +import org.apache.spark.SparkEnv +import java.util.concurrent.{TimeUnit, ConcurrentHashMap, Executors} +import scala.collection.mutable.HashSet +import org.apache.spark.streaming._ + +/** + * This class schedules jobs to be run on Spark. It uses the JobGenerator to generate + * the jobs and runs them using a thread pool. Number of threads + */ +private[streaming] +class JobScheduler(val ssc: StreamingContext) extends Logging { + + initLogging() + + val jobSets = new ConcurrentHashMap[Time, JobSet] + val numConcurrentJobs = System.getProperty("spark.streaming.concurrentJobs", "1").toInt + val executor = Executors.newFixedThreadPool(numConcurrentJobs) + val generator = new JobGenerator(this) + val listenerBus = new StreamingListenerBus() + + def clock = generator.clock + + def start() { + generator.start() + } + + def stop() { + generator.stop() + executor.shutdown() + if (!executor.awaitTermination(5, TimeUnit.SECONDS)) { + executor.shutdownNow() + } + } + + def runJobs(time: Time, jobs: Seq[Job]) { + if (jobs.isEmpty) { + logInfo("No jobs added for time " + time) + } else { + val jobSet = new JobSet(time, jobs) + jobSets.put(time, jobSet) + jobSet.jobs.foreach(job => executor.execute(new JobHandler(job))) + logInfo("Added jobs for time " + time) + } + } + + def getPendingTimes(): Array[Time] = { + jobSets.keySet.toArray(new Array[Time](0)) + } + + private def beforeJobStart(job: Job) { + val jobSet = jobSets.get(job.time) + if (!jobSet.hasStarted) { + listenerBus.post(StreamingListenerBatchStarted(jobSet.toBatchInfo())) + } + jobSet.beforeJobStart(job) + logInfo("Starting job " + job.id + " from job set of time " + jobSet.time) + SparkEnv.set(generator.ssc.env) + } + + private def afterJobEnd(job: Job) { + val jobSet = jobSets.get(job.time) + jobSet.afterJobStop(job) + logInfo("Finished job " + job.id + " from job set of time " + jobSet.time) + if (jobSet.hasCompleted) { + jobSets.remove(jobSet.time) + generator.onBatchCompletion(jobSet.time) + logInfo("Total delay: %.3f s for time %s (execution: %.3f s)".format( + jobSet.totalDelay / 1000.0, jobSet.time.toString, + jobSet.processingDelay / 1000.0 + )) + listenerBus.post(StreamingListenerBatchCompleted(jobSet.toBatchInfo())) + } + } + + private[streaming] + class JobHandler(job: Job) extends Runnable { + def run() { + beforeJobStart(job) + try { + job.run() + } catch { + case e: Exception => + logError("Running " + job + " failed", e) + } + afterJobEnd(job) + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala new file mode 100644 index 0000000000..57268674ea --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import scala.collection.mutable.HashSet +import org.apache.spark.streaming.Time + +/** Class representing a set of Jobs + * belong to the same batch. + */ +private[streaming] +case class JobSet(time: Time, jobs: Seq[Job]) { + + private val incompleteJobs = new HashSet[Job]() + var submissionTime = System.currentTimeMillis() // when this jobset was submitted + var processingStartTime = -1L // when the first job of this jobset started processing + var processingEndTime = -1L // when the last job of this jobset finished processing + + jobs.zipWithIndex.foreach { case (job, i) => job.setId(i) } + incompleteJobs ++= jobs + + def beforeJobStart(job: Job) { + if (processingStartTime < 0) processingStartTime = System.currentTimeMillis() + } + + def afterJobStop(job: Job) { + incompleteJobs -= job + if (hasCompleted) processingEndTime = System.currentTimeMillis() + } + + def hasStarted() = (processingStartTime > 0) + + def hasCompleted() = incompleteJobs.isEmpty + + // Time taken to process all the jobs from the time they started processing + // (i.e. not including the time they wait in the streaming scheduler queue) + def processingDelay = processingEndTime - processingStartTime + + // Time taken to process all the jobs from the time they were submitted + // (i.e. including the time they wait in the streaming scheduler queue) + def totalDelay = { + processingEndTime - time.milliseconds + } + + def toBatchInfo(): BatchInfo = { + new BatchInfo( + time, + submissionTime, + if (processingStartTime >= 0 ) Some(processingStartTime) else None, + if (processingEndTime >= 0 ) Some(processingEndTime) else None + ) + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala index 6e9a781978..abff55d77c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.streaming +package org.apache.spark.streaming.scheduler import org.apache.spark.streaming.dstream.{NetworkInputDStream, NetworkReceiver} import org.apache.spark.streaming.dstream.{StopReceiver, ReportBlock, ReportError} @@ -31,6 +31,7 @@ import akka.actor._ import akka.pattern.ask import akka.dispatch._ import org.apache.spark.storage.BlockId +import org.apache.spark.streaming.{Time, StreamingContext} private[streaming] sealed trait NetworkInputTrackerMessage private[streaming] case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala new file mode 100644 index 0000000000..36225e190c --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import scala.collection.mutable.Queue +import org.apache.spark.util.Distribution + +/** Base trait for events related to StreamingListener */ +sealed trait StreamingListenerEvent + +case class StreamingListenerBatchCompleted(batchInfo: BatchInfo) extends StreamingListenerEvent + +case class StreamingListenerBatchStarted(batchInfo: BatchInfo) extends StreamingListenerEvent + + +/** + * A listener interface for receiving information about an ongoing streaming + * computation. + */ +trait StreamingListener { + /** + * Called when processing of a batch has completed + */ + def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { } + + /** + * Called when processing of a batch has started + */ + def onBatchStarted(batchStarted: StreamingListenerBatchStarted) { } +} + + +/** + * A simple StreamingListener that logs summary statistics across Spark Streaming batches + * @param numBatchInfos Number of last batches to consider for generating statistics (default: 10) + */ +class StatsReportListener(numBatchInfos: Int = 10) extends StreamingListener { + // Queue containing latest completed batches + val batchInfos = new Queue[BatchInfo]() + + override def onBatchCompleted(batchStarted: StreamingListenerBatchCompleted) { + batchInfos.enqueue(batchStarted.batchInfo) + if (batchInfos.size > numBatchInfos) batchInfos.dequeue() + printStats() + } + + def printStats() { + showMillisDistribution("Total delay: ", _.totalDelay) + showMillisDistribution("Processing time: ", _.processingDelay) + } + + def showMillisDistribution(heading: String, getMetric: BatchInfo => Option[Long]) { + org.apache.spark.scheduler.StatsReportListener.showMillisDistribution( + heading, extractDistribution(getMetric)) + } + + def extractDistribution(getMetric: BatchInfo => Option[Long]): Option[Distribution] = { + Distribution(batchInfos.flatMap(getMetric(_)).map(_.toDouble)) + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala new file mode 100644 index 0000000000..110a20f282 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import org.apache.spark.Logging +import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import java.util.concurrent.LinkedBlockingQueue + +/** Asynchronously passes StreamingListenerEvents to registered StreamingListeners. */ +private[spark] class StreamingListenerBus() extends Logging { + private val listeners = new ArrayBuffer[StreamingListener]() with SynchronizedBuffer[StreamingListener] + + /* Cap the capacity of the SparkListenerEvent queue so we get an explicit error (rather than + * an OOM exception) if it's perpetually being added to more quickly than it's being drained. */ + private val EVENT_QUEUE_CAPACITY = 10000 + private val eventQueue = new LinkedBlockingQueue[StreamingListenerEvent](EVENT_QUEUE_CAPACITY) + private var queueFullErrorMessageLogged = false + + new Thread("StreamingListenerBus") { + setDaemon(true) + override def run() { + while (true) { + val event = eventQueue.take + event match { + case batchStarted: StreamingListenerBatchStarted => + listeners.foreach(_.onBatchStarted(batchStarted)) + case batchCompleted: StreamingListenerBatchCompleted => + listeners.foreach(_.onBatchCompleted(batchCompleted)) + case _ => + } + } + } + }.start() + + def addListener(listener: StreamingListener) { + listeners += listener + } + + def post(event: StreamingListenerEvent) { + val eventAdded = eventQueue.offer(event) + if (!eventAdded && !queueFullErrorMessageLogged) { + logError("Dropping SparkListenerEvent because no remaining room in event queue. " + + "This likely means one of the SparkListeners is too slow and cannot keep up with the " + + "rate at which tasks are being started by the scheduler.") + queueFullErrorMessageLogged = true + } + } + + /** + * Waits until there are no more events in the queue, or until the specified time has elapsed. + * Used for testing only. Returns true if the queue has emptied and false is the specified time + * elapsed before the queue emptied. + */ + def waitUntilEmpty(timeoutMillis: Int): Boolean = { + val finishTime = System.currentTimeMillis + timeoutMillis + while (!eventQueue.isEmpty()) { + if (System.currentTimeMillis > finishTime) { + return false + } + /* Sleep rather than using wait/notify, because this is used only for testing and wait/notify + * add overhead in the general case. */ + Thread.sleep(10) + } + return true + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 60e986cb9d..ee6b433d1f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -26,17 +26,6 @@ import util.ManualClock import org.apache.spark.{SparkContext, SparkConf} class BasicOperationsSuite extends TestSuiteBase { - - override def framework = "BasicOperationsSuite" - - conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") - - after { - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.driver.port") - System.clearProperty("spark.hostPort") - } - test("map") { val input = Seq(1 to 4, 5 to 8, 9 to 12) testOperation( diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index ca230fd056..c60a3f5390 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -40,29 +40,25 @@ import org.apache.spark.streaming.util.ManualClock * the checkpointing of a DStream's RDDs as well as the checkpointing of * the whole DStream graph. */ -class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { +class CheckpointSuite extends TestSuiteBase { - before { + var ssc: StreamingContext = null + + override def batchDuration = Milliseconds(500) + + override def actuallyWait = true // to allow checkpoints to be written + + override def beforeFunction() { + super.beforeFunction() FileUtils.deleteDirectory(new File(checkpointDir)) } - after { + override def afterFunction() { + super.afterFunction() if (ssc != null) ssc.stop() FileUtils.deleteDirectory(new File(checkpointDir)) - - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.driver.port") - System.clearProperty("spark.hostPort") } - var ssc: StreamingContext = null - - override def framework = "CheckpointSuite" - - override def batchDuration = Milliseconds(500) - - override def actuallyWait = true - test("basic rdd checkpoints + dstream graph checkpoint recovery") { assert(batchDuration === Milliseconds(500), "batchDuration for this test must be 1 second") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala index 6337c5359c..da9b04de1a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala @@ -32,17 +32,22 @@ import collection.mutable.ArrayBuffer * This testsuite tests master failures at random times while the stream is running using * the real clock. */ -class FailureSuite extends FunSuite with BeforeAndAfter with Logging { +class FailureSuite extends TestSuiteBase with Logging { var directory = "FailureSuite" val numBatches = 30 - val batchDuration = Milliseconds(1000) - before { + override def batchDuration = Milliseconds(1000) + + override def useManualClock = false + + override def beforeFunction() { + super.beforeFunction() FileUtils.deleteDirectory(new File(directory)) } - after { + override def afterFunction() { + super.afterFunction() FileUtils.deleteDirectory(new File(directory)) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 8c16daa21c..52381c10b0 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -50,16 +50,6 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val testPort = 9999 - override def checkpointDir = "checkpoint" - - conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") - - after { - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.driver.port") - System.clearProperty("spark.hostPort") - } - test("socket input stream") { // Start the server val testServer = new TestServer() diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala new file mode 100644 index 0000000000..fa64142096 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import org.apache.spark.streaming.scheduler._ +import scala.collection.mutable.ArrayBuffer +import org.scalatest.matchers.ShouldMatchers + +class StreamingListenerSuite extends TestSuiteBase with ShouldMatchers { + + val input = (1 to 4).map(Seq(_)).toSeq + val operation = (d: DStream[Int]) => d.map(x => x) + + // To make sure that the processing start and end times in collected + // information are different for successive batches + override def batchDuration = Milliseconds(100) + override def actuallyWait = true + + test("basic BatchInfo generation") { + val ssc = setupStreams(input, operation) + val collector = new BatchInfoCollector + ssc.addStreamingListener(collector) + runStreams(ssc, input.size, input.size) + val batchInfos = collector.batchInfos + batchInfos should have size 4 + + batchInfos.foreach(info => { + info.schedulingDelay should not be None + info.processingDelay should not be None + info.totalDelay should not be None + info.schedulingDelay.get should be >= 0L + info.processingDelay.get should be >= 0L + info.totalDelay.get should be >= 0L + }) + + isInIncreasingOrder(batchInfos.map(_.submissionTime)) should be (true) + isInIncreasingOrder(batchInfos.map(_.processingStartTime.get)) should be (true) + isInIncreasingOrder(batchInfos.map(_.processingEndTime.get)) should be (true) + } + + /** Check if a sequence of numbers is in increasing order */ + def isInIncreasingOrder(seq: Seq[Long]): Boolean = { + for(i <- 1 until seq.size) { + if (seq(i - 1) > seq(i)) return false + } + true + } + + /** Listener that collects information on processed batches */ + class BatchInfoCollector extends StreamingListener { + val batchInfos = new ArrayBuffer[BatchInfo] + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { + batchInfos += batchCompleted.batchInfo + } + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 3dd6718491..33464bc3a1 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -110,7 +110,7 @@ class TestOutputStreamWithPartitions[T: ClassTag](parent: DStream[T], trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { // Name of the framework for Spark context - def framework = "TestSuiteBase" + def framework = this.getClass.getSimpleName // Master for Spark context def master = "local[2]" @@ -127,15 +127,45 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { // Maximum time to wait before the test times out def maxWaitTimeMillis = 10000 + // Whether to use manual clock or not + def useManualClock = true + // Whether to actually wait in real time before changing manual clock def actuallyWait = false - // A SparkConf to use in tests. Can be modified before calling setupStreams to configure things. + //// A SparkConf to use in tests. Can be modified before calling setupStreams to configure things. val conf = new SparkConf() .setMaster(master) .setAppName(framework) .set("spark.cleaner.ttl", "3600") + // Default before function for any streaming test suite. Override this + // if you want to add your stuff to "before" (i.e., don't call before { } ) + def beforeFunction() { + //if (useManualClock) { + // System.setProperty( + // "spark.streaming.clock", + // "org.apache.spark.streaming.util.ManualClock" + // ) + //} else { + // System.clearProperty("spark.streaming.clock") + //} + if (useManualClock) { + conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") + } + } + + // Default after function for any streaming test suite. Override this + // if you want to add your stuff to "after" (i.e., don't call after { } ) + def afterFunction() { + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") + } + + before(beforeFunction) + after(afterFunction) + /** * Set up required DStreams to test the DStream operation using the two sequences * of input collections. diff --git a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala index 3242c4cd11..c92c34d49b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala @@ -21,19 +21,9 @@ import org.apache.spark.streaming.StreamingContext._ class WindowOperationsSuite extends TestSuiteBase { - conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") + override def maxWaitTimeMillis = 20000 // large window tests can sometimes take longer - override def framework = "WindowOperationsSuite" - - override def maxWaitTimeMillis = 20000 - - override def batchDuration = Seconds(1) - - after { - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.driver.port") - System.clearProperty("spark.hostPort") - } + override def batchDuration = Seconds(1) // making sure its visible in this class val largerSlideInput = Seq( Seq(("a", 1)), diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index cc150888eb..595a7ee8c3 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -422,8 +422,10 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl } def monitorApplication(appId: ApplicationId): Boolean = { + val interval = new SparkConf().getOrElse("spark.yarn.report.interval", "1000").toLong + while (true) { - Thread.sleep(1000) + Thread.sleep(interval) val report = super.getApplicationReport(appId) logInfo("Application report from ASM: \n" + diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala index 4c9fee5695..5966a0f757 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala @@ -27,8 +27,8 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import org.apache.spark.Logging -import org.apache.spark.scheduler.SplitInfo -import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedulerBackend} +import org.apache.spark.scheduler.{SplitInfo,TaskSchedulerImpl} +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils import org.apache.hadoop.conf.Configuration @@ -214,9 +214,9 @@ private[yarn] class YarnAllocationHandler( // host if there are sufficiently large number of hosts/containers. val allocatedContainers = new ArrayBuffer[Container](_allocatedContainers.size) - allocatedContainers ++= ClusterScheduler.prioritizeContainers(dataLocalContainers) - allocatedContainers ++= ClusterScheduler.prioritizeContainers(rackLocalContainers) - allocatedContainers ++= ClusterScheduler.prioritizeContainers(offRackContainers) + allocatedContainers ++= TaskSchedulerImpl.prioritizeContainers(dataLocalContainers) + allocatedContainers ++= TaskSchedulerImpl.prioritizeContainers(rackLocalContainers) + allocatedContainers ++= TaskSchedulerImpl.prioritizeContainers(offRackContainers) // Run each of the allocated containers for (container <- allocatedContainers) { diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala index 63a0449e5a..522e0a9ad7 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala @@ -20,13 +20,14 @@ package org.apache.spark.scheduler.cluster import org.apache.spark._ import org.apache.hadoop.conf.Configuration import org.apache.spark.deploy.yarn.YarnAllocationHandler +import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.util.Utils /** * * This scheduler launch worker through Yarn - by call into Client to launch WorkerLauncher as AM. */ -private[spark] class YarnClientClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) { +private[spark] class YarnClientClusterScheduler(sc: SparkContext, conf: Configuration) extends TaskSchedulerImpl(sc) { def this(sc: SparkContext) = this(sc, new Configuration()) diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 6feaaff014..4b69f5078b 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -20,9 +20,10 @@ package org.apache.spark.scheduler.cluster import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState} import org.apache.spark.{SparkException, Logging, SparkContext} import org.apache.spark.deploy.yarn.{Client, ClientArguments} +import org.apache.spark.scheduler.TaskSchedulerImpl private[spark] class YarnClientSchedulerBackend( - scheduler: ClusterScheduler, + scheduler: TaskSchedulerImpl, sc: SparkContext) extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) with Logging { diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala index 29b3f22e13..2d9fbcb400 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala @@ -17,16 +17,20 @@ package org.apache.spark.scheduler.cluster +import org.apache.hadoop.conf.Configuration + import org.apache.spark._ import org.apache.spark.deploy.yarn.{ApplicationMaster, YarnAllocationHandler} +import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.util.Utils -import org.apache.hadoop.conf.Configuration /** * - * This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of ApplicationMaster, etc is done + * This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of + * ApplicationMaster, etc. is done */ -private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) { +private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) + extends TaskSchedulerImpl(sc) { logInfo("Created YarnClusterScheduler") |