aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatei Zaharia <matei@databricks.com>2013-12-29 15:08:08 -0500
committerMatei Zaharia <matei@databricks.com>2013-12-29 15:08:08 -0500
commitb4ceed40d6e511a1d475b3f4fbcdd2ad24c02b5a (patch)
tree486226041e35962c1543902d8ffc10a81f4223a5
parent58c6fa2041b99160f254b17c2b71de9d82c53f8c (diff)
parentad3dfd153196497fefe6c1925e0a495a4373f1c5 (diff)
downloadspark-b4ceed40d6e511a1d475b3f4fbcdd2ad24c02b5a.tar.gz
spark-b4ceed40d6e511a1d475b3f4fbcdd2ad24c02b5a.tar.bz2
spark-b4ceed40d6e511a1d475b3f4fbcdd2ad24c02b5a.zip
Merge remote-tracking branch 'origin/master' into conf2
Conflicts: core/src/main/scala/org/apache/spark/SparkContext.scala core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala45
-rw-r--r--core/src/main/scala/org/apache/spark/TaskEndReason.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala42
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala (renamed from core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorLossReason.scala)2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Pool.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala (renamed from core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala)4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala (renamed from core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala)9
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala (renamed from core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala)47
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala700
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala (renamed from core/src/main/scala/org/apache/spark/scheduler/cluster/WorkerOffer.scala)2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala714
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala108
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala224
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala191
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala25
-rw-r--r--core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala31
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala27
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala90
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala31
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala15
-rw-r--r--core/src/test/scala/org/apache/spark/FailureSuite.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala58
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala (renamed from core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala)16
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala (renamed from core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala)3
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala19
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala (renamed from core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala)30
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala (renamed from core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala)30
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala227
-rw-r--r--core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala73
-rw-r--r--examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala232
-rw-r--r--new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala4
-rw-r--r--new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala10
-rw-r--r--new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala3
-rw-r--r--new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala3
-rw-r--r--new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala3
-rw-r--r--project/SparkBuild.scala5
-rw-r--r--python/pyspark/java_gateway.py1
-rw-r--r--python/pyspark/mllib/__init__.py20
-rw-r--r--python/pyspark/mllib/_common.py227
-rw-r--r--python/pyspark/mllib/classification.py86
-rw-r--r--python/pyspark/mllib/clustering.py79
-rw-r--r--python/pyspark/mllib/recommendation.py74
-rw-r--r--python/pyspark/mllib/regression.py110
-rw-r--r--python/pyspark/serializers.py2
-rw-r--r--python/pyspark/shell.py2
-rwxr-xr-xspark-class6
-rw-r--r--spark-class2.cmd4
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala4
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/DStream.scala11
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala1
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/JobManager.scala88
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala26
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala8
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala3
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala1
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala55
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala (renamed from streaming/src/main/scala/org/apache/spark/streaming/Job.scala)24
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala (renamed from streaming/src/main/scala/org/apache/spark/streaming/Scheduler.scala)51
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala108
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala68
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala (renamed from streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala)3
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala75
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala81
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala11
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala26
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala13
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala10
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala71
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala34
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala14
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala4
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala10
-rw-r--r--yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala3
-rw-r--r--yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala3
-rw-r--r--yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala10
90 files changed, 2721 insertions, 1808 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index c109ff930c..6f54fa7a5a 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -43,11 +43,10 @@ import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd._
import org.apache.spark.scheduler._
-import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedulerBackend,
-SimrSchedulerBackend, SparkDeploySchedulerBackend}
-import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend,
-MesosSchedulerBackend}
-import org.apache.spark.scheduler.local.LocalScheduler
+import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend,
+ SparkDeploySchedulerBackend, SimrSchedulerBackend}
+import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
+import org.apache.spark.scheduler.local.LocalBackend
import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils}
import org.apache.spark.ui.SparkUI
import org.apache.spark.util._
@@ -560,9 +559,7 @@ class SparkContext(
}
addedFiles(key) = System.currentTimeMillis
- // Fetch the file locally in case a job is executed locally.
- // Jobs that run through LocalScheduler will already fetch the required dependencies,
- // but jobs run in DAGScheduler.runLocally() will not so we must fetch the files here.
+ // Fetch the file locally in case a job is executed using DAGScheduler.runLocally().
Utils.fetchFile(path, new File(SparkFiles.getRootDirectory), conf)
logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
@@ -1070,18 +1067,30 @@ object SparkContext {
// Regular expression for connection to Simr cluster
val SIMR_REGEX = """simr://(.*)""".r
+ // When running locally, don't try to re-execute tasks on failure.
+ val MAX_LOCAL_TASK_FAILURES = 1
+
master match {
case "local" =>
- new LocalScheduler(1, 0, sc)
+ val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true)
+ val backend = new LocalBackend(scheduler, 1)
+ scheduler.initialize(backend)
+ scheduler
case LOCAL_N_REGEX(threads) =>
- new LocalScheduler(threads.toInt, 0, sc)
+ val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true)
+ val backend = new LocalBackend(scheduler, threads.toInt)
+ scheduler.initialize(backend)
+ scheduler
case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
- new LocalScheduler(threads.toInt, maxFailures.toInt, sc)
+ val scheduler = new TaskSchedulerImpl(sc, maxFailures.toInt, isLocal = true)
+ val backend = new LocalBackend(scheduler, threads.toInt)
+ scheduler.initialize(backend)
+ scheduler
case SPARK_REGEX(sparkUrl) =>
- val scheduler = new ClusterScheduler(sc)
+ val scheduler = new TaskSchedulerImpl(sc)
val masterUrls = sparkUrl.split(",").map("spark://" + _)
val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls, appName)
scheduler.initialize(backend)
@@ -1096,7 +1105,7 @@ object SparkContext {
memoryPerSlaveInt, sc.executorMemory))
}
- val scheduler = new ClusterScheduler(sc)
+ val scheduler = new TaskSchedulerImpl(sc)
val localCluster = new LocalSparkCluster(
numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt)
val masterUrls = localCluster.start()
@@ -1111,7 +1120,7 @@ object SparkContext {
val scheduler = try {
val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler")
val cons = clazz.getConstructor(classOf[SparkContext])
- cons.newInstance(sc).asInstanceOf[ClusterScheduler]
+ cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl]
} catch {
// TODO: Enumerate the exact reasons why it can fail
// But irrespective of it, it means we cannot proceed !
@@ -1127,7 +1136,7 @@ object SparkContext {
val scheduler = try {
val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientClusterScheduler")
val cons = clazz.getConstructor(classOf[SparkContext])
- cons.newInstance(sc).asInstanceOf[ClusterScheduler]
+ cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl]
} catch {
case th: Throwable => {
@@ -1137,7 +1146,7 @@ object SparkContext {
val backend = try {
val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend")
- val cons = clazz.getConstructor(classOf[ClusterScheduler], classOf[SparkContext])
+ val cons = clazz.getConstructor(classOf[TaskSchedulerImpl], classOf[SparkContext])
cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend]
} catch {
case th: Throwable => {
@@ -1150,7 +1159,7 @@ object SparkContext {
case mesosUrl @ MESOS_REGEX(_) =>
MesosNativeLibrary.load()
- val scheduler = new ClusterScheduler(sc)
+ val scheduler = new TaskSchedulerImpl(sc)
val coarseGrained = sc.conf.getOrElse("spark.mesos.coarse", "false").toBoolean
val url = mesosUrl.stripPrefix("mesos://") // strip scheme from raw Mesos URLs
val backend = if (coarseGrained) {
@@ -1162,7 +1171,7 @@ object SparkContext {
scheduler
case SIMR_REGEX(simrUrl) =>
- val scheduler = new ClusterScheduler(sc)
+ val scheduler = new TaskSchedulerImpl(sc)
val backend = new SimrSchedulerBackend(scheduler, sc, simrUrl)
scheduler.initialize(backend)
scheduler
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index c1e5e04b31..faf6dcd618 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -53,5 +53,3 @@ private[spark] case class ExceptionFailure(
private[spark] case object TaskResultLost extends TaskEndReason
private[spark] case object TaskKilled extends TaskEndReason
-
-private[spark] case class OtherFailure(message: String) extends TaskEndReason
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
index ec47ba1b56..a801d85770 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
@@ -140,12 +140,12 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I
<body>
{linkToMaster}
<div>
- <div style="float:left;width:40%">{backButton}</div>
+ <div style="float:left; margin-right:10px">{backButton}</div>
<div style="float:left;">{range}</div>
- <div style="float:right;">{nextButton}</div>
+ <div style="float:right; margin-left:10px">{nextButton}</div>
</div>
<br />
- <div style="height:500px;overflow:auto;padding:5px;">
+ <div style="height:500px; overflow:auto; padding:5px;">
<pre>{logText}</pre>
</div>
</body>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 77aa24e6b6..e06e49d9d2 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -152,7 +152,8 @@ class DAGScheduler(
val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done
val running = new HashSet[Stage] // Stages we are running right now
val failed = new HashSet[Stage] // Stages that must be resubmitted due to fetch failures
- val pendingTasks = new TimeStampedHashMap[Stage, HashSet[Task[_]]] // Missing tasks from each stage
+ // Missing tasks from each stage
+ val pendingTasks = new TimeStampedHashMap[Stage, HashSet[Task[_]]]
var lastFetchFailureTime: Long = 0 // Used to wait a bit to avoid repeated resubmits
val activeJobs = new HashSet[ActiveJob]
@@ -240,7 +241,8 @@ class DAGScheduler(
shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage
case None =>
- val stage = newOrUsedStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, shuffleDep, jobId)
+ val stage =
+ newOrUsedStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, shuffleDep, jobId)
shuffleToMapStage(shuffleDep.shuffleId) = stage
stage
}
@@ -249,7 +251,8 @@ class DAGScheduler(
/**
* Create a Stage -- either directly for use as a result stage, or as part of the (re)-creation
* of a shuffle map stage in newOrUsedStage. The stage will be associated with the provided
- * jobId. Production of shuffle map stages should always use newOrUsedStage, not newStage directly.
+ * jobId. Production of shuffle map stages should always use newOrUsedStage, not newStage
+ * directly.
*/
private def newStage(
rdd: RDD[_],
@@ -359,7 +362,8 @@ class DAGScheduler(
stageIdToJobIds.getOrElseUpdate(s.id, new HashSet[Int]()) += jobId
jobIdToStageIds.getOrElseUpdate(jobId, new HashSet[Int]()) += s.id
val parents = getParentStages(s.rdd, jobId)
- val parentsWithoutThisJobId = parents.filter(p => !stageIdToJobIds.get(p.id).exists(_.contains(jobId)))
+ val parentsWithoutThisJobId = parents.filter(p =>
+ !stageIdToJobIds.get(p.id).exists(_.contains(jobId)))
updateJobIdStageIdMapsList(parentsWithoutThisJobId ++ stages.tail)
}
}
@@ -367,8 +371,9 @@ class DAGScheduler(
}
/**
- * Removes job and any stages that are not needed by any other job. Returns the set of ids for stages that
- * were removed. The associated tasks for those stages need to be cancelled if we got here via job cancellation.
+ * Removes job and any stages that are not needed by any other job. Returns the set of ids for
+ * stages that were removed. The associated tasks for those stages need to be cancelled if we
+ * got here via job cancellation.
*/
private def removeJobAndIndependentStages(jobId: Int): Set[Int] = {
val registeredStages = jobIdToStageIds(jobId)
@@ -379,7 +384,8 @@ class DAGScheduler(
stageIdToJobIds.filterKeys(stageId => registeredStages.contains(stageId)).foreach {
case (stageId, jobSet) =>
if (!jobSet.contains(jobId)) {
- logError("Job %d not registered for stage %d even though that stage was registered for the job"
+ logError(
+ "Job %d not registered for stage %d even though that stage was registered for the job"
.format(jobId, stageId))
} else {
def removeStage(stageId: Int) {
@@ -390,7 +396,8 @@ class DAGScheduler(
running -= s
}
stageToInfos -= s
- shuffleToMapStage.keys.filter(shuffleToMapStage(_) == s).foreach(shuffleToMapStage.remove)
+ shuffleToMapStage.keys.filter(shuffleToMapStage(_) == s).foreach(shuffleId =>
+ shuffleToMapStage.remove(shuffleId))
if (pendingTasks.contains(s) && !pendingTasks(s).isEmpty) {
logDebug("Removing pending status for stage %d".format(stageId))
}
@@ -408,7 +415,8 @@ class DAGScheduler(
stageIdToStage -= stageId
stageIdToJobIds -= stageId
- logDebug("After removal of stage %d, remaining stages = %d".format(stageId, stageIdToStage.size))
+ logDebug("After removal of stage %d, remaining stages = %d"
+ .format(stageId, stageIdToStage.size))
}
jobSet -= jobId
@@ -460,7 +468,8 @@ class DAGScheduler(
assert(partitions.size > 0)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
- eventProcessActor ! JobSubmitted(jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties)
+ eventProcessActor ! JobSubmitted(
+ jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties)
waiter
}
@@ -495,7 +504,8 @@ class DAGScheduler(
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val partitions = (0 until rdd.partitions.size).toArray
val jobId = nextJobId.getAndIncrement()
- eventProcessActor ! JobSubmitted(jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties)
+ eventProcessActor ! JobSubmitted(
+ jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties)
listener.awaitResult() // Will throw an exception if the job fails
}
@@ -530,8 +540,8 @@ class DAGScheduler(
case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
var finalStage: Stage = null
try {
- // New stage creation at times and if its not protected, the scheduler thread is killed.
- // e.g. it can fail when jobs are run on HadoopRDD whose underlying hdfs files have been deleted
+ // New stage creation may throw an exception if, for example, jobs are run on a HadoopRDD
+ // whose underlying HDFS files have been deleted.
finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite))
} catch {
case e: Exception =>
@@ -564,7 +574,8 @@ class DAGScheduler(
case JobGroupCancelled(groupId) =>
// Cancel all jobs belonging to this job group.
// First finds all active jobs with this group id, and then kill stages for them.
- val activeInGroup = activeJobs.filter(groupId == _.properties.get(SparkContext.SPARK_JOB_GROUP_ID))
+ val activeInGroup = activeJobs.filter(activeJob =>
+ groupId == activeJob.properties.get(SparkContext.SPARK_JOB_GROUP_ID))
val jobIds = activeInGroup.map(_.jobId)
jobIds.foreach { handleJobCancellation }
@@ -586,7 +597,8 @@ class DAGScheduler(
stage <- stageIdToStage.get(task.stageId);
stageInfo <- stageToInfos.get(stage)
) {
- if (taskInfo.serializedSize > TASK_SIZE_TO_WARN * 1024 && !stageInfo.emittedTaskSizeWarning) {
+ if (taskInfo.serializedSize > TASK_SIZE_TO_WARN * 1024 &&
+ !stageInfo.emittedTaskSizeWarning) {
stageInfo.emittedTaskSizeWarning = true
logWarning(("Stage %d (%s) contains a task of very large " +
"size (%d KB). The maximum recommended task size is %d KB.").format(
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorLossReason.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala
index 5077b2b48b..2bc43a9186 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorLossReason.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
import org.apache.spark.executor.ExecutorExitCode
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
index 60927831a1..be5c95e59e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
@@ -328,10 +328,6 @@ class JobLogger(val user: String, val logDirName: String)
task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
mapId + " REDUCE_ID=" + reduceId
stageLogInfo(task.stageId, taskStatus)
- case OtherFailure(message) =>
- taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId +
- " STAGE_ID=" + task.stageId + " INFO=" + message
- stageLogInfo(task.stageId, taskStatus)
case _ =>
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
index 596f9adde9..1791242215 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
@@ -117,8 +117,4 @@ private[spark] class Pool(
parent.decreaseRunningTasks(taskNum)
}
}
-
- override def hasPendingTasks(): Boolean = {
- schedulableQueue.exists(_.hasPendingTasks())
- }
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala
index 1c7ea2dccc..d573e125a3 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala
@@ -42,5 +42,4 @@ private[spark] trait Schedulable {
def executorLost(executorId: String, host: String): Unit
def checkSpeculatableTasks(): Boolean
def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager]
- def hasPendingTasks(): Boolean
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
index 65d3fc8187..02bdbba825 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
@@ -15,12 +15,12 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
import org.apache.spark.SparkContext
/**
- * A backend interface for cluster scheduling systems that allows plugging in different ones under
+ * A backend interface for scheduling systems that allows plugging in different ones under
* ClusterScheduler. We assume a Mesos-like model where the application gets resource offers as
* machines become available and can launch tasks on them.
*/
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index 3841b5616d..ee63b3c4a1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -63,7 +63,7 @@ trait SparkListener {
* Called when a task begins remotely fetching its result (will not be called for tasks that do
* not need to fetch the result remotely).
*/
- def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { }
+ def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { }
/**
* Called when a task ends
@@ -131,8 +131,8 @@ object StatsReportListener extends Logging {
def showDistribution(heading: String, d: Distribution, formatNumber: Double => String) {
val stats = d.statCounter
- logInfo(heading + stats)
val quantiles = d.getQuantiles(probabilities).map{formatNumber}
+ logInfo(heading + stats)
logInfo(percentilesHeader)
logInfo("\t" + quantiles.mkString("\t"))
}
@@ -173,8 +173,6 @@ object StatsReportListener extends Logging {
showMillisDistribution(heading, extractLongDistribution(stage, getMetric))
}
-
-
val seconds = 1000L
val minutes = seconds * 60
val hours = minutes * 60
@@ -198,7 +196,6 @@ object StatsReportListener extends Logging {
}
-
case class RuntimePercentage(executorPct: Double, fetchPct: Option[Double], other: Double)
object RuntimePercentage {
def apply(totalTime: Long, metrics: TaskMetrics): RuntimePercentage = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
index d5824e7954..85687ea330 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
@@ -91,4 +91,3 @@ private[spark] class SparkListenerBus() extends Logging {
return true
}
}
-
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index 319c91b933..29b0247f8a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -15,21 +15,20 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
import java.nio.ByteBuffer
import java.util.concurrent.{LinkedBlockingDeque, ThreadFactory, ThreadPoolExecutor, TimeUnit}
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
-import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult}
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.util.Utils
/**
* Runs a thread pool that deserializes and remotely fetches (if necessary) task results.
*/
-private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler)
+private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedulerImpl)
extends Logging {
private val THREADS = sparkEnv.conf.getOrElse("spark.resultGetter.threads", "4").toInt
@@ -43,7 +42,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche
}
def enqueueSuccessfulTask(
- taskSetManager: ClusterTaskSetManager, tid: Long, serializedData: ByteBuffer) {
+ taskSetManager: TaskSetManager, tid: Long, serializedData: ByteBuffer) {
getTaskResultExecutor.execute(new Runnable {
override def run() {
try {
@@ -79,7 +78,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche
})
}
- def enqueueFailedTask(taskSetManager: ClusterTaskSetManager, tid: Long, taskState: TaskState,
+ def enqueueFailedTask(taskSetManager: TaskSetManager, tid: Long, taskState: TaskState,
serializedData: ByteBuffer) {
var reason: Option[TaskEndReason] = None
getTaskResultExecutor.execute(new Runnable {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index 10e0478108..17b6d97e90 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -20,11 +20,12 @@ package org.apache.spark.scheduler
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
/**
- * Low-level task scheduler interface, implemented by both ClusterScheduler and LocalScheduler.
- * Each TaskScheduler schedulers task for a single SparkContext.
- * These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage,
- * and are responsible for sending the tasks to the cluster, running them, retrying if there
- * are failures, and mitigating stragglers. They return events to the DAGScheduler.
+ * Low-level task scheduler interface, currently implemented exclusively by the ClusterScheduler.
+ * This interface allows plugging in different task schedulers. Each TaskScheduler schedulers tasks
+ * for a single SparkContext. These schedulers get sets of tasks submitted to them from the
+ * DAGScheduler for each stage, and are responsible for sending the tasks to the cluster, running
+ * them, retrying if there are failures, and mitigating stragglers. They return events to the
+ * DAGScheduler.
*/
private[spark] trait TaskScheduler {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 2707740d44..56a038dc69 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
import java.nio.ByteBuffer
import java.util.concurrent.atomic.AtomicLong
@@ -28,37 +28,40 @@ import scala.concurrent.duration._
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
-import org.apache.spark.scheduler._
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
/**
- * The main TaskScheduler implementation, for running tasks on a cluster. Clients should first call
- * initialize() and start(), then submit task sets through the runTasks method.
- *
- * This class can work with multiple types of clusters by acting through a SchedulerBackend.
+ * Schedules tasks for multiple types of clusters by acting through a SchedulerBackend.
+ * It can also work with a local setup by using a LocalBackend and setting isLocal to true.
* It handles common logic, like determining a scheduling order across jobs, waking up to launch
* speculative tasks, etc.
*
+ * Clients should first call initialize() and start(), then submit task sets through the
+ * runTasks method.
+ *
* THREADING: SchedulerBackends and task-submitting clients can call this class from multiple
* threads, so it needs locks in public API methods to maintain its state. In addition, some
* SchedulerBackends sycnchronize on themselves when they want to send events here, and then
* acquire a lock on us, so we need to make sure that we don't try to lock the backend while
* we are holding a lock on ourselves.
*/
-private[spark] class ClusterScheduler(val sc: SparkContext)
- extends TaskScheduler
- with Logging
+private[spark] class TaskSchedulerImpl(
+ val sc: SparkContext,
+ val maxTaskFailures: Int = System.getProperty("spark.task.maxFailures", "4").toInt,
+ isLocal: Boolean = false)
+ extends TaskScheduler with Logging
{
val conf = sc.conf
+
// How often to check for speculative tasks
val SPECULATION_INTERVAL = conf.getOrElse("spark.speculation.interval", "100").toLong
// Threshold above which we warn user initial TaskSet may be starved
val STARVATION_TIMEOUT = conf.getOrElse("spark.starvation.timeout", "15000").toLong
- // ClusterTaskSetManagers are not thread safe, so any access to one should be synchronized
+ // TaskSetManagers are not thread safe, so any access to one should be synchronized
// on this class.
- val activeTaskSets = new HashMap[String, ClusterTaskSetManager]
+ val activeTaskSets = new HashMap[String, TaskSetManager]
val taskIdToTaskSetId = new HashMap[Long, String]
val taskIdToExecutorId = new HashMap[Long, String]
@@ -120,7 +123,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
override def start() {
backend.start()
- if (conf.getOrElse("spark.speculation", "false").toBoolean) {
+ if (!isLocal && conf.getOrElse("spark.speculation", "false").toBoolean) {
logInfo("Starting speculative execution thread")
import sc.env.actorSystem.dispatcher
sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL milliseconds,
@@ -134,12 +137,12 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
val tasks = taskSet.tasks
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
- val manager = new ClusterTaskSetManager(this, taskSet)
+ val manager = new TaskSetManager(this, taskSet, maxTaskFailures)
activeTaskSets(taskSet.id) = manager
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
taskSetTaskIds(taskSet.id) = new HashSet[Long]()
- if (!hasReceivedTask) {
+ if (!isLocal && !hasReceivedTask) {
starvationTimer.scheduleAtFixedRate(new TimerTask() {
override def run() {
if (!hasLaunchedTask) {
@@ -293,19 +296,19 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
}
- def handleTaskGettingResult(taskSetManager: ClusterTaskSetManager, tid: Long) {
+ def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long) {
taskSetManager.handleTaskGettingResult(tid)
}
def handleSuccessfulTask(
- taskSetManager: ClusterTaskSetManager,
+ taskSetManager: TaskSetManager,
tid: Long,
taskResult: DirectTaskResult[_]) = synchronized {
taskSetManager.handleSuccessfulTask(tid, taskResult)
}
def handleFailedTask(
- taskSetManager: ClusterTaskSetManager,
+ taskSetManager: TaskSetManager,
tid: Long,
taskState: TaskState,
reason: Option[TaskEndReason]) = synchronized {
@@ -353,7 +356,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
override def defaultParallelism() = backend.defaultParallelism()
-
// Check for speculatable tasks in all our active jobs.
def checkSpeculatableTasks() {
var shouldRevive = false
@@ -365,13 +367,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
}
- // Check for pending tasks in all our active jobs.
- def hasPendingTasks: Boolean = {
- synchronized {
- rootPool.hasPendingTasks()
- }
- }
-
def executorLost(executorId: String, reason: ExecutorLossReason) {
var failedExecutor: Option[String] = None
@@ -430,7 +425,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
-object ClusterScheduler {
+private[spark] object TaskSchedulerImpl {
/**
* Used to balance containers across hosts.
*
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 90f6bcefac..9b95e418d8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -17,32 +17,702 @@
package org.apache.spark.scheduler
-import java.nio.ByteBuffer
+import java.io.NotSerializableException
+import java.util.Arrays
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+import scala.math.max
+import scala.math.min
+
+import org.apache.spark.{ExceptionFailure, FetchFailed, Logging, Resubmitted, SparkEnv,
+ Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState}
import org.apache.spark.TaskState.TaskState
+import org.apache.spark.util.{Clock, SystemClock}
+
/**
- * Tracks and schedules the tasks within a single TaskSet. This class keeps track of the status of
- * each task and is responsible for retries on failure and locality. The main interfaces to it
- * are resourceOffer, which asks the TaskSet whether it wants to run a task on one node, and
- * statusUpdate, which tells it that one of its tasks changed state (e.g. finished).
+ * Schedules the tasks within a single TaskSet in the ClusterScheduler. This class keeps track of
+ * each task, retries tasks if they fail (up to a limited number of times), and
+ * handles locality-aware scheduling for this TaskSet via delay scheduling. The main interfaces
+ * to it are resourceOffer, which asks the TaskSet whether it wants to run a task on one node,
+ * and statusUpdate, which tells it that one of its tasks changed state (e.g. finished).
+ *
+ * THREADING: This class is designed to only be called from code with a lock on the
+ * TaskScheduler (e.g. its event handlers). It should not be called from other threads.
*
- * THREADING: This class is designed to only be called from code with a lock on the TaskScheduler
- * (e.g. its event handlers). It should not be called from other threads.
+ * @param sched the ClusterScheduler associated with the TaskSetManager
+ * @param taskSet the TaskSet to manage scheduling for
+ * @param maxTaskFailures if any particular task fails more than this number of times, the entire
+ * task set will be aborted
*/
-private[spark] trait TaskSetManager extends Schedulable {
- def schedulableQueue = null
-
- def schedulingMode = SchedulingMode.NONE
-
- def taskSet: TaskSet
+private[spark] class TaskSetManager(
+ sched: TaskSchedulerImpl,
+ val taskSet: TaskSet,
+ val maxTaskFailures: Int,
+ clock: Clock = SystemClock)
+ extends Schedulable with Logging
+{
+ // CPUs to request per task
+ val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toInt
+
+ // Quantile of tasks at which to start speculation
+ val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble
+ val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble
+
+ // Serializer for closures and tasks.
+ val env = SparkEnv.get
+ val ser = env.closureSerializer.newInstance()
+
+ val tasks = taskSet.tasks
+ val numTasks = tasks.length
+ val copiesRunning = new Array[Int](numTasks)
+ val successful = new Array[Boolean](numTasks)
+ val numFailures = new Array[Int](numTasks)
+ val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil)
+ var tasksSuccessful = 0
+
+ var weight = 1
+ var minShare = 0
+ var priority = taskSet.priority
+ var stageId = taskSet.stageId
+ var name = "TaskSet_"+taskSet.stageId.toString
+ var parent: Pool = null
+
+ var runningTasks = 0
+ private val runningTasksSet = new HashSet[Long]
+
+ // Set of pending tasks for each executor. These collections are actually
+ // treated as stacks, in which new tasks are added to the end of the
+ // ArrayBuffer and removed from the end. This makes it faster to detect
+ // tasks that repeatedly fail because whenever a task failed, it is put
+ // back at the head of the stack. They are also only cleaned up lazily;
+ // when a task is launched, it remains in all the pending lists except
+ // the one that it was launched from, but gets removed from them later.
+ private val pendingTasksForExecutor = new HashMap[String, ArrayBuffer[Int]]
+
+ // Set of pending tasks for each host. Similar to pendingTasksForExecutor,
+ // but at host level.
+ private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]]
+
+ // Set of pending tasks for each rack -- similar to the above.
+ private val pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]]
+
+ // Set containing pending tasks with no locality preferences.
+ val pendingTasksWithNoPrefs = new ArrayBuffer[Int]
+
+ // Set containing all pending tasks (also used as a stack, as above).
+ val allPendingTasks = new ArrayBuffer[Int]
+
+ // Tasks that can be speculated. Since these will be a small fraction of total
+ // tasks, we'll just hold them in a HashSet.
+ val speculatableTasks = new HashSet[Int]
+
+ // Task index, start and finish time for each task attempt (indexed by task ID)
+ val taskInfos = new HashMap[Long, TaskInfo]
+
+ // Did the TaskSet fail?
+ var failed = false
+ var causeOfFailure = ""
+
+ // How frequently to reprint duplicate exceptions in full, in milliseconds
+ val EXCEPTION_PRINT_INTERVAL =
+ System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong
+
+ // Map of recent exceptions (identified by string representation and top stack frame) to
+ // duplicate count (how many times the same exception has appeared) and time the full exception
+ // was printed. This should ideally be an LRU map that can drop old exceptions automatically.
+ val recentExceptions = HashMap[String, (Int, Long)]()
+
+ // Figure out the current map output tracker epoch and set it on all tasks
+ val epoch = sched.mapOutputTracker.getEpoch
+ logDebug("Epoch for " + taskSet + ": " + epoch)
+ for (t <- tasks) {
+ t.epoch = epoch
+ }
+
+ // Add all our tasks to the pending lists. We do this in reverse order
+ // of task index so that tasks with low indices get launched first.
+ for (i <- (0 until numTasks).reverse) {
+ addPendingTask(i)
+ }
+
+ // Figure out which locality levels we have in our TaskSet, so we can do delay scheduling
+ val myLocalityLevels = computeValidLocalityLevels()
+ val localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level
+
+ // Delay scheduling variables: we keep track of our current locality level and the time we
+ // last launched a task at that level, and move up a level when localityWaits[curLevel] expires.
+ // We then move down if we manage to launch a "more local" task.
+ var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels
+ var lastLaunchTime = clock.getTime() // Time we last launched a task at this level
+
+ override def schedulableQueue = null
+
+ override def schedulingMode = SchedulingMode.NONE
+
+ /**
+ * Add a task to all the pending-task lists that it should be on. If readding is set, we are
+ * re-adding the task so only include it in each list if it's not already there.
+ */
+ private def addPendingTask(index: Int, readding: Boolean = false) {
+ // Utility method that adds `index` to a list only if readding=false or it's not already there
+ def addTo(list: ArrayBuffer[Int]) {
+ if (!readding || !list.contains(index)) {
+ list += index
+ }
+ }
+
+ var hadAliveLocations = false
+ for (loc <- tasks(index).preferredLocations) {
+ for (execId <- loc.executorId) {
+ if (sched.isExecutorAlive(execId)) {
+ addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer))
+ hadAliveLocations = true
+ }
+ }
+ if (sched.hasExecutorsAliveOnHost(loc.host)) {
+ addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer))
+ for (rack <- sched.getRackForHost(loc.host)) {
+ addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer))
+ }
+ hadAliveLocations = true
+ }
+ }
+
+ if (!hadAliveLocations) {
+ // Even though the task might've had preferred locations, all of those hosts or executors
+ // are dead; put it in the no-prefs list so we can schedule it elsewhere right away.
+ addTo(pendingTasksWithNoPrefs)
+ }
+
+ if (!readding) {
+ allPendingTasks += index // No point scanning this whole list to find the old task there
+ }
+ }
+
+ /**
+ * Return the pending tasks list for a given executor ID, or an empty list if
+ * there is no map entry for that host
+ */
+ private def getPendingTasksForExecutor(executorId: String): ArrayBuffer[Int] = {
+ pendingTasksForExecutor.getOrElse(executorId, ArrayBuffer())
+ }
+
+ /**
+ * Return the pending tasks list for a given host, or an empty list if
+ * there is no map entry for that host
+ */
+ private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = {
+ pendingTasksForHost.getOrElse(host, ArrayBuffer())
+ }
+
+ /**
+ * Return the pending rack-local task list for a given rack, or an empty list if
+ * there is no map entry for that rack
+ */
+ private def getPendingTasksForRack(rack: String): ArrayBuffer[Int] = {
+ pendingTasksForRack.getOrElse(rack, ArrayBuffer())
+ }
+
+ /**
+ * Dequeue a pending task from the given list and return its index.
+ * Return None if the list is empty.
+ * This method also cleans up any tasks in the list that have already
+ * been launched, since we want that to happen lazily.
+ */
+ private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
+ while (!list.isEmpty) {
+ val index = list.last
+ list.trimEnd(1)
+ if (copiesRunning(index) == 0 && !successful(index)) {
+ return Some(index)
+ }
+ }
+ return None
+ }
+
+ /** Check whether a task is currently running an attempt on a given host */
+ private def hasAttemptOnHost(taskIndex: Int, host: String): Boolean = {
+ !taskAttempts(taskIndex).exists(_.host == host)
+ }
+
+ /**
+ * Return a speculative task for a given executor if any are available. The task should not have
+ * an attempt running on this host, in case the host is slow. In addition, the task should meet
+ * the given locality constraint.
+ */
+ private def findSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value)
+ : Option[(Int, TaskLocality.Value)] =
+ {
+ speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set
+
+ if (!speculatableTasks.isEmpty) {
+ // Check for process-local or preference-less tasks; note that tasks can be process-local
+ // on multiple nodes when we replicate cached blocks, as in Spark Streaming
+ for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
+ val prefs = tasks(index).preferredLocations
+ val executors = prefs.flatMap(_.executorId)
+ if (prefs.size == 0 || executors.contains(execId)) {
+ speculatableTasks -= index
+ return Some((index, TaskLocality.PROCESS_LOCAL))
+ }
+ }
+
+ // Check for node-local tasks
+ if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) {
+ for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
+ val locations = tasks(index).preferredLocations.map(_.host)
+ if (locations.contains(host)) {
+ speculatableTasks -= index
+ return Some((index, TaskLocality.NODE_LOCAL))
+ }
+ }
+ }
+ // Check for rack-local tasks
+ if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
+ for (rack <- sched.getRackForHost(host)) {
+ for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
+ val racks = tasks(index).preferredLocations.map(_.host).map(sched.getRackForHost)
+ if (racks.contains(rack)) {
+ speculatableTasks -= index
+ return Some((index, TaskLocality.RACK_LOCAL))
+ }
+ }
+ }
+ }
+
+ // Check for non-local tasks
+ if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
+ for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
+ speculatableTasks -= index
+ return Some((index, TaskLocality.ANY))
+ }
+ }
+ }
+
+ return None
+ }
+
+ /**
+ * Dequeue a pending task for a given node and return its index and locality level.
+ * Only search for tasks matching the given locality constraint.
+ */
+ private def findTask(execId: String, host: String, locality: TaskLocality.Value)
+ : Option[(Int, TaskLocality.Value)] =
+ {
+ for (index <- findTaskFromList(getPendingTasksForExecutor(execId))) {
+ return Some((index, TaskLocality.PROCESS_LOCAL))
+ }
+
+ if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) {
+ for (index <- findTaskFromList(getPendingTasksForHost(host))) {
+ return Some((index, TaskLocality.NODE_LOCAL))
+ }
+ }
+
+ if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
+ for {
+ rack <- sched.getRackForHost(host)
+ index <- findTaskFromList(getPendingTasksForRack(rack))
+ } {
+ return Some((index, TaskLocality.RACK_LOCAL))
+ }
+ }
+
+ // Look for no-pref tasks after rack-local tasks since they can run anywhere.
+ for (index <- findTaskFromList(pendingTasksWithNoPrefs)) {
+ return Some((index, TaskLocality.PROCESS_LOCAL))
+ }
+
+ if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
+ for (index <- findTaskFromList(allPendingTasks)) {
+ return Some((index, TaskLocality.ANY))
+ }
+ }
+
+ // Finally, if all else has failed, find a speculative task
+ return findSpeculativeTask(execId, host, locality)
+ }
+
+ /**
+ * Respond to an offer of a single executor from the scheduler by finding a task
+ */
def resourceOffer(
execId: String,
host: String,
availableCpus: Int,
maxLocality: TaskLocality.TaskLocality)
- : Option[TaskDescription]
+ : Option[TaskDescription] =
+ {
+ if (tasksSuccessful < numTasks && availableCpus >= CPUS_PER_TASK) {
+ val curTime = clock.getTime()
+
+ var allowedLocality = getAllowedLocalityLevel(curTime)
+ if (allowedLocality > maxLocality) {
+ allowedLocality = maxLocality // We're not allowed to search for farther-away tasks
+ }
+
+ findTask(execId, host, allowedLocality) match {
+ case Some((index, taskLocality)) => {
+ // Found a task; do some bookkeeping and return a task description
+ val task = tasks(index)
+ val taskId = sched.newTaskId()
+ // Figure out whether this should count as a preferred launch
+ logInfo("Starting task %s:%d as TID %s on executor %s: %s (%s)".format(
+ taskSet.id, index, taskId, execId, host, taskLocality))
+ // Do various bookkeeping
+ copiesRunning(index) += 1
+ val info = new TaskInfo(taskId, index, curTime, execId, host, taskLocality)
+ taskInfos(taskId) = info
+ taskAttempts(index) = info :: taskAttempts(index)
+ // Update our locality level for delay scheduling
+ currentLocalityIndex = getLocalityIndex(taskLocality)
+ lastLaunchTime = curTime
+ // Serialize and return the task
+ val startTime = clock.getTime()
+ // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here
+ // we assume the task can be serialized without exceptions.
+ val serializedTask = Task.serializeWithDependencies(
+ task, sched.sc.addedFiles, sched.sc.addedJars, ser)
+ val timeTaken = clock.getTime() - startTime
+ addRunningTask(taskId)
+ logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
+ taskSet.id, index, serializedTask.limit, timeTaken))
+ val taskName = "task %s:%d".format(taskSet.id, index)
+ if (taskAttempts(index).size == 1)
+ taskStarted(task,info)
+ return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask))
+ }
+ case _ =>
+ }
+ }
+ return None
+ }
+
+ /**
+ * Get the level we can launch tasks according to delay scheduling, based on current wait time.
+ */
+ private def getAllowedLocalityLevel(curTime: Long): TaskLocality.TaskLocality = {
+ while (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex) &&
+ currentLocalityIndex < myLocalityLevels.length - 1)
+ {
+ // Jump to the next locality level, and remove our waiting time for the current one since
+ // we don't want to count it again on the next one
+ lastLaunchTime += localityWaits(currentLocalityIndex)
+ currentLocalityIndex += 1
+ }
+ myLocalityLevels(currentLocalityIndex)
+ }
+
+ /**
+ * Find the index in myLocalityLevels for a given locality. This is also designed to work with
+ * localities that are not in myLocalityLevels (in case we somehow get those) by returning the
+ * next-biggest level we have. Uses the fact that the last value in myLocalityLevels is ANY.
+ */
+ def getLocalityIndex(locality: TaskLocality.TaskLocality): Int = {
+ var index = 0
+ while (locality > myLocalityLevels(index)) {
+ index += 1
+ }
+ index
+ }
+
+ private def taskStarted(task: Task[_], info: TaskInfo) {
+ sched.dagScheduler.taskStarted(task, info)
+ }
+
+ def handleTaskGettingResult(tid: Long) = {
+ val info = taskInfos(tid)
+ info.markGettingResult()
+ sched.dagScheduler.taskGettingResult(tasks(info.index), info)
+ }
+
+ /**
+ * Marks the task as successful and notifies the DAGScheduler that a task has ended.
+ */
+ def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = {
+ val info = taskInfos(tid)
+ val index = info.index
+ info.markSuccessful()
+ removeRunningTask(tid)
+ if (!successful(index)) {
+ logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format(
+ tid, info.duration, info.host, tasksSuccessful, numTasks))
+ sched.dagScheduler.taskEnded(
+ tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
+
+ // Mark successful and stop if all the tasks have succeeded.
+ tasksSuccessful += 1
+ successful(index) = true
+ if (tasksSuccessful == numTasks) {
+ sched.taskSetFinished(this)
+ }
+ } else {
+ logInfo("Ignorning task-finished event for TID " + tid + " because task " +
+ index + " has already completed successfully")
+ }
+ }
+
+ /**
+ * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the
+ * DAG Scheduler.
+ */
+ def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) {
+ val info = taskInfos(tid)
+ if (info.failed) {
+ return
+ }
+ removeRunningTask(tid)
+ val index = info.index
+ info.markFailed()
+ var failureReason = "unknown"
+ if (!successful(index)) {
+ logWarning("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
+ copiesRunning(index) -= 1
+ // Check if the problem is a map output fetch failure. In that case, this
+ // task will never succeed on any node, so tell the scheduler about it.
+ reason.foreach {
+ case fetchFailed: FetchFailed =>
+ logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress)
+ sched.dagScheduler.taskEnded(tasks(index), fetchFailed, null, null, info, null)
+ successful(index) = true
+ tasksSuccessful += 1
+ sched.taskSetFinished(this)
+ removeAllRunningTasks()
+ return
+
+ case TaskKilled =>
+ logWarning("Task %d was killed.".format(tid))
+ sched.dagScheduler.taskEnded(tasks(index), reason.get, null, null, info, null)
+ return
+
+ case ef: ExceptionFailure =>
+ sched.dagScheduler.taskEnded(
+ tasks(index), ef, null, null, info, ef.metrics.getOrElse(null))
+ if (ef.className == classOf[NotSerializableException].getName()) {
+ // If the task result wasn't rerializable, there's no point in trying to re-execute it.
+ logError("Task %s:%s had a not serializable result: %s; not retrying".format(
+ taskSet.id, index, ef.description))
+ abort("Task %s:%s had a not serializable result: %s".format(
+ taskSet.id, index, ef.description))
+ return
+ }
+ val key = ef.description
+ failureReason = "Exception failure: %s".format(ef.description)
+ val now = clock.getTime()
+ val (printFull, dupCount) = {
+ if (recentExceptions.contains(key)) {
+ val (dupCount, printTime) = recentExceptions(key)
+ if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
+ recentExceptions(key) = (0, now)
+ (true, 0)
+ } else {
+ recentExceptions(key) = (dupCount + 1, printTime)
+ (false, dupCount + 1)
+ }
+ } else {
+ recentExceptions(key) = (0, now)
+ (true, 0)
+ }
+ }
+ if (printFull) {
+ val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
+ logWarning("Loss was due to %s\n%s\n%s".format(
+ ef.className, ef.description, locs.mkString("\n")))
+ } else {
+ logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
+ }
+
+ case TaskResultLost =>
+ failureReason = "Lost result for TID %s on host %s".format(tid, info.host)
+ logWarning(failureReason)
+ sched.dagScheduler.taskEnded(tasks(index), TaskResultLost, null, null, info, null)
+
+ case _ => {}
+ }
+ // On non-fetch failures, re-enqueue the task as pending for a max number of retries
+ addPendingTask(index)
+ if (state != TaskState.KILLED) {
+ numFailures(index) += 1
+ if (numFailures(index) >= maxTaskFailures) {
+ logError("Task %s:%d failed %d times; aborting job".format(
+ taskSet.id, index, maxTaskFailures))
+ abort("Task %s:%d failed %d times (most recent failure: %s)".format(
+ taskSet.id, index, maxTaskFailures, failureReason))
+ }
+ }
+ } else {
+ logInfo("Ignoring task-lost event for TID " + tid +
+ " because task " + index + " is already finished")
+ }
+ }
+
+ def error(message: String) {
+ // Save the error message
+ abort("Error: " + message)
+ }
+
+ def abort(message: String) {
+ failed = true
+ causeOfFailure = message
+ // TODO: Kill running tasks if we were not terminated due to a Mesos error
+ sched.dagScheduler.taskSetFailed(taskSet, message)
+ removeAllRunningTasks()
+ sched.taskSetFinished(this)
+ }
+
+ /** If the given task ID is not in the set of running tasks, adds it.
+ *
+ * Used to keep track of the number of running tasks, for enforcing scheduling policies.
+ */
+ def addRunningTask(tid: Long) {
+ if (runningTasksSet.add(tid) && parent != null) {
+ parent.increaseRunningTasks(1)
+ }
+ runningTasks = runningTasksSet.size
+ }
+
+ /** If the given task ID is in the set of running tasks, removes it. */
+ def removeRunningTask(tid: Long) {
+ if (runningTasksSet.remove(tid) && parent != null) {
+ parent.decreaseRunningTasks(1)
+ }
+ runningTasks = runningTasksSet.size
+ }
+
+ private[scheduler] def removeAllRunningTasks() {
+ val numRunningTasks = runningTasksSet.size
+ runningTasksSet.clear()
+ if (parent != null) {
+ parent.decreaseRunningTasks(numRunningTasks)
+ }
+ runningTasks = 0
+ }
+
+ override def getSchedulableByName(name: String): Schedulable = {
+ return null
+ }
+
+ override def addSchedulable(schedulable: Schedulable) {}
+
+ override def removeSchedulable(schedulable: Schedulable) {}
+
+ override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
+ var sortedTaskSetQueue = ArrayBuffer[TaskSetManager](this)
+ sortedTaskSetQueue += this
+ return sortedTaskSetQueue
+ }
+
+ /** Called by TaskScheduler when an executor is lost so we can re-enqueue our tasks */
+ override def executorLost(execId: String, host: String) {
+ logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id)
+
+ // Re-enqueue pending tasks for this host based on the status of the cluster -- for example, a
+ // task that used to have locations on only this host might now go to the no-prefs list. Note
+ // that it's okay if we add a task to the same queue twice (if it had multiple preferred
+ // locations), because findTaskFromList will skip already-running tasks.
+ for (index <- getPendingTasksForExecutor(execId)) {
+ addPendingTask(index, readding=true)
+ }
+ for (index <- getPendingTasksForHost(host)) {
+ addPendingTask(index, readding=true)
+ }
+
+ // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage
+ if (tasks(0).isInstanceOf[ShuffleMapTask]) {
+ for ((tid, info) <- taskInfos if info.executorId == execId) {
+ val index = taskInfos(tid).index
+ if (successful(index)) {
+ successful(index) = false
+ copiesRunning(index) -= 1
+ tasksSuccessful -= 1
+ addPendingTask(index)
+ // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
+ // stage finishes when a total of tasks.size tasks finish.
+ sched.dagScheduler.taskEnded(tasks(index), Resubmitted, null, null, info, null)
+ }
+ }
+ }
+ // Also re-enqueue any tasks that were running on the node
+ for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
+ handleFailedTask(tid, TaskState.KILLED, None)
+ }
+ }
+
+ /**
+ * Check for tasks to be speculated and return true if there are any. This is called periodically
+ * by the TaskScheduler.
+ *
+ * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that
+ * we don't scan the whole task set. It might also help to make this sorted by launch time.
+ */
+ override def checkSpeculatableTasks(): Boolean = {
+ // Can't speculate if we only have one task, or if all tasks have finished.
+ if (numTasks == 1 || tasksSuccessful == numTasks) {
+ return false
+ }
+ var foundTasks = false
+ val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
+ logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
+ if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) {
+ val time = clock.getTime()
+ val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
+ Arrays.sort(durations)
+ val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.size - 1))
+ val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100)
+ // TODO: Threshold should also look at standard deviation of task durations and have a lower
+ // bound based on that.
+ logDebug("Task length threshold for speculation: " + threshold)
+ for ((tid, info) <- taskInfos) {
+ val index = info.index
+ if (!successful(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
+ !speculatableTasks.contains(index)) {
+ logInfo(
+ "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
+ taskSet.id, index, info.host, threshold))
+ speculatableTasks += index
+ foundTasks = true
+ }
+ }
+ }
+ return foundTasks
+ }
+
+ private def getLocalityWait(level: TaskLocality.TaskLocality): Long = {
+ val defaultWait = System.getProperty("spark.locality.wait", "3000")
+ level match {
+ case TaskLocality.PROCESS_LOCAL =>
+ System.getProperty("spark.locality.wait.process", defaultWait).toLong
+ case TaskLocality.NODE_LOCAL =>
+ System.getProperty("spark.locality.wait.node", defaultWait).toLong
+ case TaskLocality.RACK_LOCAL =>
+ System.getProperty("spark.locality.wait.rack", defaultWait).toLong
+ case TaskLocality.ANY =>
+ 0L
+ }
+ }
- def error(message: String)
+ /**
+ * Compute the locality levels used in this TaskSet. Assumes that all tasks have already been
+ * added to queues using addPendingTask.
+ */
+ private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = {
+ import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY}
+ val levels = new ArrayBuffer[TaskLocality.TaskLocality]
+ if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0) {
+ levels += PROCESS_LOCAL
+ }
+ if (!pendingTasksForHost.isEmpty && getLocalityWait(NODE_LOCAL) != 0) {
+ levels += NODE_LOCAL
+ }
+ if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0) {
+ levels += RACK_LOCAL
+ }
+ levels += ANY
+ logDebug("Valid locality levels for " + taskSet + ": " + levels.mkString(", "))
+ levels.toArray
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/WorkerOffer.scala b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala
index 938f62883a..ba6bab3f91 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/WorkerOffer.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
/**
* Represents free resources available on an executor.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
deleted file mode 100644
index a46b16b92f..0000000000
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
+++ /dev/null
@@ -1,714 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler.cluster
-
-import java.io.NotSerializableException
-import java.util.Arrays
-
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
-import scala.math.max
-import scala.math.min
-
-import org.apache.spark.{ExceptionFailure, FetchFailed, Logging, Resubmitted, SparkEnv,
- Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState}
-import org.apache.spark.TaskState.TaskState
-import org.apache.spark.scheduler._
-import org.apache.spark.util.{SystemClock, Clock}
-
-
-/**
- * Schedules the tasks within a single TaskSet in the ClusterScheduler. This class keeps track of
- * the status of each task, retries tasks if they fail (up to a limited number of times), and
- * handles locality-aware scheduling for this TaskSet via delay scheduling. The main interfaces
- * to it are resourceOffer, which asks the TaskSet whether it wants to run a task on one node,
- * and statusUpdate, which tells it that one of its tasks changed state (e.g. finished).
- *
- * THREADING: This class is designed to only be called from code with a lock on the
- * ClusterScheduler (e.g. its event handlers). It should not be called from other threads.
- */
-private[spark] class ClusterTaskSetManager(
- sched: ClusterScheduler,
- val taskSet: TaskSet,
- clock: Clock = SystemClock)
- extends TaskSetManager
- with Logging
-{
- val conf = sched.sc.conf
- // CPUs to request per task
- val CPUS_PER_TASK = conf.getOrElse("spark.task.cpus", "1").toInt
-
- // Maximum times a task is allowed to fail before failing the job
- val MAX_TASK_FAILURES = conf.getOrElse("spark.task.maxFailures", "4").toInt
-
- // Quantile of tasks at which to start speculation
- val SPECULATION_QUANTILE = conf.getOrElse("spark.speculation.quantile", "0.75").toDouble
- val SPECULATION_MULTIPLIER = conf.getOrElse("spark.speculation.multiplier", "1.5").toDouble
-
- // Serializer for closures and tasks.
- val env = SparkEnv.get
- val ser = env.closureSerializer.newInstance()
-
- val tasks = taskSet.tasks
- val numTasks = tasks.length
- val copiesRunning = new Array[Int](numTasks)
- val successful = new Array[Boolean](numTasks)
- val numFailures = new Array[Int](numTasks)
- val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil)
- var tasksSuccessful = 0
-
- var weight = 1
- var minShare = 0
- var priority = taskSet.priority
- var stageId = taskSet.stageId
- var name = "TaskSet_"+taskSet.stageId.toString
- var parent: Pool = null
-
- var runningTasks = 0
- private val runningTasksSet = new HashSet[Long]
-
- // Set of pending tasks for each executor. These collections are actually
- // treated as stacks, in which new tasks are added to the end of the
- // ArrayBuffer and removed from the end. This makes it faster to detect
- // tasks that repeatedly fail because whenever a task failed, it is put
- // back at the head of the stack. They are also only cleaned up lazily;
- // when a task is launched, it remains in all the pending lists except
- // the one that it was launched from, but gets removed from them later.
- private val pendingTasksForExecutor = new HashMap[String, ArrayBuffer[Int]]
-
- // Set of pending tasks for each host. Similar to pendingTasksForExecutor,
- // but at host level.
- private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]]
-
- // Set of pending tasks for each rack -- similar to the above.
- private val pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]]
-
- // Set containing pending tasks with no locality preferences.
- val pendingTasksWithNoPrefs = new ArrayBuffer[Int]
-
- // Set containing all pending tasks (also used as a stack, as above).
- val allPendingTasks = new ArrayBuffer[Int]
-
- // Tasks that can be speculated. Since these will be a small fraction of total
- // tasks, we'll just hold them in a HashSet.
- val speculatableTasks = new HashSet[Int]
-
- // Task index, start and finish time for each task attempt (indexed by task ID)
- val taskInfos = new HashMap[Long, TaskInfo]
-
- // Did the TaskSet fail?
- var failed = false
- var causeOfFailure = ""
-
- // How frequently to reprint duplicate exceptions in full, in milliseconds
- val EXCEPTION_PRINT_INTERVAL =
- conf.getOrElse("spark.logging.exceptionPrintInterval", "10000").toLong
-
- // Map of recent exceptions (identified by string representation and top stack frame) to
- // duplicate count (how many times the same exception has appeared) and time the full exception
- // was printed. This should ideally be an LRU map that can drop old exceptions automatically.
- val recentExceptions = HashMap[String, (Int, Long)]()
-
- // Figure out the current map output tracker epoch and set it on all tasks
- val epoch = sched.mapOutputTracker.getEpoch
- logDebug("Epoch for " + taskSet + ": " + epoch)
- for (t <- tasks) {
- t.epoch = epoch
- }
-
- // Add all our tasks to the pending lists. We do this in reverse order
- // of task index so that tasks with low indices get launched first.
- for (i <- (0 until numTasks).reverse) {
- addPendingTask(i)
- }
-
- // Figure out which locality levels we have in our TaskSet, so we can do delay scheduling
- val myLocalityLevels = computeValidLocalityLevels()
- val localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level
-
- // Delay scheduling variables: we keep track of our current locality level and the time we
- // last launched a task at that level, and move up a level when localityWaits[curLevel] expires.
- // We then move down if we manage to launch a "more local" task.
- var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels
- var lastLaunchTime = clock.getTime() // Time we last launched a task at this level
-
- /**
- * Add a task to all the pending-task lists that it should be on. If readding is set, we are
- * re-adding the task so only include it in each list if it's not already there.
- */
- private def addPendingTask(index: Int, readding: Boolean = false) {
- // Utility method that adds `index` to a list only if readding=false or it's not already there
- def addTo(list: ArrayBuffer[Int]) {
- if (!readding || !list.contains(index)) {
- list += index
- }
- }
-
- var hadAliveLocations = false
- for (loc <- tasks(index).preferredLocations) {
- for (execId <- loc.executorId) {
- if (sched.isExecutorAlive(execId)) {
- addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer))
- hadAliveLocations = true
- }
- }
- if (sched.hasExecutorsAliveOnHost(loc.host)) {
- addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer))
- for (rack <- sched.getRackForHost(loc.host)) {
- addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer))
- }
- hadAliveLocations = true
- }
- }
-
- if (!hadAliveLocations) {
- // Even though the task might've had preferred locations, all of those hosts or executors
- // are dead; put it in the no-prefs list so we can schedule it elsewhere right away.
- addTo(pendingTasksWithNoPrefs)
- }
-
- if (!readding) {
- allPendingTasks += index // No point scanning this whole list to find the old task there
- }
- }
-
- /**
- * Return the pending tasks list for a given executor ID, or an empty list if
- * there is no map entry for that host
- */
- private def getPendingTasksForExecutor(executorId: String): ArrayBuffer[Int] = {
- pendingTasksForExecutor.getOrElse(executorId, ArrayBuffer())
- }
-
- /**
- * Return the pending tasks list for a given host, or an empty list if
- * there is no map entry for that host
- */
- private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = {
- pendingTasksForHost.getOrElse(host, ArrayBuffer())
- }
-
- /**
- * Return the pending rack-local task list for a given rack, or an empty list if
- * there is no map entry for that rack
- */
- private def getPendingTasksForRack(rack: String): ArrayBuffer[Int] = {
- pendingTasksForRack.getOrElse(rack, ArrayBuffer())
- }
-
- /**
- * Dequeue a pending task from the given list and return its index.
- * Return None if the list is empty.
- * This method also cleans up any tasks in the list that have already
- * been launched, since we want that to happen lazily.
- */
- private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
- while (!list.isEmpty) {
- val index = list.last
- list.trimEnd(1)
- if (copiesRunning(index) == 0 && !successful(index)) {
- return Some(index)
- }
- }
- return None
- }
-
- /** Check whether a task is currently running an attempt on a given host */
- private def hasAttemptOnHost(taskIndex: Int, host: String): Boolean = {
- !taskAttempts(taskIndex).exists(_.host == host)
- }
-
- /**
- * Return a speculative task for a given executor if any are available. The task should not have
- * an attempt running on this host, in case the host is slow. In addition, the task should meet
- * the given locality constraint.
- */
- private def findSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value)
- : Option[(Int, TaskLocality.Value)] =
- {
- speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set
-
- if (!speculatableTasks.isEmpty) {
- // Check for process-local or preference-less tasks; note that tasks can be process-local
- // on multiple nodes when we replicate cached blocks, as in Spark Streaming
- for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
- val prefs = tasks(index).preferredLocations
- val executors = prefs.flatMap(_.executorId)
- if (prefs.size == 0 || executors.contains(execId)) {
- speculatableTasks -= index
- return Some((index, TaskLocality.PROCESS_LOCAL))
- }
- }
-
- // Check for node-local tasks
- if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) {
- for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
- val locations = tasks(index).preferredLocations.map(_.host)
- if (locations.contains(host)) {
- speculatableTasks -= index
- return Some((index, TaskLocality.NODE_LOCAL))
- }
- }
- }
-
- // Check for rack-local tasks
- if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
- for (rack <- sched.getRackForHost(host)) {
- for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
- val racks = tasks(index).preferredLocations.map(_.host).map(sched.getRackForHost)
- if (racks.contains(rack)) {
- speculatableTasks -= index
- return Some((index, TaskLocality.RACK_LOCAL))
- }
- }
- }
- }
-
- // Check for non-local tasks
- if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
- for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
- speculatableTasks -= index
- return Some((index, TaskLocality.ANY))
- }
- }
- }
-
- return None
- }
-
- /**
- * Dequeue a pending task for a given node and return its index and locality level.
- * Only search for tasks matching the given locality constraint.
- */
- private def findTask(execId: String, host: String, locality: TaskLocality.Value)
- : Option[(Int, TaskLocality.Value)] =
- {
- for (index <- findTaskFromList(getPendingTasksForExecutor(execId))) {
- return Some((index, TaskLocality.PROCESS_LOCAL))
- }
-
- if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) {
- for (index <- findTaskFromList(getPendingTasksForHost(host))) {
- return Some((index, TaskLocality.NODE_LOCAL))
- }
- }
-
- if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
- for {
- rack <- sched.getRackForHost(host)
- index <- findTaskFromList(getPendingTasksForRack(rack))
- } {
- return Some((index, TaskLocality.RACK_LOCAL))
- }
- }
-
- // Look for no-pref tasks after rack-local tasks since they can run anywhere.
- for (index <- findTaskFromList(pendingTasksWithNoPrefs)) {
- return Some((index, TaskLocality.PROCESS_LOCAL))
- }
-
- if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
- for (index <- findTaskFromList(allPendingTasks)) {
- return Some((index, TaskLocality.ANY))
- }
- }
-
- // Finally, if all else has failed, find a speculative task
- return findSpeculativeTask(execId, host, locality)
- }
-
- /**
- * Respond to an offer of a single executor from the scheduler by finding a task
- */
- override def resourceOffer(
- execId: String,
- host: String,
- availableCpus: Int,
- maxLocality: TaskLocality.TaskLocality)
- : Option[TaskDescription] =
- {
- if (tasksSuccessful < numTasks && availableCpus >= CPUS_PER_TASK) {
- val curTime = clock.getTime()
-
- var allowedLocality = getAllowedLocalityLevel(curTime)
- if (allowedLocality > maxLocality) {
- allowedLocality = maxLocality // We're not allowed to search for farther-away tasks
- }
-
- findTask(execId, host, allowedLocality) match {
- case Some((index, taskLocality)) => {
- // Found a task; do some bookkeeping and return a task description
- val task = tasks(index)
- val taskId = sched.newTaskId()
- // Figure out whether this should count as a preferred launch
- logInfo("Starting task %s:%d as TID %s on executor %s: %s (%s)".format(
- taskSet.id, index, taskId, execId, host, taskLocality))
- // Do various bookkeeping
- copiesRunning(index) += 1
- val info = new TaskInfo(taskId, index, curTime, execId, host, taskLocality)
- taskInfos(taskId) = info
- taskAttempts(index) = info :: taskAttempts(index)
- // Update our locality level for delay scheduling
- currentLocalityIndex = getLocalityIndex(taskLocality)
- lastLaunchTime = curTime
- // Serialize and return the task
- val startTime = clock.getTime()
- // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here
- // we assume the task can be serialized without exceptions.
- val serializedTask = Task.serializeWithDependencies(
- task, sched.sc.addedFiles, sched.sc.addedJars, ser)
- val timeTaken = clock.getTime() - startTime
- addRunningTask(taskId)
- logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
- taskSet.id, index, serializedTask.limit, timeTaken))
- val taskName = "task %s:%d".format(taskSet.id, index)
- info.serializedSize = serializedTask.limit
- if (taskAttempts(index).size == 1)
- taskStarted(task,info)
- return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask))
- }
- case _ =>
- }
- }
- return None
- }
-
- /**
- * Get the level we can launch tasks according to delay scheduling, based on current wait time.
- */
- private def getAllowedLocalityLevel(curTime: Long): TaskLocality.TaskLocality = {
- while (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex) &&
- currentLocalityIndex < myLocalityLevels.length - 1)
- {
- // Jump to the next locality level, and remove our waiting time for the current one since
- // we don't want to count it again on the next one
- lastLaunchTime += localityWaits(currentLocalityIndex)
- currentLocalityIndex += 1
- }
- myLocalityLevels(currentLocalityIndex)
- }
-
- /**
- * Find the index in myLocalityLevels for a given locality. This is also designed to work with
- * localities that are not in myLocalityLevels (in case we somehow get those) by returning the
- * next-biggest level we have. Uses the fact that the last value in myLocalityLevels is ANY.
- */
- def getLocalityIndex(locality: TaskLocality.TaskLocality): Int = {
- var index = 0
- while (locality > myLocalityLevels(index)) {
- index += 1
- }
- index
- }
-
- private def taskStarted(task: Task[_], info: TaskInfo) {
- sched.dagScheduler.taskStarted(task, info)
- }
-
- def handleTaskGettingResult(tid: Long) = {
- val info = taskInfos(tid)
- info.markGettingResult()
- sched.dagScheduler.taskGettingResult(tasks(info.index), info)
- }
-
- /**
- * Marks the task as successful and notifies the DAGScheduler that a task has ended.
- */
- def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = {
- val info = taskInfos(tid)
- val index = info.index
- info.markSuccessful()
- removeRunningTask(tid)
- if (!successful(index)) {
- logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format(
- tid, info.duration, info.host, tasksSuccessful, numTasks))
- sched.dagScheduler.taskEnded(
- tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
-
- // Mark successful and stop if all the tasks have succeeded.
- tasksSuccessful += 1
- successful(index) = true
- if (tasksSuccessful == numTasks) {
- sched.taskSetFinished(this)
- }
- } else {
- logInfo("Ignorning task-finished event for TID " + tid + " because task " +
- index + " has already completed successfully")
- }
- }
-
- /**
- * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the
- * DAG Scheduler.
- */
- def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) {
- val info = taskInfos(tid)
- if (info.failed) {
- return
- }
- removeRunningTask(tid)
- val index = info.index
- info.markFailed()
- if (!successful(index)) {
- logWarning("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
- copiesRunning(index) -= 1
- // Check if the problem is a map output fetch failure. In that case, this
- // task will never succeed on any node, so tell the scheduler about it.
- reason.foreach {
- case fetchFailed: FetchFailed =>
- logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress)
- sched.dagScheduler.taskEnded(tasks(index), fetchFailed, null, null, info, null)
- successful(index) = true
- tasksSuccessful += 1
- sched.taskSetFinished(this)
- removeAllRunningTasks()
- return
-
- case TaskKilled =>
- logWarning("Task %d was killed.".format(tid))
- sched.dagScheduler.taskEnded(tasks(index), reason.get, null, null, info, null)
- return
-
- case ef: ExceptionFailure =>
- sched.dagScheduler.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null))
- if (ef.className == classOf[NotSerializableException].getName()) {
- // If the task result wasn't serializable, there's no point in trying to re-execute it.
- logError("Task %s:%s had a not serializable result: %s; not retrying".format(
- taskSet.id, index, ef.description))
- abort("Task %s:%s had a not serializable result: %s".format(
- taskSet.id, index, ef.description))
- return
- }
- val key = ef.description
- val now = clock.getTime()
- val (printFull, dupCount) = {
- if (recentExceptions.contains(key)) {
- val (dupCount, printTime) = recentExceptions(key)
- if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
- recentExceptions(key) = (0, now)
- (true, 0)
- } else {
- recentExceptions(key) = (dupCount + 1, printTime)
- (false, dupCount + 1)
- }
- } else {
- recentExceptions(key) = (0, now)
- (true, 0)
- }
- }
- if (printFull) {
- val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
- logWarning("Loss was due to %s\n%s\n%s".format(
- ef.className, ef.description, locs.mkString("\n")))
- } else {
- logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
- }
-
- case TaskResultLost =>
- logWarning("Lost result for TID %s on host %s".format(tid, info.host))
- sched.dagScheduler.taskEnded(tasks(index), TaskResultLost, null, null, info, null)
-
- case _ => {}
- }
- // On non-fetch failures, re-enqueue the task as pending for a max number of retries
- addPendingTask(index)
- if (state != TaskState.KILLED) {
- numFailures(index) += 1
- if (numFailures(index) >= MAX_TASK_FAILURES) {
- logError("Task %s:%d failed %d times; aborting job".format(
- taskSet.id, index, MAX_TASK_FAILURES))
- abort("Task %s:%d failed %d times".format(taskSet.id, index, MAX_TASK_FAILURES))
- }
- }
- } else {
- logInfo("Ignoring task-lost event for TID " + tid +
- " because task " + index + " is already finished")
- }
- }
-
- override def error(message: String) {
- // Save the error message
- abort("Error: " + message)
- }
-
- def abort(message: String) {
- failed = true
- causeOfFailure = message
- // TODO: Kill running tasks if we were not terminated due to a Mesos error
- sched.dagScheduler.taskSetFailed(taskSet, message)
- removeAllRunningTasks()
- sched.taskSetFinished(this)
- }
-
- /** If the given task ID is not in the set of running tasks, adds it.
- *
- * Used to keep track of the number of running tasks, for enforcing scheduling policies.
- */
- def addRunningTask(tid: Long) {
- if (runningTasksSet.add(tid) && parent != null) {
- parent.increaseRunningTasks(1)
- }
- runningTasks = runningTasksSet.size
- }
-
- /** If the given task ID is in the set of running tasks, removes it. */
- def removeRunningTask(tid: Long) {
- if (runningTasksSet.remove(tid) && parent != null) {
- parent.decreaseRunningTasks(1)
- }
- runningTasks = runningTasksSet.size
- }
-
- private[cluster] def removeAllRunningTasks() {
- val numRunningTasks = runningTasksSet.size
- runningTasksSet.clear()
- if (parent != null) {
- parent.decreaseRunningTasks(numRunningTasks)
- }
- runningTasks = 0
- }
-
- override def getSchedulableByName(name: String): Schedulable = {
- return null
- }
-
- override def addSchedulable(schedulable: Schedulable) {}
-
- override def removeSchedulable(schedulable: Schedulable) {}
-
- override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
- var sortedTaskSetQueue = ArrayBuffer[TaskSetManager](this)
- sortedTaskSetQueue += this
- return sortedTaskSetQueue
- }
-
- /** Called by cluster scheduler when an executor is lost so we can re-enqueue our tasks */
- override def executorLost(execId: String, host: String) {
- logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id)
-
- // Re-enqueue pending tasks for this host based on the status of the cluster -- for example, a
- // task that used to have locations on only this host might now go to the no-prefs list. Note
- // that it's okay if we add a task to the same queue twice (if it had multiple preferred
- // locations), because findTaskFromList will skip already-running tasks.
- for (index <- getPendingTasksForExecutor(execId)) {
- addPendingTask(index, readding=true)
- }
- for (index <- getPendingTasksForHost(host)) {
- addPendingTask(index, readding=true)
- }
-
- // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage
- if (tasks(0).isInstanceOf[ShuffleMapTask]) {
- for ((tid, info) <- taskInfos if info.executorId == execId) {
- val index = taskInfos(tid).index
- if (successful(index)) {
- successful(index) = false
- copiesRunning(index) -= 1
- tasksSuccessful -= 1
- addPendingTask(index)
- // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
- // stage finishes when a total of tasks.size tasks finish.
- sched.dagScheduler.taskEnded(tasks(index), Resubmitted, null, null, info, null)
- }
- }
- }
- // Also re-enqueue any tasks that were running on the node
- for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
- handleFailedTask(tid, TaskState.KILLED, None)
- }
- }
-
- /**
- * Check for tasks to be speculated and return true if there are any. This is called periodically
- * by the ClusterScheduler.
- *
- * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that
- * we don't scan the whole task set. It might also help to make this sorted by launch time.
- */
- override def checkSpeculatableTasks(): Boolean = {
- // Can't speculate if we only have one task, or if all tasks have finished.
- if (numTasks == 1 || tasksSuccessful == numTasks) {
- return false
- }
- var foundTasks = false
- val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
- logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
- if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) {
- val time = clock.getTime()
- val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
- Arrays.sort(durations)
- val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.size - 1))
- val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100)
- // TODO: Threshold should also look at standard deviation of task durations and have a lower
- // bound based on that.
- logDebug("Task length threshold for speculation: " + threshold)
- for ((tid, info) <- taskInfos) {
- val index = info.index
- if (!successful(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
- !speculatableTasks.contains(index)) {
- logInfo(
- "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
- taskSet.id, index, info.host, threshold))
- speculatableTasks += index
- foundTasks = true
- }
- }
- }
- return foundTasks
- }
-
- override def hasPendingTasks(): Boolean = {
- numTasks > 0 && tasksSuccessful < numTasks
- }
-
- private def getLocalityWait(level: TaskLocality.TaskLocality): Long = {
- val defaultWait = conf.getOrElse("spark.locality.wait", "3000")
- level match {
- case TaskLocality.PROCESS_LOCAL =>
- conf.getOrElse("spark.locality.wait.process", defaultWait).toLong
- case TaskLocality.NODE_LOCAL =>
- conf.getOrElse("spark.locality.wait.node", defaultWait).toLong
- case TaskLocality.RACK_LOCAL =>
- conf.getOrElse("spark.locality.wait.rack", defaultWait).toLong
- case TaskLocality.ANY =>
- 0L
- }
- }
-
- /**
- * Compute the locality levels used in this TaskSet. Assumes that all tasks have already been
- * added to queues using addPendingTask.
- */
- private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = {
- import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY}
- val levels = new ArrayBuffer[TaskLocality.TaskLocality]
- if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0) {
- levels += PROCESS_LOCAL
- }
- if (!pendingTasksForHost.isEmpty && getLocalityWait(NODE_LOCAL) != 0) {
- levels += NODE_LOCAL
- }
- if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0) {
- levels += RACK_LOCAL
- }
- levels += ANY
- logDebug("Valid locality levels for " + taskSet + ": " + levels.mkString(", "))
- levels.toArray
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 156b01b149..b4a3ecca39 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -28,8 +28,10 @@ import akka.actor._
import akka.pattern.ask
import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
+import org.apache.spark.{SparkException, Logging, TaskState}
import org.apache.spark.{Logging, SparkException, TaskState}
-import org.apache.spark.scheduler.TaskDescription
+import org.apache.spark.scheduler.{TaskSchedulerImpl, SchedulerBackend, SlaveLost, TaskDescription,
+ WorkerOffer}
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.util.{AkkaUtils, Utils}
@@ -42,7 +44,7 @@ import org.apache.spark.util.{AkkaUtils, Utils}
* (spark.deploy.*).
*/
private[spark]
-class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: ActorSystem)
+class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: ActorSystem)
extends SchedulerBackend with Logging
{
// Use an atomic variable to track total number of cores in the cluster for simplicity and speed
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
index d74f000ebb..f41fbbd1f3 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
@@ -19,10 +19,12 @@ package org.apache.spark.scheduler.cluster
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{Path, FileSystem}
+
import org.apache.spark.{Logging, SparkContext}
+import org.apache.spark.scheduler.TaskSchedulerImpl
private[spark] class SimrSchedulerBackend(
- scheduler: ClusterScheduler,
+ scheduler: TaskSchedulerImpl,
sc: SparkContext,
driverFilePath: String)
extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index de69e3260d..224077566d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -17,14 +17,16 @@
package org.apache.spark.scheduler.cluster
+import scala.collection.mutable.HashMap
+
import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.deploy.client.{Client, ClientListener}
import org.apache.spark.deploy.{Command, ApplicationDescription}
-import scala.collection.mutable.HashMap
+import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskSchedulerImpl}
import org.apache.spark.util.Utils
private[spark] class SparkDeploySchedulerBackend(
- scheduler: ClusterScheduler,
+ scheduler: TaskSchedulerImpl,
sc: SparkContext,
masters: Array[String],
appName: String)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index 1695374152..9e2cd3f699 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -30,7 +30,8 @@ import org.apache.mesos._
import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _}
import org.apache.spark.{SparkException, Logging, SparkContext, TaskState}
-import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedulerBackend}
+import org.apache.spark.scheduler.TaskSchedulerImpl
+import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
/**
* A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds
@@ -43,7 +44,7 @@ import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedu
* remove this.
*/
private[spark] class CoarseMesosSchedulerBackend(
- scheduler: ClusterScheduler,
+ scheduler: TaskSchedulerImpl,
sc: SparkContext,
master: String,
appName: String)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
index 8dfd4d5fb3..be96382983 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
@@ -30,9 +30,8 @@ import org.apache.mesos._
import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _}
import org.apache.spark.{Logging, SparkException, SparkContext, TaskState}
-import org.apache.spark.scheduler.TaskDescription
-import org.apache.spark.scheduler.cluster.{ClusterScheduler, ExecutorExited, ExecutorLossReason}
-import org.apache.spark.scheduler.cluster.{SchedulerBackend, SlaveLost, WorkerOffer}
+import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SchedulerBackend, SlaveLost,
+ TaskDescription, TaskSchedulerImpl, WorkerOffer}
import org.apache.spark.util.Utils
/**
@@ -41,7 +40,7 @@ import org.apache.spark.util.Utils
* from multiple apps can run on different cores) and in time (a core can switch ownership).
*/
private[spark] class MesosSchedulerBackend(
- scheduler: ClusterScheduler,
+ scheduler: TaskSchedulerImpl,
sc: SparkContext,
master: String,
appName: String)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
new file mode 100644
index 0000000000..4edc6a0d3f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.local
+
+import java.nio.ByteBuffer
+
+import akka.actor.{Actor, ActorRef, Props}
+
+import org.apache.spark.{Logging, SparkContext, SparkEnv, TaskState}
+import org.apache.spark.TaskState.TaskState
+import org.apache.spark.executor.{Executor, ExecutorBackend}
+import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer}
+
+private case class ReviveOffers()
+
+private case class StatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer)
+
+private case class KillTask(taskId: Long)
+
+/**
+ * Calls to LocalBackend are all serialized through LocalActor. Using an actor makes the calls on
+ * LocalBackend asynchronous, which is necessary to prevent deadlock between LocalBackend
+ * and the ClusterScheduler.
+ */
+private[spark] class LocalActor(
+ scheduler: TaskSchedulerImpl,
+ executorBackend: LocalBackend,
+ private val totalCores: Int) extends Actor with Logging {
+
+ private var freeCores = totalCores
+
+ private val localExecutorId = "localhost"
+ private val localExecutorHostname = "localhost"
+
+ val executor = new Executor(localExecutorId, localExecutorHostname, Seq.empty, isLocal = true)
+
+ def receive = {
+ case ReviveOffers =>
+ reviveOffers()
+
+ case StatusUpdate(taskId, state, serializedData) =>
+ scheduler.statusUpdate(taskId, state, serializedData)
+ if (TaskState.isFinished(state)) {
+ freeCores += 1
+ reviveOffers()
+ }
+
+ case KillTask(taskId) =>
+ executor.killTask(taskId)
+ }
+
+ def reviveOffers() {
+ val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores))
+ for (task <- scheduler.resourceOffers(offers).flatten) {
+ freeCores -= 1
+ executor.launchTask(executorBackend, task.taskId, task.serializedTask)
+ }
+ }
+}
+
+/**
+ * LocalBackend is used when running a local version of Spark where the executor, backend, and
+ * master all run in the same JVM. It sits behind a ClusterScheduler and handles launching tasks
+ * on a single Executor (created by the LocalBackend) running locally.
+ */
+private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: Int)
+ extends SchedulerBackend with ExecutorBackend {
+
+ var localActor: ActorRef = null
+
+ override def start() {
+ localActor = SparkEnv.get.actorSystem.actorOf(
+ Props(new LocalActor(scheduler, this, totalCores)),
+ "LocalBackendActor")
+ }
+
+ override def stop() {
+ }
+
+ override def reviveOffers() {
+ localActor ! ReviveOffers
+ }
+
+ override def defaultParallelism() = totalCores
+
+ override def killTask(taskId: Long, executorId: String) {
+ localActor ! KillTask(taskId)
+ }
+
+ override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) {
+ localActor ! StatusUpdate(taskId, state, serializedData)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
deleted file mode 100644
index 7c173e3ad5..0000000000
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
+++ /dev/null
@@ -1,224 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler.local
-
-import java.nio.ByteBuffer
-import java.util.concurrent.atomic.AtomicInteger
-
-import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
-
-import akka.actor._
-
-import org.apache.spark._
-import org.apache.spark.TaskState.TaskState
-import org.apache.spark.executor.{Executor, ExecutorBackend}
-import org.apache.spark.scheduler._
-import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
-
-
-/**
- * A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
- * the scheduler also allows each task to fail up to maxFailures times, which is useful for
- * testing fault recovery.
- */
-
-private[local]
-case class LocalReviveOffers()
-
-private[local]
-case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer)
-
-private[local]
-case class KillTask(taskId: Long)
-
-private[spark]
-class LocalActor(localScheduler: LocalScheduler, private var freeCores: Int)
- extends Actor with Logging {
-
- val executor = new Executor(
- "localhost", "localhost", localScheduler.sc.conf.getAll, isLocal = true)
-
- def receive = {
- case LocalReviveOffers =>
- launchTask(localScheduler.resourceOffer(freeCores))
-
- case LocalStatusUpdate(taskId, state, serializeData) =>
- if (TaskState.isFinished(state)) {
- freeCores += 1
- launchTask(localScheduler.resourceOffer(freeCores))
- }
-
- case KillTask(taskId) =>
- executor.killTask(taskId)
- }
-
- private def launchTask(tasks: Seq[TaskDescription]) {
- for (task <- tasks) {
- freeCores -= 1
- executor.launchTask(localScheduler, task.taskId, task.serializedTask)
- }
- }
-}
-
-private[spark] class LocalScheduler(val threads: Int, val maxFailures: Int, val sc: SparkContext)
- extends TaskScheduler
- with ExecutorBackend
- with Logging {
-
- val env = SparkEnv.get
- val conf = env.conf
- val attemptId = new AtomicInteger
- var dagScheduler: DAGScheduler = null
-
- // Application dependencies (added through SparkContext) that we've fetched so far on this node.
- // Each map holds the master's timestamp for the version of that file or JAR we got.
- val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
- val currentJars: HashMap[String, Long] = new HashMap[String, Long]()
-
- var schedulableBuilder: SchedulableBuilder = null
- var rootPool: Pool = null
- val schedulingMode: SchedulingMode = SchedulingMode.withName(
- conf.getOrElse("spark.scheduler.mode", "FIFO"))
- val activeTaskSets = new HashMap[String, LocalTaskSetManager]
- val taskIdToTaskSetId = new HashMap[Long, String]
- val taskSetTaskIds = new HashMap[String, HashSet[Long]]
-
- var localActor: ActorRef = null
-
- override def start() {
- // temporarily set rootPool name to empty
- rootPool = new Pool("", schedulingMode, 0, 0)
- schedulableBuilder = {
- schedulingMode match {
- case SchedulingMode.FIFO =>
- new FIFOSchedulableBuilder(rootPool)
- case SchedulingMode.FAIR =>
- new FairSchedulableBuilder(rootPool, conf)
- }
- }
- schedulableBuilder.buildPools()
-
- localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test")
- }
-
- override def setDAGScheduler(dagScheduler: DAGScheduler) {
- this.dagScheduler = dagScheduler
- }
-
- override def submitTasks(taskSet: TaskSet) {
- synchronized {
- val manager = new LocalTaskSetManager(this, taskSet)
- schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
- activeTaskSets(taskSet.id) = manager
- taskSetTaskIds(taskSet.id) = new HashSet[Long]()
- localActor ! LocalReviveOffers
- }
- }
-
- override def cancelTasks(stageId: Int): Unit = synchronized {
- logInfo("Cancelling stage " + stageId)
- logInfo("Cancelling stage " + activeTaskSets.map(_._2.stageId))
- activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
- // There are two possible cases here:
- // 1. The task set manager has been created and some tasks have been scheduled.
- // In this case, send a kill signal to the executors to kill the task and then abort
- // the stage.
- // 2. The task set manager has been created but no tasks has been scheduled. In this case,
- // simply abort the stage.
- val taskIds = taskSetTaskIds(tsm.taskSet.id)
- if (taskIds.size > 0) {
- taskIds.foreach { tid =>
- localActor ! KillTask(tid)
- }
- }
- logInfo("Stage %d was cancelled".format(stageId))
- taskSetFinished(tsm)
- }
- }
-
- def resourceOffer(freeCores: Int): Seq[TaskDescription] = {
- synchronized {
- var freeCpuCores = freeCores
- val tasks = new ArrayBuffer[TaskDescription](freeCores)
- val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue()
- for (manager <- sortedTaskSetQueue) {
- logDebug("parentName:%s,name:%s,runningTasks:%s".format(
- manager.parent.name, manager.name, manager.runningTasks))
- }
-
- var launchTask = false
- for (manager <- sortedTaskSetQueue) {
- do {
- launchTask = false
- manager.resourceOffer(null, null, freeCpuCores, null) match {
- case Some(task) =>
- tasks += task
- taskIdToTaskSetId(task.taskId) = manager.taskSet.id
- taskSetTaskIds(manager.taskSet.id) += task.taskId
- freeCpuCores -= 1
- launchTask = true
- case None => {}
- }
- } while(launchTask)
- }
- return tasks
- }
- }
-
- def taskSetFinished(manager: TaskSetManager) {
- synchronized {
- activeTaskSets -= manager.taskSet.id
- manager.parent.removeSchedulable(manager)
- logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
- taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
- taskSetTaskIds -= manager.taskSet.id
- }
- }
-
- override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) {
- if (TaskState.isFinished(state)) {
- synchronized {
- taskIdToTaskSetId.get(taskId) match {
- case Some(taskSetId) =>
- val taskSetManager = activeTaskSets.get(taskSetId)
- taskSetManager.foreach { tsm =>
- taskSetTaskIds(taskSetId) -= taskId
-
- state match {
- case TaskState.FINISHED =>
- tsm.taskEnded(taskId, state, serializedData)
- case TaskState.FAILED =>
- tsm.taskFailed(taskId, state, serializedData)
- case TaskState.KILLED =>
- tsm.error("Task %d was killed".format(taskId))
- case _ => {}
- }
- }
- case None =>
- logInfo("Ignoring update from TID " + taskId + " because its task set is gone")
- }
- }
- localActor ! LocalStatusUpdate(taskId, state, serializedData)
- }
- }
-
- override def stop() {
- }
-
- override def defaultParallelism() = threads
-}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
deleted file mode 100644
index 53bf78267e..0000000000
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
+++ /dev/null
@@ -1,191 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler.local
-
-import java.nio.ByteBuffer
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-
-import org.apache.spark.{ExceptionFailure, Logging, SparkEnv, SparkException, Success, TaskState}
-import org.apache.spark.TaskState.TaskState
-import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Pool, Schedulable, Task,
- TaskDescription, TaskInfo, TaskLocality, TaskResult, TaskSet, TaskSetManager}
-
-
-private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet)
- extends TaskSetManager with Logging {
-
- var parent: Pool = null
- var weight: Int = 1
- var minShare: Int = 0
- var runningTasks: Int = 0
- var priority: Int = taskSet.priority
- var stageId: Int = taskSet.stageId
- var name: String = "TaskSet_" + taskSet.stageId.toString
-
- var failCount = new Array[Int](taskSet.tasks.size)
- val taskInfos = new HashMap[Long, TaskInfo]
- val numTasks = taskSet.tasks.size
- var numFinished = 0
- val env = SparkEnv.get
- val ser = env.closureSerializer.newInstance()
- val copiesRunning = new Array[Int](numTasks)
- val finished = new Array[Boolean](numTasks)
- val numFailures = new Array[Int](numTasks)
- val MAX_TASK_FAILURES = sched.maxFailures
-
- def increaseRunningTasks(taskNum: Int): Unit = {
- runningTasks += taskNum
- if (parent != null) {
- parent.increaseRunningTasks(taskNum)
- }
- }
-
- def decreaseRunningTasks(taskNum: Int): Unit = {
- runningTasks -= taskNum
- if (parent != null) {
- parent.decreaseRunningTasks(taskNum)
- }
- }
-
- override def addSchedulable(schedulable: Schedulable): Unit = {
- // nothing
- }
-
- override def removeSchedulable(schedulable: Schedulable): Unit = {
- // nothing
- }
-
- override def getSchedulableByName(name: String): Schedulable = {
- return null
- }
-
- override def executorLost(executorId: String, host: String): Unit = {
- // nothing
- }
-
- override def checkSpeculatableTasks() = true
-
- override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
- var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]
- sortedTaskSetQueue += this
- return sortedTaskSetQueue
- }
-
- override def hasPendingTasks() = true
-
- def findTask(): Option[Int] = {
- for (i <- 0 to numTasks-1) {
- if (copiesRunning(i) == 0 && !finished(i)) {
- return Some(i)
- }
- }
- return None
- }
-
- override def resourceOffer(
- execId: String,
- host: String,
- availableCpus: Int,
- maxLocality: TaskLocality.TaskLocality)
- : Option[TaskDescription] =
- {
- SparkEnv.set(sched.env)
- logDebug("availableCpus:%d, numFinished:%d, numTasks:%d".format(
- availableCpus.toInt, numFinished, numTasks))
- if (availableCpus > 0 && numFinished < numTasks) {
- findTask() match {
- case Some(index) =>
- val taskId = sched.attemptId.getAndIncrement()
- val task = taskSet.tasks(index)
- val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1",
- TaskLocality.NODE_LOCAL)
- taskInfos(taskId) = info
- // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here
- // we assume the task can be serialized without exceptions.
- val bytes = Task.serializeWithDependencies(
- task, sched.sc.addedFiles, sched.sc.addedJars, ser)
- logInfo("Size of task " + taskId + " is " + bytes.limit + " bytes")
- val taskName = "task %s:%d".format(taskSet.id, index)
- copiesRunning(index) += 1
- increaseRunningTasks(1)
- taskStarted(task, info)
- return Some(new TaskDescription(taskId, null, taskName, index, bytes))
- case None => {}
- }
- }
- return None
- }
-
- def taskStarted(task: Task[_], info: TaskInfo) {
- sched.dagScheduler.taskStarted(task, info)
- }
-
- def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) {
- val info = taskInfos(tid)
- val index = info.index
- val task = taskSet.tasks(index)
- info.markSuccessful()
- val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader) match {
- case directResult: DirectTaskResult[_] => directResult
- case IndirectTaskResult(blockId) => {
- throw new SparkException("Expect only DirectTaskResults when using LocalScheduler")
- }
- }
- result.metrics.resultSize = serializedData.limit()
- sched.dagScheduler.taskEnded(task, Success, result.value, result.accumUpdates, info,
- result.metrics)
- numFinished += 1
- decreaseRunningTasks(1)
- finished(index) = true
- if (numFinished == numTasks) {
- sched.taskSetFinished(this)
- }
- }
-
- def taskFailed(tid: Long, state: TaskState, serializedData: ByteBuffer) {
- val info = taskInfos(tid)
- val index = info.index
- val task = taskSet.tasks(index)
- info.markFailed()
- decreaseRunningTasks(1)
- val reason: ExceptionFailure = ser.deserialize[ExceptionFailure](
- serializedData, getClass.getClassLoader)
- sched.dagScheduler.taskEnded(task, reason, null, null, info, reason.metrics.getOrElse(null))
- if (!finished(index)) {
- copiesRunning(index) -= 1
- numFailures(index) += 1
- val locs = reason.stackTrace.map(loc => "\tat %s".format(loc.toString))
- logInfo("Loss was due to %s\n%s\n%s".format(
- reason.className, reason.description, locs.mkString("\n")))
- if (numFailures(index) > MAX_TASK_FAILURES) {
- val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format(
- taskSet.id, index, MAX_TASK_FAILURES, reason.description)
- decreaseRunningTasks(runningTasks)
- sched.dagScheduler.taskSetFailed(taskSet, errorMessage)
- // need to delete failed Taskset from schedule queue
- sched.taskSetFinished(this)
- }
- }
- }
-
- override def error(message: String) {
- sched.dagScheduler.taskSetFailed(taskSet, message)
- sched.taskSetFinished(this)
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
index f592df283a..151eedb783 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
@@ -74,10 +74,16 @@ class ShuffleBlockManager(blockManager: BlockManager) {
* Contains all the state related to a particular shuffle. This includes a pool of unused
* ShuffleFileGroups, as well as all ShuffleFileGroups that have been created for the shuffle.
*/
- private class ShuffleState() {
+ private class ShuffleState(val numBuckets: Int) {
val nextFileId = new AtomicInteger(0)
val unusedFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]()
val allFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]()
+
+ /**
+ * The mapIds of all map tasks completed on this Executor for this shuffle.
+ * NB: This is only populated if consolidateShuffleFiles is FALSE. We don't need it otherwise.
+ */
+ val completedMapTasks = new ConcurrentLinkedQueue[Int]()
}
type ShuffleId = Int
@@ -88,7 +94,7 @@ class ShuffleBlockManager(blockManager: BlockManager) {
def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer) = {
new ShuffleWriterGroup {
- shuffleStates.putIfAbsent(shuffleId, new ShuffleState())
+ shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets))
private val shuffleState = shuffleStates(shuffleId)
private var fileGroup: ShuffleFileGroup = null
@@ -113,6 +119,8 @@ class ShuffleBlockManager(blockManager: BlockManager) {
fileGroup.recordMapOutput(mapId, offsets)
}
recycleFileGroup(fileGroup)
+ } else {
+ shuffleState.completedMapTasks.add(mapId)
}
}
@@ -158,7 +166,18 @@ class ShuffleBlockManager(blockManager: BlockManager) {
}
private def cleanup(cleanupTime: Long) {
- shuffleStates.clearOldValues(cleanupTime)
+ shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => {
+ if (consolidateShuffleFiles) {
+ for (fileGroup <- state.allFileGroups; file <- fileGroup.files) {
+ file.delete()
+ }
+ } else {
+ for (mapId <- state.completedMapTasks; reduceId <- 0 until state.numBuckets) {
+ val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId)
+ blockManager.diskBlockManager.getFile(blockId).delete()
+ }
+ }
+ })
}
}
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala
index e596690bc3..a31a7e1d58 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala
@@ -56,7 +56,8 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize)).fold(0L)(_+_)
val execHead = Seq("Executor ID", "Address", "RDD blocks", "Memory used", "Disk used",
- "Active tasks", "Failed tasks", "Complete tasks", "Total tasks")
+ "Active tasks", "Failed tasks", "Complete tasks", "Total tasks", "Task Time", "Shuffle Read",
+ "Shuffle Write")
def execRow(kv: Seq[String]) = {
<tr>
@@ -73,6 +74,9 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
<td>{kv(7)}</td>
<td>{kv(8)}</td>
<td>{kv(9)}</td>
+ <td>{Utils.msDurationToString(kv(10).toLong)}</td>
+ <td>{Utils.bytesToString(kv(11).toLong)}</td>
+ <td>{Utils.bytesToString(kv(12).toLong)}</td>
</tr>
}
@@ -111,6 +115,9 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
val failedTasks = listener.executorToTasksFailed.getOrElse(execId, 0)
val completedTasks = listener.executorToTasksComplete.getOrElse(execId, 0)
val totalTasks = activeTasks + failedTasks + completedTasks
+ val totalDuration = listener.executorToDuration.getOrElse(execId, 0)
+ val totalShuffleRead = listener.executorToShuffleRead.getOrElse(execId, 0)
+ val totalShuffleWrite = listener.executorToShuffleWrite.getOrElse(execId, 0)
Seq(
execId,
@@ -122,7 +129,10 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
activeTasks.toString,
failedTasks.toString,
completedTasks.toString,
- totalTasks.toString
+ totalTasks.toString,
+ totalDuration.toString,
+ totalShuffleRead.toString,
+ totalShuffleWrite.toString
)
}
@@ -130,6 +140,9 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
val executorToTasksActive = HashMap[String, HashSet[TaskInfo]]()
val executorToTasksComplete = HashMap[String, Int]()
val executorToTasksFailed = HashMap[String, Int]()
+ val executorToDuration = HashMap[String, Long]()
+ val executorToShuffleRead = HashMap[String, Long]()
+ val executorToShuffleWrite = HashMap[String, Long]()
override def onTaskStart(taskStart: SparkListenerTaskStart) {
val eid = taskStart.taskInfo.executorId
@@ -140,6 +153,9 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
val eid = taskEnd.taskInfo.executorId
val activeTasks = executorToTasksActive.getOrElseUpdate(eid, new HashSet[TaskInfo]())
+ val newDuration = executorToDuration.getOrElse(eid, 0L) + taskEnd.taskInfo.duration
+ executorToDuration.put(eid, newDuration)
+
activeTasks -= taskEnd.taskInfo
val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) =
taskEnd.reason match {
@@ -150,6 +166,17 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
executorToTasksComplete(eid) = executorToTasksComplete.getOrElse(eid, 0) + 1
(None, Option(taskEnd.taskMetrics))
}
+
+ // update shuffle read/write
+ if (null != taskEnd.taskMetrics) {
+ taskEnd.taskMetrics.shuffleReadMetrics.foreach(shuffleRead =>
+ executorToShuffleRead.put(eid, executorToShuffleRead.getOrElse(eid, 0L) +
+ shuffleRead.remoteBytesRead))
+
+ taskEnd.taskMetrics.shuffleWriteMetrics.foreach(shuffleWrite =>
+ executorToShuffleWrite.put(eid, executorToShuffleWrite.getOrElse(eid, 0L) +
+ shuffleWrite.shuffleBytesWritten))
+ }
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala
new file mode 100644
index 0000000000..3c53e88380
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.jobs
+
+/** class for reporting aggregated metrics for each executors in stageUI */
+private[spark] class ExecutorSummary {
+ var taskTime : Long = 0
+ var failedTasks : Int = 0
+ var succeededTasks : Int = 0
+ var shuffleRead : Long = 0
+ var shuffleWrite : Long = 0
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
new file mode 100644
index 0000000000..0dd876480a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.jobs
+
+import scala.xml.Node
+
+import org.apache.spark.scheduler.SchedulingMode
+import org.apache.spark.util.Utils
+import scala.collection.mutable
+
+/** Page showing executor summary */
+private[spark] class ExecutorTable(val parent: JobProgressUI, val stageId: Int) {
+
+ val listener = parent.listener
+ val dateFmt = parent.dateFmt
+ val isFairScheduler = listener.sc.getSchedulingMode == SchedulingMode.FAIR
+
+ def toNodeSeq(): Seq[Node] = {
+ listener.synchronized {
+ executorTable()
+ }
+ }
+
+ /** Special table which merges two header cells. */
+ private def executorTable[T](): Seq[Node] = {
+ <table class="table table-bordered table-striped table-condensed sortable">
+ <thead>
+ <th>Executor ID</th>
+ <th>Address</th>
+ <th>Task Time</th>
+ <th>Total Tasks</th>
+ <th>Failed Tasks</th>
+ <th>Succeeded Tasks</th>
+ <th>Shuffle Read</th>
+ <th>Shuffle Write</th>
+ </thead>
+ <tbody>
+ {createExecutorTable()}
+ </tbody>
+ </table>
+ }
+
+ private def createExecutorTable() : Seq[Node] = {
+ // make a executor-id -> address map
+ val executorIdToAddress = mutable.HashMap[String, String]()
+ val storageStatusList = parent.sc.getExecutorStorageStatus
+ for (statusId <- 0 until storageStatusList.size) {
+ val blockManagerId = parent.sc.getExecutorStorageStatus(statusId).blockManagerId
+ val address = blockManagerId.hostPort
+ val executorId = blockManagerId.executorId
+ executorIdToAddress.put(executorId, address)
+ }
+
+ val executorIdToSummary = listener.stageIdToExecutorSummaries.get(stageId)
+ executorIdToSummary match {
+ case Some(x) => {
+ x.toSeq.sortBy(_._1).map{
+ case (k,v) => {
+ <tr>
+ <td>{k}</td>
+ <td>{executorIdToAddress.getOrElse(k, "CANNOT FIND ADDRESS")}</td>
+ <td>{parent.formatDuration(v.taskTime)}</td>
+ <td>{v.failedTasks + v.succeededTasks}</td>
+ <td>{v.failedTasks}</td>
+ <td>{v.succeededTasks}</td>
+ <td>{Utils.bytesToString(v.shuffleRead)}</td>
+ <td>{Utils.bytesToString(v.shuffleWrite)}</td>
+ </tr>
+ }
+ }
+ }
+ case _ => { Seq[Node]() }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index 6ff8e9fb14..eed3544b70 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -57,6 +57,7 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
val stageIdToTasksFailed = HashMap[Int, Int]()
val stageIdToTaskInfos =
HashMap[Int, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]]()
+ val stageIdToExecutorSummaries = HashMap[Int, HashMap[String, ExecutorSummary]]()
override def onJobStart(jobStart: SparkListenerJobStart) {}
@@ -124,8 +125,38 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized {
val sid = taskEnd.task.stageId
+
+ // create executor summary map if necessary
+ val executorSummaryMap = stageIdToExecutorSummaries.getOrElseUpdate(key = sid,
+ op = new HashMap[String, ExecutorSummary]())
+ executorSummaryMap.getOrElseUpdate(key = taskEnd.taskInfo.executorId,
+ op = new ExecutorSummary())
+
+ val executorSummary = executorSummaryMap.get(taskEnd.taskInfo.executorId)
+ executorSummary match {
+ case Some(y) => {
+ // first update failed-task, succeed-task
+ taskEnd.reason match {
+ case Success =>
+ y.succeededTasks += 1
+ case _ =>
+ y.failedTasks += 1
+ }
+
+ // update duration
+ y.taskTime += taskEnd.taskInfo.duration
+
+ Option(taskEnd.taskMetrics).foreach { taskMetrics =>
+ taskMetrics.shuffleReadMetrics.foreach { y.shuffleRead += _.remoteBytesRead }
+ taskMetrics.shuffleWriteMetrics.foreach { y.shuffleWrite += _.shuffleBytesWritten }
+ }
+ }
+ case _ => {}
+ }
+
val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
tasksActive -= taskEnd.taskInfo
+
val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) =
taskEnd.reason match {
case e: ExceptionFailure =>
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 996e1b4d1a..8dcfeacb60 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -66,7 +66,7 @@ private[spark] class StagePage(parent: JobProgressUI) {
<div>
<ul class="unstyled">
<li>
- <strong>Total duration across all tasks: </strong>
+ <strong>Total task time across all tasks: </strong>
{parent.formatDuration(listener.stageIdToTime.getOrElse(stageId, 0L) + activeTime)}
</li>
{if (hasShuffleRead)
@@ -166,11 +166,12 @@ private[spark] class StagePage(parent: JobProgressUI) {
def quantileRow(data: Seq[String]): Seq[Node] = <tr> {data.map(d => <td>{d}</td>)} </tr>
Some(listingTable(quantileHeaders, quantileRow, listings, fixedWidth = true))
}
-
+ val executorTable = new ExecutorTable(parent, stageId)
val content =
summary ++
<h4>Summary Metrics for {numCompleted} Completed Tasks</h4> ++
<div>{summaryTable.getOrElse("No tasks have reported metrics yet.")}</div> ++
+ <h4>Aggregated Metrics by Executors</h4> ++ executorTable.toNodeSeq() ++
<h4>Tasks</h4> ++ taskTable
headerSparkPage(content, parent.sc, "Details for Stage %d".format(stageId), Stages)
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
index 9ad6de3c6d..463d85dfd5 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
@@ -48,7 +48,7 @@ private[spark] class StageTable(val stages: Seq[StageInfo], val parent: JobProgr
{if (isFairScheduler) {<th>Pool Name</th>} else {}}
<th>Description</th>
<th>Submitted</th>
- <th>Duration</th>
+ <th>Task Time</th>
<th>Tasks: Succeeded/Total</th>
<th>Shuffle Read</th>
<th>Shuffle Write</th>
diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
index 431d88838f..9ea7fc2dfd 100644
--- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
@@ -32,7 +32,7 @@ class MetadataCleaner(
{
val name = cleanerType.toString
- private val delaySeconds = MetadataCleaner.getDelaySeconds(conf)
+ private val delaySeconds = MetadataCleaner.getDelaySeconds(conf, cleanerType)
private val periodSeconds = math.max(10, delaySeconds / 10)
private val timer = new Timer(name + " cleanup timer", true)
diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
index dbff571de9..181ae2fd45 100644
--- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
@@ -104,19 +104,28 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() with Logging {
def toMap: immutable.Map[A, B] = iterator.toMap
/**
- * Removes old key-value pairs that have timestamp earlier than `threshTime`
+ * Removes old key-value pairs that have timestamp earlier than `threshTime`,
+ * calling the supplied function on each such entry before removing.
*/
- def clearOldValues(threshTime: Long) {
+ def clearOldValues(threshTime: Long, f: (A, B) => Unit) {
val iterator = internalMap.entrySet().iterator()
- while(iterator.hasNext) {
+ while (iterator.hasNext) {
val entry = iterator.next()
if (entry.getValue._2 < threshTime) {
+ f(entry.getKey, entry.getValue._1)
logDebug("Removing key " + entry.getKey)
iterator.remove()
}
}
}
+ /**
+ * Removes old key-value pairs that have timestamp earlier than `threshTime`
+ */
+ def clearOldValues(threshTime: Long) {
+ clearOldValues(threshTime, (_, _) => ())
+ }
+
private def currentTime: Long = System.currentTimeMillis()
}
diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala
index af448fcb37..befdc1589f 100644
--- a/core/src/test/scala/org/apache/spark/FailureSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala
@@ -42,7 +42,7 @@ class FailureSuite extends FunSuite with LocalSparkContext {
// Run a 3-task map job in which task 1 deterministically fails once, and check
// whether the job completes successfully and we ran 4 tasks in total.
test("failure in a single-stage job") {
- sc = new SparkContext("local[1,1]", "test")
+ sc = new SparkContext("local[1,2]", "test")
val results = sc.makeRDD(1 to 3, 3).map { x =>
FailureSuiteState.synchronized {
FailureSuiteState.tasksRun += 1
@@ -62,7 +62,7 @@ class FailureSuite extends FunSuite with LocalSparkContext {
// Run a map-reduce job in which a reduce task deterministically fails once.
test("failure in a two-stage job") {
- sc = new SparkContext("local[1,1]", "test")
+ sc = new SparkContext("local[1,2]", "test")
val results = sc.makeRDD(1 to 3).map(x => (x, x)).groupByKey(3).map {
case (k, v) =>
FailureSuiteState.synchronized {
diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala
index 151af0d213..f28d5c7b13 100644
--- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala
@@ -19,20 +19,21 @@ package org.apache.spark
import org.scalatest.{FunSuite, PrivateMethodTester}
-import org.apache.spark.scheduler.TaskScheduler
-import org.apache.spark.scheduler.cluster.{ClusterScheduler, SimrSchedulerBackend, SparkDeploySchedulerBackend}
+import org.apache.spark.scheduler.{TaskSchedulerImpl, TaskScheduler}
+import org.apache.spark.scheduler.cluster.{SimrSchedulerBackend, SparkDeploySchedulerBackend}
import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
-import org.apache.spark.scheduler.local.LocalScheduler
+import org.apache.spark.scheduler.local.LocalBackend
class SparkContextSchedulerCreationSuite
extends FunSuite with PrivateMethodTester with LocalSparkContext with Logging {
- def createTaskScheduler(master: String): TaskScheduler = {
+ def createTaskScheduler(master: String): TaskSchedulerImpl = {
// Create local SparkContext to setup a SparkEnv. We don't actually want to start() the
// real schedulers, so we don't want to create a full SparkContext with the desired scheduler.
sc = new SparkContext("local", "test")
val createTaskSchedulerMethod = PrivateMethod[TaskScheduler]('createTaskScheduler)
- SparkContext invokePrivate createTaskSchedulerMethod(sc, master, "test")
+ val sched = SparkContext invokePrivate createTaskSchedulerMethod(sc, master, "test")
+ sched.asInstanceOf[TaskSchedulerImpl]
}
test("bad-master") {
@@ -43,55 +44,49 @@ class SparkContextSchedulerCreationSuite
}
test("local") {
- createTaskScheduler("local") match {
- case s: LocalScheduler =>
- assert(s.threads === 1)
- assert(s.maxFailures === 0)
+ val sched = createTaskScheduler("local")
+ sched.backend match {
+ case s: LocalBackend => assert(s.totalCores === 1)
case _ => fail()
}
}
test("local-n") {
- createTaskScheduler("local[5]") match {
- case s: LocalScheduler =>
- assert(s.threads === 5)
- assert(s.maxFailures === 0)
+ val sched = createTaskScheduler("local[5]")
+ assert(sched.maxTaskFailures === 1)
+ sched.backend match {
+ case s: LocalBackend => assert(s.totalCores === 5)
case _ => fail()
}
}
test("local-n-failures") {
- createTaskScheduler("local[4, 2]") match {
- case s: LocalScheduler =>
- assert(s.threads === 4)
- assert(s.maxFailures === 2)
+ val sched = createTaskScheduler("local[4, 2]")
+ assert(sched.maxTaskFailures === 2)
+ sched.backend match {
+ case s: LocalBackend => assert(s.totalCores === 4)
case _ => fail()
}
}
test("simr") {
- createTaskScheduler("simr://uri") match {
- case s: ClusterScheduler =>
- assert(s.backend.isInstanceOf[SimrSchedulerBackend])
+ createTaskScheduler("simr://uri").backend match {
+ case s: SimrSchedulerBackend => // OK
case _ => fail()
}
}
test("local-cluster") {
- createTaskScheduler("local-cluster[3, 14, 512]") match {
- case s: ClusterScheduler =>
- assert(s.backend.isInstanceOf[SparkDeploySchedulerBackend])
+ createTaskScheduler("local-cluster[3, 14, 512]").backend match {
+ case s: SparkDeploySchedulerBackend => // OK
case _ => fail()
}
}
def testYarn(master: String, expectedClassName: String) {
try {
- createTaskScheduler(master) match {
- case s: ClusterScheduler =>
- assert(s.getClass === Class.forName(expectedClassName))
- case _ => fail()
- }
+ val sched = createTaskScheduler(master)
+ assert(sched.getClass === Class.forName(expectedClassName))
} catch {
case e: SparkException =>
assert(e.getMessage.contains("YARN mode not available"))
@@ -110,11 +105,8 @@ class SparkContextSchedulerCreationSuite
def testMesos(master: String, expectedClass: Class[_]) {
try {
- createTaskScheduler(master) match {
- case s: ClusterScheduler =>
- assert(s.backend.getClass === expectedClass)
- case _ => fail()
- }
+ val sched = createTaskScheduler(master)
+ assert(sched.backend.getClass === expectedClass)
} catch {
case e: UnsatisfiedLinkError =>
assert(e.getMessage.contains("no mesos in"))
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala
index 34d2e4cb8c..7bf2020fe3 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala
@@ -15,14 +15,12 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter
import org.apache.spark._
-import org.apache.spark.scheduler._
-import org.apache.spark.scheduler.cluster._
import scala.collection.mutable.ArrayBuffer
import java.util.Properties
@@ -31,9 +29,9 @@ class FakeTaskSetManager(
initPriority: Int,
initStageId: Int,
initNumTasks: Int,
- clusterScheduler: ClusterScheduler,
+ clusterScheduler: TaskSchedulerImpl,
taskSet: TaskSet)
- extends ClusterTaskSetManager(clusterScheduler, taskSet) {
+ extends TaskSetManager(clusterScheduler, taskSet, 0) {
parent = null
weight = 1
@@ -106,7 +104,7 @@ class FakeTaskSetManager(
class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging {
- def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: ClusterScheduler, taskSet: TaskSet): FakeTaskSetManager = {
+ def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: TaskSchedulerImpl, taskSet: TaskSet): FakeTaskSetManager = {
new FakeTaskSetManager(priority, stage, numTasks, cs , taskSet)
}
@@ -133,7 +131,7 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging
test("FIFO Scheduler Test") {
sc = new SparkContext("local", "ClusterSchedulerSuite")
- val clusterScheduler = new ClusterScheduler(sc)
+ val clusterScheduler = new TaskSchedulerImpl(sc)
var tasks = ArrayBuffer[Task[_]]()
val task = new FakeTask(0)
tasks += task
@@ -160,7 +158,7 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging
test("Fair Scheduler Test") {
sc = new SparkContext("local", "ClusterSchedulerSuite")
- val clusterScheduler = new ClusterScheduler(sc)
+ val clusterScheduler = new TaskSchedulerImpl(sc)
var tasks = ArrayBuffer[Task[_]]()
val task = new FakeTask(0)
tasks += task
@@ -217,7 +215,7 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging
test("Nested Pool Test") {
sc = new SparkContext("local", "ClusterSchedulerSuite")
- val clusterScheduler = new ClusterScheduler(sc)
+ val clusterScheduler = new TaskSchedulerImpl(sc)
var tasks = ArrayBuffer[Task[_]]()
val task = new FakeTask(0)
tasks += task
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
index 0f01515179..0b90c4e74c 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
@@ -15,10 +15,9 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
import org.apache.spark.TaskContext
-import org.apache.spark.scheduler.{TaskLocation, Task}
class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0) {
override def runTask(context: TaskContext): Int = 0
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
index 2e41438a52..d4320e5e14 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -19,23 +19,26 @@ package org.apache.spark.scheduler
import scala.collection.mutable.{Buffer, HashSet}
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
import org.scalatest.matchers.ShouldMatchers
import org.apache.spark.{LocalSparkContext, SparkContext}
import org.apache.spark.SparkContext._
class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
- with BeforeAndAfterAll {
+ with BeforeAndAfter with BeforeAndAfterAll {
/** Length of time to wait while draining listener events. */
val WAIT_TIMEOUT_MILLIS = 10000
+ before {
+ sc = new SparkContext("local", "SparkListenerSuite")
+ }
+
override def afterAll {
System.clearProperty("spark.akka.frameSize")
}
test("basic creation of StageInfo") {
- sc = new SparkContext("local", "DAGSchedulerSuite")
val listener = new SaveStageInfo
sc.addSparkListener(listener)
val rdd1 = sc.parallelize(1 to 100, 4)
@@ -56,7 +59,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
}
test("StageInfo with fewer tasks than partitions") {
- sc = new SparkContext("local", "DAGSchedulerSuite")
val listener = new SaveStageInfo
sc.addSparkListener(listener)
val rdd1 = sc.parallelize(1 to 100, 4)
@@ -72,7 +74,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
}
test("local metrics") {
- sc = new SparkContext("local", "DAGSchedulerSuite")
val listener = new SaveStageInfo
sc.addSparkListener(listener)
sc.addSparkListener(new StatsReportListener)
@@ -135,10 +136,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
}
test("onTaskGettingResult() called when result fetched remotely") {
- // Need to use local cluster mode here, because results are not ever returned through the
- // block manager when using the LocalScheduler.
- sc = new SparkContext("local-cluster[1,1,512]", "test")
-
val listener = new SaveTaskEvents
sc.addSparkListener(listener)
@@ -157,10 +154,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
}
test("onTaskGettingResult() not called when result sent directly") {
- // Need to use local cluster mode here, because results are not ever returned through the
- // block manager when using the LocalScheduler.
- sc = new SparkContext("local-cluster[1,1,512]", "test")
-
val listener = new SaveTaskEvents
sc.addSparkListener(listener)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
index 618fae7c16..4b52d9651e 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
@@ -15,14 +15,13 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
import java.nio.ByteBuffer
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
-import org.apache.spark.{SparkConf, LocalSparkContext, SparkContext, SparkEnv}
-import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult}
+import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv}
import org.apache.spark.storage.TaskResultBlockId
/**
@@ -31,12 +30,12 @@ import org.apache.spark.storage.TaskResultBlockId
* Used to test the case where a BlockManager evicts the task result (or dies) before the
* TaskResult is retrieved.
*/
-class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler)
+class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedulerImpl)
extends TaskResultGetter(sparkEnv, scheduler) {
var removedResult = false
override def enqueueSuccessfulTask(
- taskSetManager: ClusterTaskSetManager, tid: Long, serializedData: ByteBuffer) {
+ taskSetManager: TaskSetManager, tid: Long, serializedData: ByteBuffer) {
if (!removedResult) {
// Only remove the result once, since we'd like to test the case where the task eventually
// succeeds.
@@ -44,13 +43,13 @@ class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSched
case IndirectTaskResult(blockId) =>
sparkEnv.blockManager.master.removeBlock(blockId)
case directResult: DirectTaskResult[_] =>
- taskSetManager.abort("Internal error: expect only indirect results")
+ taskSetManager.abort("Internal error: expect only indirect results")
}
serializedData.rewind()
removedResult = true
}
super.enqueueSuccessfulTask(taskSetManager, tid, serializedData)
- }
+ }
}
/**
@@ -65,22 +64,18 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA
System.setProperty("spark.akka.frameSize", "1")
}
- before {
- // Use local-cluster mode because results are returned differently when running with the
- // LocalScheduler.
- sc = new SparkContext("local-cluster[1,1,512]", "test")
- }
-
override def afterAll {
System.clearProperty("spark.akka.frameSize")
}
test("handling results smaller than Akka frame size") {
+ sc = new SparkContext("local", "test")
val result = sc.parallelize(Seq(1), 1).map(x => 2 * x).reduce((x, y) => x)
assert(result === 2)
}
- test("handling results larger than Akka frame size") {
+ test("handling results larger than Akka frame size") {
+ sc = new SparkContext("local", "test")
val akkaFrameSize =
sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt
val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x)
@@ -92,10 +87,13 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA
}
test("task retried if result missing from block manager") {
+ // Set the maximum number of task failures to > 0, so that the task set isn't aborted
+ // after the result is missing.
+ sc = new SparkContext("local[1,2]", "test")
// If this test hangs, it's probably because no resource offers were made after the task
// failed.
- val scheduler: ClusterScheduler = sc.taskScheduler match {
- case clusterScheduler: ClusterScheduler =>
+ val scheduler: TaskSchedulerImpl = sc.taskScheduler match {
+ case clusterScheduler: TaskSchedulerImpl =>
clusterScheduler
case _ =>
assert(false, "Expect local cluster to use ClusterScheduler")
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index 3711382f2e..5d33e66253 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable
@@ -23,7 +23,6 @@ import scala.collection.mutable
import org.scalatest.FunSuite
import org.apache.spark._
-import org.apache.spark.scheduler._
import org.apache.spark.executor.TaskMetrics
import java.nio.ByteBuffer
import org.apache.spark.util.{Utils, FakeClock}
@@ -56,10 +55,10 @@ class FakeDAGScheduler(taskScheduler: FakeClusterScheduler) extends DAGScheduler
* A mock ClusterScheduler implementation that just remembers information about tasks started and
* feedback received from the TaskSetManagers. Note that it's important to initialize this with
* a list of "live" executors and their hostnames for isExecutorAlive and hasExecutorsAliveOnHost
- * to work, and these are required for locality in ClusterTaskSetManager.
+ * to work, and these are required for locality in TaskSetManager.
*/
class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /* execId, host */)
- extends ClusterScheduler(sc)
+ extends TaskSchedulerImpl(sc)
{
val startedTasks = new ArrayBuffer[Long]
val endedTasks = new mutable.HashMap[Long, TaskEndReason]
@@ -79,16 +78,19 @@ class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /*
override def hasExecutorsAliveOnHost(host: String): Boolean = executors.values.exists(_ == host)
}
-class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
+class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
import TaskLocality.{ANY, PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL}
+
private val conf = new SparkConf
+
val LOCALITY_WAIT = conf.getOrElse("spark.locality.wait", "3000").toLong
+ val MAX_TASK_FAILURES = 4
test("TaskSet with no preferences") {
sc = new SparkContext("local", "test")
val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
val taskSet = createTaskSet(1)
- val manager = new ClusterTaskSetManager(sched, taskSet)
+ val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES)
// Offer a host with no CPUs
assert(manager.resourceOffer("exec1", "host1", 0, ANY) === None)
@@ -114,7 +116,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
sc = new SparkContext("local", "test")
val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
val taskSet = createTaskSet(3)
- val manager = new ClusterTaskSetManager(sched, taskSet)
+ val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES)
// First three offers should all find tasks
for (i <- 0 until 3) {
@@ -151,7 +153,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
Seq() // Last task has no locality prefs
)
val clock = new FakeClock
- val manager = new ClusterTaskSetManager(sched, taskSet, clock)
+ val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
// First offer host1, exec1: first task should be chosen
assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0)
@@ -197,7 +199,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
Seq(TaskLocation("host2"))
)
val clock = new FakeClock
- val manager = new ClusterTaskSetManager(sched, taskSet, clock)
+ val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
// First offer host1: first task should be chosen
assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0)
@@ -234,7 +236,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
Seq(TaskLocation("host3"))
)
val clock = new FakeClock
- val manager = new ClusterTaskSetManager(sched, taskSet, clock)
+ val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
// First offer host1: first task should be chosen
assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0)
@@ -262,7 +264,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
val taskSet = createTaskSet(1)
val clock = new FakeClock
- val manager = new ClusterTaskSetManager(sched, taskSet, clock)
+ val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0)
@@ -279,17 +281,17 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
val taskSet = createTaskSet(1)
val clock = new FakeClock
- val manager = new ClusterTaskSetManager(sched, taskSet, clock)
+ val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
// Fail the task MAX_TASK_FAILURES times, and check that the task set is aborted
// after the last failure.
- (1 to manager.MAX_TASK_FAILURES).foreach { index =>
+ (1 to manager.maxTaskFailures).foreach { index =>
val offerResult = manager.resourceOffer("exec1", "host1", 1, ANY)
assert(offerResult != None,
"Expect resource offer on iteration %s to return a task".format(index))
assert(offerResult.get.index === 0)
manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, Some(TaskResultLost))
- if (index < manager.MAX_TASK_FAILURES) {
+ if (index < MAX_TASK_FAILURES) {
assert(!sched.taskSetsFailed.contains(taskSet.id))
} else {
assert(sched.taskSetsFailed.contains(taskSet.id))
diff --git a/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala
deleted file mode 100644
index 1e676c1719..0000000000
--- a/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala
+++ /dev/null
@@ -1,227 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler.local
-
-import java.util.concurrent.Semaphore
-import java.util.concurrent.CountDownLatch
-
-import scala.collection.mutable.HashMap
-
-import org.scalatest.{BeforeAndAfterEach, FunSuite}
-
-import org.apache.spark._
-
-
-class Lock() {
- var finished = false
- def jobWait() = {
- synchronized {
- while(!finished) {
- this.wait()
- }
- }
- }
-
- def jobFinished() = {
- synchronized {
- finished = true
- this.notifyAll()
- }
- }
-}
-
-object TaskThreadInfo {
- val threadToLock = HashMap[Int, Lock]()
- val threadToRunning = HashMap[Int, Boolean]()
- val threadToStarted = HashMap[Int, CountDownLatch]()
-}
-
-/*
- * 1. each thread contains one job.
- * 2. each job contains one stage.
- * 3. each stage only contains one task.
- * 4. each task(launched) must be lanched orderly(using threadToStarted) to make sure
- * it will get cpu core resource, and will wait to finished after user manually
- * release "Lock" and then cluster will contain another free cpu cores.
- * 5. each task(pending) must use "sleep" to make sure it has been added to taskSetManager queue,
- * thus it will be scheduled later when cluster has free cpu cores.
- */
-class LocalSchedulerSuite extends FunSuite with LocalSparkContext with BeforeAndAfterEach {
-
- override def afterEach() {
- super.afterEach()
- System.clearProperty("spark.scheduler.mode")
- }
-
- def createThread(threadIndex: Int, poolName: String, sc: SparkContext, sem: Semaphore) {
-
- TaskThreadInfo.threadToRunning(threadIndex) = false
- val nums = sc.parallelize(threadIndex to threadIndex, 1)
- TaskThreadInfo.threadToLock(threadIndex) = new Lock()
- TaskThreadInfo.threadToStarted(threadIndex) = new CountDownLatch(1)
- new Thread {
- if (poolName != null) {
- sc.setLocalProperty("spark.scheduler.pool", poolName)
- }
- override def run() {
- val ans = nums.map(number => {
- TaskThreadInfo.threadToRunning(number) = true
- TaskThreadInfo.threadToStarted(number).countDown()
- TaskThreadInfo.threadToLock(number).jobWait()
- TaskThreadInfo.threadToRunning(number) = false
- number
- }).collect()
- assert(ans.toList === List(threadIndex))
- sem.release()
- }
- }.start()
- }
-
- test("Local FIFO scheduler end-to-end test") {
- System.setProperty("spark.scheduler.mode", "FIFO")
- sc = new SparkContext("local[4]", "test")
- val sem = new Semaphore(0)
-
- createThread(1,null,sc,sem)
- TaskThreadInfo.threadToStarted(1).await()
- createThread(2,null,sc,sem)
- TaskThreadInfo.threadToStarted(2).await()
- createThread(3,null,sc,sem)
- TaskThreadInfo.threadToStarted(3).await()
- createThread(4,null,sc,sem)
- TaskThreadInfo.threadToStarted(4).await()
- // thread 5 and 6 (stage pending)must meet following two points
- // 1. stages (taskSetManager) of jobs in thread 5 and 6 should be add to taskSetManager
- // queue before executing TaskThreadInfo.threadToLock(1).jobFinished()
- // 2. priority of stage in thread 5 should be prior to priority of stage in thread 6
- // So I just use "sleep" 1s here for each thread.
- // TODO: any better solution?
- createThread(5,null,sc,sem)
- Thread.sleep(1000)
- createThread(6,null,sc,sem)
- Thread.sleep(1000)
-
- assert(TaskThreadInfo.threadToRunning(1) === true)
- assert(TaskThreadInfo.threadToRunning(2) === true)
- assert(TaskThreadInfo.threadToRunning(3) === true)
- assert(TaskThreadInfo.threadToRunning(4) === true)
- assert(TaskThreadInfo.threadToRunning(5) === false)
- assert(TaskThreadInfo.threadToRunning(6) === false)
-
- TaskThreadInfo.threadToLock(1).jobFinished()
- TaskThreadInfo.threadToStarted(5).await()
-
- assert(TaskThreadInfo.threadToRunning(1) === false)
- assert(TaskThreadInfo.threadToRunning(2) === true)
- assert(TaskThreadInfo.threadToRunning(3) === true)
- assert(TaskThreadInfo.threadToRunning(4) === true)
- assert(TaskThreadInfo.threadToRunning(5) === true)
- assert(TaskThreadInfo.threadToRunning(6) === false)
-
- TaskThreadInfo.threadToLock(3).jobFinished()
- TaskThreadInfo.threadToStarted(6).await()
-
- assert(TaskThreadInfo.threadToRunning(1) === false)
- assert(TaskThreadInfo.threadToRunning(2) === true)
- assert(TaskThreadInfo.threadToRunning(3) === false)
- assert(TaskThreadInfo.threadToRunning(4) === true)
- assert(TaskThreadInfo.threadToRunning(5) === true)
- assert(TaskThreadInfo.threadToRunning(6) === true)
-
- TaskThreadInfo.threadToLock(2).jobFinished()
- TaskThreadInfo.threadToLock(4).jobFinished()
- TaskThreadInfo.threadToLock(5).jobFinished()
- TaskThreadInfo.threadToLock(6).jobFinished()
- sem.acquire(6)
- }
-
- test("Local fair scheduler end-to-end test") {
- System.setProperty("spark.scheduler.mode", "FAIR")
- val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
- System.setProperty("spark.scheduler.allocation.file", xmlPath)
-
- sc = new SparkContext("local[8]", "LocalSchedulerSuite")
- val sem = new Semaphore(0)
-
- createThread(10,"1",sc,sem)
- TaskThreadInfo.threadToStarted(10).await()
- createThread(20,"2",sc,sem)
- TaskThreadInfo.threadToStarted(20).await()
- createThread(30,"3",sc,sem)
- TaskThreadInfo.threadToStarted(30).await()
-
- assert(TaskThreadInfo.threadToRunning(10) === true)
- assert(TaskThreadInfo.threadToRunning(20) === true)
- assert(TaskThreadInfo.threadToRunning(30) === true)
-
- createThread(11,"1",sc,sem)
- TaskThreadInfo.threadToStarted(11).await()
- createThread(21,"2",sc,sem)
- TaskThreadInfo.threadToStarted(21).await()
- createThread(31,"3",sc,sem)
- TaskThreadInfo.threadToStarted(31).await()
-
- assert(TaskThreadInfo.threadToRunning(11) === true)
- assert(TaskThreadInfo.threadToRunning(21) === true)
- assert(TaskThreadInfo.threadToRunning(31) === true)
-
- createThread(12,"1",sc,sem)
- TaskThreadInfo.threadToStarted(12).await()
- createThread(22,"2",sc,sem)
- TaskThreadInfo.threadToStarted(22).await()
- createThread(32,"3",sc,sem)
-
- assert(TaskThreadInfo.threadToRunning(12) === true)
- assert(TaskThreadInfo.threadToRunning(22) === true)
- assert(TaskThreadInfo.threadToRunning(32) === false)
-
- TaskThreadInfo.threadToLock(10).jobFinished()
- TaskThreadInfo.threadToStarted(32).await()
-
- assert(TaskThreadInfo.threadToRunning(32) === true)
-
- //1. Similar with above scenario, sleep 1s for stage of 23 and 33 to be added to taskSetManager
- // queue so that cluster will assign free cpu core to stage 23 after stage 11 finished.
- //2. priority of 23 and 33 will be meaningless as using fair scheduler here.
- createThread(23,"2",sc,sem)
- createThread(33,"3",sc,sem)
- Thread.sleep(1000)
-
- TaskThreadInfo.threadToLock(11).jobFinished()
- TaskThreadInfo.threadToStarted(23).await()
-
- assert(TaskThreadInfo.threadToRunning(23) === true)
- assert(TaskThreadInfo.threadToRunning(33) === false)
-
- TaskThreadInfo.threadToLock(12).jobFinished()
- TaskThreadInfo.threadToStarted(33).await()
-
- assert(TaskThreadInfo.threadToRunning(33) === true)
-
- TaskThreadInfo.threadToLock(20).jobFinished()
- TaskThreadInfo.threadToLock(21).jobFinished()
- TaskThreadInfo.threadToLock(22).jobFinished()
- TaskThreadInfo.threadToLock(23).jobFinished()
- TaskThreadInfo.threadToLock(30).jobFinished()
- TaskThreadInfo.threadToLock(31).jobFinished()
- TaskThreadInfo.threadToLock(32).jobFinished()
- TaskThreadInfo.threadToLock(33).jobFinished()
-
- sem.acquire(11)
- }
-}
diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
new file mode 100644
index 0000000000..67a57a0e7f
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.jobs
+
+import org.scalatest.FunSuite
+import org.apache.spark.scheduler._
+import org.apache.spark.{LocalSparkContext, SparkContext, Success}
+import org.apache.spark.scheduler.SparkListenerTaskStart
+import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics}
+
+class JobProgressListenerSuite extends FunSuite with LocalSparkContext {
+ test("test executor id to summary") {
+ val sc = new SparkContext("local", "test")
+ val listener = new JobProgressListener(sc)
+ val taskMetrics = new TaskMetrics()
+ val shuffleReadMetrics = new ShuffleReadMetrics()
+
+ // nothing in it
+ assert(listener.stageIdToExecutorSummaries.size == 0)
+
+ // finish this task, should get updated shuffleRead
+ shuffleReadMetrics.remoteBytesRead = 1000
+ taskMetrics.shuffleReadMetrics = Some(shuffleReadMetrics)
+ var taskInfo = new TaskInfo(1234L, 0, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL)
+ taskInfo.finishTime = 1
+ listener.onTaskEnd(new SparkListenerTaskEnd(
+ new ShuffleMapTask(0, null, null, 0, null), Success, taskInfo, taskMetrics))
+ assert(listener.stageIdToExecutorSummaries.getOrElse(0, fail()).getOrElse("exe-1", fail())
+ .shuffleRead == 1000)
+
+ // finish a task with unknown executor-id, nothing should happen
+ taskInfo = new TaskInfo(1234L, 0, 1000L, "exe-unknown", "host1", TaskLocality.NODE_LOCAL)
+ taskInfo.finishTime = 1
+ listener.onTaskEnd(new SparkListenerTaskEnd(
+ new ShuffleMapTask(0, null, null, 0, null), Success, taskInfo, taskMetrics))
+ assert(listener.stageIdToExecutorSummaries.size == 1)
+
+ // finish this task, should get updated duration
+ shuffleReadMetrics.remoteBytesRead = 1000
+ taskMetrics.shuffleReadMetrics = Some(shuffleReadMetrics)
+ taskInfo = new TaskInfo(1235L, 0, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL)
+ taskInfo.finishTime = 1
+ listener.onTaskEnd(new SparkListenerTaskEnd(
+ new ShuffleMapTask(0, null, null, 0, null), Success, taskInfo, taskMetrics))
+ assert(listener.stageIdToExecutorSummaries.getOrElse(0, fail()).getOrElse("exe-1", fail())
+ .shuffleRead == 2000)
+
+ // finish this task, should get updated duration
+ shuffleReadMetrics.remoteBytesRead = 1000
+ taskMetrics.shuffleReadMetrics = Some(shuffleReadMetrics)
+ taskInfo = new TaskInfo(1236L, 0, 0L, "exe-2", "host1", TaskLocality.NODE_LOCAL)
+ taskInfo.finishTime = 1
+ listener.onTaskEnd(new SparkListenerTaskEnd(
+ new ShuffleMapTask(0, null, null, 0, null), Success, taskInfo, taskMetrics))
+ assert(listener.stageIdToExecutorSummaries.getOrElse(0, fail()).getOrElse("exe-2", fail())
+ .shuffleRead == 1000)
+ }
+}
diff --git a/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java b/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java
index 9a8e4209ed..22994fb2ec 100644
--- a/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java
+++ b/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java
@@ -53,7 +53,7 @@ public class JavaKafkaWordCount {
}
// Create the context with a 1 second batch size
- JavaStreamingContext ssc = new JavaStreamingContext(args[0], "NetworkWordCount",
+ JavaStreamingContext ssc = new JavaStreamingContext(args[0], "KafkaWordCount",
new Duration(2000), System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR"));
int numThreads = Integer.parseInt(args[4]);
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
new file mode 100644
index 0000000000..8247c1ebc5
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.api.python
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.mllib.regression._
+import org.apache.spark.mllib.classification._
+import org.apache.spark.mllib.clustering._
+import org.apache.spark.mllib.recommendation._
+import org.apache.spark.rdd.RDD
+import java.nio.ByteBuffer
+import java.nio.ByteOrder
+import java.nio.DoubleBuffer
+
+/**
+ * The Java stubs necessary for the Python mllib bindings.
+ */
+class PythonMLLibAPI extends Serializable {
+ private def deserializeDoubleVector(bytes: Array[Byte]): Array[Double] = {
+ val packetLength = bytes.length
+ if (packetLength < 16) {
+ throw new IllegalArgumentException("Byte array too short.")
+ }
+ val bb = ByteBuffer.wrap(bytes)
+ bb.order(ByteOrder.nativeOrder())
+ val magic = bb.getLong()
+ if (magic != 1) {
+ throw new IllegalArgumentException("Magic " + magic + " is wrong.")
+ }
+ val length = bb.getLong()
+ if (packetLength != 16 + 8 * length) {
+ throw new IllegalArgumentException("Length " + length + " is wrong.")
+ }
+ val db = bb.asDoubleBuffer()
+ val ans = new Array[Double](length.toInt)
+ db.get(ans)
+ return ans
+ }
+
+ private def serializeDoubleVector(doubles: Array[Double]): Array[Byte] = {
+ val len = doubles.length
+ val bytes = new Array[Byte](16 + 8 * len)
+ val bb = ByteBuffer.wrap(bytes)
+ bb.order(ByteOrder.nativeOrder())
+ bb.putLong(1)
+ bb.putLong(len)
+ val db = bb.asDoubleBuffer()
+ db.put(doubles)
+ return bytes
+ }
+
+ private def deserializeDoubleMatrix(bytes: Array[Byte]): Array[Array[Double]] = {
+ val packetLength = bytes.length
+ if (packetLength < 24) {
+ throw new IllegalArgumentException("Byte array too short.")
+ }
+ val bb = ByteBuffer.wrap(bytes)
+ bb.order(ByteOrder.nativeOrder())
+ val magic = bb.getLong()
+ if (magic != 2) {
+ throw new IllegalArgumentException("Magic " + magic + " is wrong.")
+ }
+ val rows = bb.getLong()
+ val cols = bb.getLong()
+ if (packetLength != 24 + 8 * rows * cols) {
+ throw new IllegalArgumentException("Size " + rows + "x" + cols + " is wrong.")
+ }
+ val db = bb.asDoubleBuffer()
+ val ans = new Array[Array[Double]](rows.toInt)
+ var i = 0
+ for (i <- 0 until rows.toInt) {
+ ans(i) = new Array[Double](cols.toInt)
+ db.get(ans(i))
+ }
+ return ans
+ }
+
+ private def serializeDoubleMatrix(doubles: Array[Array[Double]]): Array[Byte] = {
+ val rows = doubles.length
+ var cols = 0
+ if (rows > 0) {
+ cols = doubles(0).length
+ }
+ val bytes = new Array[Byte](24 + 8 * rows * cols)
+ val bb = ByteBuffer.wrap(bytes)
+ bb.order(ByteOrder.nativeOrder())
+ bb.putLong(2)
+ bb.putLong(rows)
+ bb.putLong(cols)
+ val db = bb.asDoubleBuffer()
+ var i = 0
+ for (i <- 0 until rows) {
+ db.put(doubles(i))
+ }
+ return bytes
+ }
+
+ private def trainRegressionModel(trainFunc: (RDD[LabeledPoint], Array[Double]) => GeneralizedLinearModel,
+ dataBytesJRDD: JavaRDD[Array[Byte]], initialWeightsBA: Array[Byte]):
+ java.util.LinkedList[java.lang.Object] = {
+ val data = dataBytesJRDD.rdd.map(xBytes => {
+ val x = deserializeDoubleVector(xBytes)
+ LabeledPoint(x(0), x.slice(1, x.length))
+ })
+ val initialWeights = deserializeDoubleVector(initialWeightsBA)
+ val model = trainFunc(data, initialWeights)
+ val ret = new java.util.LinkedList[java.lang.Object]()
+ ret.add(serializeDoubleVector(model.weights))
+ ret.add(model.intercept: java.lang.Double)
+ return ret
+ }
+
+ /**
+ * Java stub for Python mllib LinearRegressionWithSGD.train()
+ */
+ def trainLinearRegressionModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]],
+ numIterations: Int, stepSize: Double, miniBatchFraction: Double,
+ initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+ return trainRegressionModel((data, initialWeights) =>
+ LinearRegressionWithSGD.train(data, numIterations, stepSize,
+ miniBatchFraction, initialWeights),
+ dataBytesJRDD, initialWeightsBA)
+ }
+
+ /**
+ * Java stub for Python mllib LassoWithSGD.train()
+ */
+ def trainLassoModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int,
+ stepSize: Double, regParam: Double, miniBatchFraction: Double,
+ initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+ return trainRegressionModel((data, initialWeights) =>
+ LassoWithSGD.train(data, numIterations, stepSize, regParam,
+ miniBatchFraction, initialWeights),
+ dataBytesJRDD, initialWeightsBA)
+ }
+
+ /**
+ * Java stub for Python mllib RidgeRegressionWithSGD.train()
+ */
+ def trainRidgeModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int,
+ stepSize: Double, regParam: Double, miniBatchFraction: Double,
+ initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+ return trainRegressionModel((data, initialWeights) =>
+ RidgeRegressionWithSGD.train(data, numIterations, stepSize, regParam,
+ miniBatchFraction, initialWeights),
+ dataBytesJRDD, initialWeightsBA)
+ }
+
+ /**
+ * Java stub for Python mllib SVMWithSGD.train()
+ */
+ def trainSVMModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int,
+ stepSize: Double, regParam: Double, miniBatchFraction: Double,
+ initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+ return trainRegressionModel((data, initialWeights) =>
+ SVMWithSGD.train(data, numIterations, stepSize, regParam,
+ miniBatchFraction, initialWeights),
+ dataBytesJRDD, initialWeightsBA)
+ }
+
+ /**
+ * Java stub for Python mllib LogisticRegressionWithSGD.train()
+ */
+ def trainLogisticRegressionModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]],
+ numIterations: Int, stepSize: Double, miniBatchFraction: Double,
+ initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+ return trainRegressionModel((data, initialWeights) =>
+ LogisticRegressionWithSGD.train(data, numIterations, stepSize,
+ miniBatchFraction, initialWeights),
+ dataBytesJRDD, initialWeightsBA)
+ }
+
+ /**
+ * Java stub for Python mllib KMeans.train()
+ */
+ def trainKMeansModel(dataBytesJRDD: JavaRDD[Array[Byte]], k: Int,
+ maxIterations: Int, runs: Int, initializationMode: String):
+ java.util.List[java.lang.Object] = {
+ val data = dataBytesJRDD.rdd.map(xBytes => deserializeDoubleVector(xBytes))
+ val model = KMeans.train(data, k, maxIterations, runs, initializationMode)
+ val ret = new java.util.LinkedList[java.lang.Object]()
+ ret.add(serializeDoubleMatrix(model.clusterCenters))
+ return ret
+ }
+
+ private def unpackRating(ratingBytes: Array[Byte]): Rating = {
+ val bb = ByteBuffer.wrap(ratingBytes)
+ bb.order(ByteOrder.nativeOrder())
+ val user = bb.getInt()
+ val product = bb.getInt()
+ val rating = bb.getDouble()
+ return new Rating(user, product, rating)
+ }
+
+ /**
+ * Java stub for Python mllib ALS.train(). This stub returns a handle
+ * to the Java object instead of the content of the Java object. Extra care
+ * needs to be taken in the Python code to ensure it gets freed on exit; see
+ * the Py4J documentation.
+ */
+ def trainALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int,
+ iterations: Int, lambda: Double, blocks: Int): MatrixFactorizationModel = {
+ val ratings = ratingsBytesJRDD.rdd.map(unpackRating)
+ return ALS.train(ratings, rank, iterations, lambda, blocks)
+ }
+
+ /**
+ * Java stub for Python mllib ALS.trainImplicit(). This stub returns a
+ * handle to the Java object instead of the content of the Java object.
+ * Extra care needs to be taken in the Python code to ensure it gets freed on
+ * exit; see the Py4J documentation.
+ */
+ def trainImplicitALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int,
+ iterations: Int, lambda: Double, blocks: Int, alpha: Double): MatrixFactorizationModel = {
+ val ratings = ratingsBytesJRDD.rdd.map(unpackRating)
+ return ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha)
+ }
+}
diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 963b5b88be..1bba6a5ae4 100644
--- a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -437,8 +437,10 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
}
def monitorApplication(appId: ApplicationId): Boolean = {
+ val interval = new SparkConf().getOrElse("spark.yarn.report.interval", "1000").toLong
+
while (true) {
- Thread.sleep(1000)
+ Thread.sleep(interval)
val report = super.getApplicationReport(appId)
logInfo("Application report from ASM: \n" +
diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
index 71d1cbd416..abc3447746 100644
--- a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
+++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
@@ -27,8 +27,8 @@ import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import org.apache.spark.Logging
-import org.apache.spark.scheduler.SplitInfo
-import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedulerBackend}
+import org.apache.spark.scheduler.{SplitInfo,TaskSchedulerImpl}
+import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
import org.apache.spark.util.Utils
import org.apache.hadoop.conf.Configuration
@@ -233,9 +233,9 @@ private[yarn] class YarnAllocationHandler(
// Note that the list we create below tries to ensure that not all containers end up within
// a host if there is a sufficiently large number of hosts/containers.
val allocatedContainersToProcess = new ArrayBuffer[Container](allocatedContainers.size)
- allocatedContainersToProcess ++= ClusterScheduler.prioritizeContainers(dataLocalContainers)
- allocatedContainersToProcess ++= ClusterScheduler.prioritizeContainers(rackLocalContainers)
- allocatedContainersToProcess ++= ClusterScheduler.prioritizeContainers(offRackContainers)
+ allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(dataLocalContainers)
+ allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(rackLocalContainers)
+ allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(offRackContainers)
// Run each of the allocated containers.
for (container <- allocatedContainersToProcess) {
diff --git a/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala
index 63a0449e5a..522e0a9ad7 100644
--- a/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala
+++ b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala
@@ -20,13 +20,14 @@ package org.apache.spark.scheduler.cluster
import org.apache.spark._
import org.apache.hadoop.conf.Configuration
import org.apache.spark.deploy.yarn.YarnAllocationHandler
+import org.apache.spark.scheduler.TaskSchedulerImpl
import org.apache.spark.util.Utils
/**
*
* This scheduler launch worker through Yarn - by call into Client to launch WorkerLauncher as AM.
*/
-private[spark] class YarnClientClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) {
+private[spark] class YarnClientClusterScheduler(sc: SparkContext, conf: Configuration) extends TaskSchedulerImpl(sc) {
def this(sc: SparkContext) = this(sc, new Configuration())
diff --git a/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
index 6feaaff014..4b69f5078b 100644
--- a/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
+++ b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
@@ -20,9 +20,10 @@ package org.apache.spark.scheduler.cluster
import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState}
import org.apache.spark.{SparkException, Logging, SparkContext}
import org.apache.spark.deploy.yarn.{Client, ClientArguments}
+import org.apache.spark.scheduler.TaskSchedulerImpl
private[spark] class YarnClientSchedulerBackend(
- scheduler: ClusterScheduler,
+ scheduler: TaskSchedulerImpl,
sc: SparkContext)
extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
with Logging {
diff --git a/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
index 29b3f22e13..a4638cc863 100644
--- a/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
+++ b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
@@ -19,6 +19,7 @@ package org.apache.spark.scheduler.cluster
import org.apache.spark._
import org.apache.spark.deploy.yarn.{ApplicationMaster, YarnAllocationHandler}
+import org.apache.spark.scheduler.TaskSchedulerImpl
import org.apache.spark.util.Utils
import org.apache.hadoop.conf.Configuration
@@ -26,7 +27,7 @@ import org.apache.hadoop.conf.Configuration
*
* This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of ApplicationMaster, etc is done
*/
-private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) {
+private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) extends TaskSchedulerImpl(sc) {
logInfo("Created YarnClusterScheduler")
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index ffb54a24ac..37d6f1b60d 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -114,6 +114,9 @@ object SparkBuild extends Build {
fork := true,
javaOptions += "-Xmx3g",
+ // Show full stack trace and duration in test cases.
+ testOptions in Test += Tests.Argument("-oDF"),
+
// Only allow one test at a time, even across projects, since they run in the same JVM
concurrentRestrictions in Global += Tags.limit(Tags.Test, 1),
@@ -260,7 +263,7 @@ object SparkBuild extends Build {
libraryDependencies <+= scalaVersion(v => "org.scala-lang" % "scala-reflect" % v )
)
-
+
def examplesSettings = sharedSettings ++ Seq(
name := "spark-examples",
libraryDependencies ++= Seq(
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index 128f078d12..d8ca9fce00 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -63,5 +63,6 @@ def launch_gateway():
java_import(gateway.jvm, "org.apache.spark.SparkConf")
java_import(gateway.jvm, "org.apache.spark.api.java.*")
java_import(gateway.jvm, "org.apache.spark.api.python.*")
+ java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
java_import(gateway.jvm, "scala.Tuple2")
return gateway
diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py
new file mode 100644
index 0000000000..b1a5df109b
--- /dev/null
+++ b/python/pyspark/mllib/__init__.py
@@ -0,0 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Python bindings for MLlib.
+"""
diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py
new file mode 100644
index 0000000000..e74ba0fabc
--- /dev/null
+++ b/python/pyspark/mllib/_common.py
@@ -0,0 +1,227 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape
+from pyspark import SparkContext
+
+# Double vector format:
+#
+# [8-byte 1] [8-byte length] [length*8 bytes of data]
+#
+# Double matrix format:
+#
+# [8-byte 2] [8-byte rows] [8-byte cols] [rows*cols*8 bytes of data]
+#
+# This is all in machine-endian. That means that the Java interpreter and the
+# Python interpreter must agree on what endian the machine is.
+
+def _deserialize_byte_array(shape, ba, offset):
+ """Wrapper around ndarray aliasing hack.
+
+ >>> x = array([1.0, 2.0, 3.0, 4.0, 5.0])
+ >>> array_equal(x, _deserialize_byte_array(x.shape, x.data, 0))
+ True
+ >>> x = array([1.0, 2.0, 3.0, 4.0]).reshape(2,2)
+ >>> array_equal(x, _deserialize_byte_array(x.shape, x.data, 0))
+ True
+ """
+ ar = ndarray(shape=shape, buffer=ba, offset=offset, dtype="float64",
+ order='C')
+ return ar.copy()
+
+def _serialize_double_vector(v):
+ """Serialize a double vector into a mutually understood format."""
+ if type(v) != ndarray:
+ raise TypeError("_serialize_double_vector called on a %s; "
+ "wanted ndarray" % type(v))
+ if v.dtype != float64:
+ raise TypeError("_serialize_double_vector called on an ndarray of %s; "
+ "wanted ndarray of float64" % v.dtype)
+ if v.ndim != 1:
+ raise TypeError("_serialize_double_vector called on a %ddarray; "
+ "wanted a 1darray" % v.ndim)
+ length = v.shape[0]
+ ba = bytearray(16 + 8*length)
+ header = ndarray(shape=[2], buffer=ba, dtype="int64")
+ header[0] = 1
+ header[1] = length
+ copyto(ndarray(shape=[length], buffer=ba, offset=16,
+ dtype="float64"), v)
+ return ba
+
+def _deserialize_double_vector(ba):
+ """Deserialize a double vector from a mutually understood format.
+
+ >>> x = array([1.0, 2.0, 3.0, 4.0, -1.0, 0.0, -0.0])
+ >>> array_equal(x, _deserialize_double_vector(_serialize_double_vector(x)))
+ True
+ """
+ if type(ba) != bytearray:
+ raise TypeError("_deserialize_double_vector called on a %s; "
+ "wanted bytearray" % type(ba))
+ if len(ba) < 16:
+ raise TypeError("_deserialize_double_vector called on a %d-byte array, "
+ "which is too short" % len(ba))
+ if (len(ba) & 7) != 0:
+ raise TypeError("_deserialize_double_vector called on a %d-byte array, "
+ "which is not a multiple of 8" % len(ba))
+ header = ndarray(shape=[2], buffer=ba, dtype="int64")
+ if header[0] != 1:
+ raise TypeError("_deserialize_double_vector called on bytearray "
+ "with wrong magic")
+ length = header[1]
+ if len(ba) != 8*length + 16:
+ raise TypeError("_deserialize_double_vector called on bytearray "
+ "with wrong length")
+ return _deserialize_byte_array([length], ba, 16)
+
+def _serialize_double_matrix(m):
+ """Serialize a double matrix into a mutually understood format."""
+ if (type(m) == ndarray and m.dtype == float64 and m.ndim == 2):
+ rows = m.shape[0]
+ cols = m.shape[1]
+ ba = bytearray(24 + 8 * rows * cols)
+ header = ndarray(shape=[3], buffer=ba, dtype="int64")
+ header[0] = 2
+ header[1] = rows
+ header[2] = cols
+ copyto(ndarray(shape=[rows, cols], buffer=ba, offset=24,
+ dtype="float64", order='C'), m)
+ return ba
+ else:
+ raise TypeError("_serialize_double_matrix called on a "
+ "non-double-matrix")
+
+def _deserialize_double_matrix(ba):
+ """Deserialize a double matrix from a mutually understood format."""
+ if type(ba) != bytearray:
+ raise TypeError("_deserialize_double_matrix called on a %s; "
+ "wanted bytearray" % type(ba))
+ if len(ba) < 24:
+ raise TypeError("_deserialize_double_matrix called on a %d-byte array, "
+ "which is too short" % len(ba))
+ if (len(ba) & 7) != 0:
+ raise TypeError("_deserialize_double_matrix called on a %d-byte array, "
+ "which is not a multiple of 8" % len(ba))
+ header = ndarray(shape=[3], buffer=ba, dtype="int64")
+ if (header[0] != 2):
+ raise TypeError("_deserialize_double_matrix called on bytearray "
+ "with wrong magic")
+ rows = header[1]
+ cols = header[2]
+ if (len(ba) != 8*rows*cols + 24):
+ raise TypeError("_deserialize_double_matrix called on bytearray "
+ "with wrong length")
+ return _deserialize_byte_array([rows, cols], ba, 24)
+
+def _linear_predictor_typecheck(x, coeffs):
+ """Check that x is a one-dimensional vector of the right shape.
+ This is a temporary hackaround until I actually implement bulk predict."""
+ if type(x) == ndarray:
+ if x.ndim == 1:
+ if x.shape == coeffs.shape:
+ pass
+ else:
+ raise RuntimeError("Got array of %d elements; wanted %d"
+ % (shape(x)[0], shape(coeffs)[0]))
+ else:
+ raise RuntimeError("Bulk predict not yet supported.")
+ elif (type(x) == RDD):
+ raise RuntimeError("Bulk predict not yet supported.")
+ else:
+ raise TypeError("Argument of type " + type(x).__name__ + " unsupported")
+
+def _get_unmangled_rdd(data, serializer):
+ dataBytes = data.map(serializer)
+ dataBytes._bypass_serializer = True
+ dataBytes.cache()
+ return dataBytes
+
+# Map a pickled Python RDD of numpy double vectors to a Java RDD of
+# _serialized_double_vectors
+def _get_unmangled_double_vector_rdd(data):
+ return _get_unmangled_rdd(data, _serialize_double_vector)
+
+class LinearModel(object):
+ """Something that has a vector of coefficients and an intercept."""
+ def __init__(self, coeff, intercept):
+ self._coeff = coeff
+ self._intercept = intercept
+
+class LinearRegressionModelBase(LinearModel):
+ """A linear regression model.
+
+ >>> lrmb = LinearRegressionModelBase(array([1.0, 2.0]), 0.1)
+ >>> abs(lrmb.predict(array([-1.03, 7.777])) - 14.624) < 1e-6
+ True
+ """
+ def predict(self, x):
+ """Predict the value of the dependent variable given a vector x"""
+ """containing values for the independent variables."""
+ _linear_predictor_typecheck(x, self._coeff)
+ return dot(self._coeff, x) + self._intercept
+
+# If we weren't given initial weights, take a zero vector of the appropriate
+# length.
+def _get_initial_weights(initial_weights, data):
+ if initial_weights is None:
+ initial_weights = data.first()
+ if type(initial_weights) != ndarray:
+ raise TypeError("At least one data element has type "
+ + type(initial_weights).__name__ + " which is not ndarray")
+ if initial_weights.ndim != 1:
+ raise TypeError("At least one data element has "
+ + initial_weights.ndim + " dimensions, which is not 1")
+ initial_weights = ones([initial_weights.shape[0] - 1])
+ return initial_weights
+
+# train_func should take two parameters, namely data and initial_weights, and
+# return the result of a call to the appropriate JVM stub.
+# _regression_train_wrapper is responsible for setup and error checking.
+def _regression_train_wrapper(sc, train_func, klass, data, initial_weights):
+ initial_weights = _get_initial_weights(initial_weights, data)
+ dataBytes = _get_unmangled_double_vector_rdd(data)
+ ans = train_func(dataBytes, _serialize_double_vector(initial_weights))
+ if len(ans) != 2:
+ raise RuntimeError("JVM call result had unexpected length")
+ elif type(ans[0]) != bytearray:
+ raise RuntimeError("JVM call result had first element of type "
+ + type(ans[0]).__name__ + " which is not bytearray")
+ elif type(ans[1]) != float:
+ raise RuntimeError("JVM call result had second element of type "
+ + type(ans[0]).__name__ + " which is not float")
+ return klass(_deserialize_double_vector(ans[0]), ans[1])
+
+def _serialize_rating(r):
+ ba = bytearray(16)
+ intpart = ndarray(shape=[2], buffer=ba, dtype=int32)
+ doublepart = ndarray(shape=[1], buffer=ba, dtype=float64, offset=8)
+ intpart[0], intpart[1], doublepart[0] = r
+ return ba
+
+def _test():
+ import doctest
+ globs = globals().copy()
+ globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+ (failure_count, test_count) = doctest.testmod(globs=globs,
+ optionflags=doctest.ELLIPSIS)
+ globs['sc'].stop()
+ if failure_count:
+ exit(-1)
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
new file mode 100644
index 0000000000..70de332d34
--- /dev/null
+++ b/python/pyspark/mllib/classification.py
@@ -0,0 +1,86 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from numpy import array, dot, shape
+from pyspark import SparkContext
+from pyspark.mllib._common import \
+ _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \
+ _serialize_double_matrix, _deserialize_double_matrix, \
+ _serialize_double_vector, _deserialize_double_vector, \
+ _get_initial_weights, _serialize_rating, _regression_train_wrapper, \
+ LinearModel, _linear_predictor_typecheck
+from math import exp, log
+
+class LogisticRegressionModel(LinearModel):
+ """A linear binary classification model derived from logistic regression.
+
+ >>> data = array([0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 1.0, 3.0]).reshape(4,2)
+ >>> lrm = LogisticRegressionWithSGD.train(sc, sc.parallelize(data))
+ >>> lrm.predict(array([1.0])) != None
+ True
+ """
+ def predict(self, x):
+ _linear_predictor_typecheck(x, self._coeff)
+ margin = dot(x, self._coeff) + self._intercept
+ prob = 1/(1 + exp(-margin))
+ return 1 if prob > 0.5 else 0
+
+class LogisticRegressionWithSGD(object):
+ @classmethod
+ def train(cls, sc, data, iterations=100, step=1.0,
+ mini_batch_fraction=1.0, initial_weights=None):
+ """Train a logistic regression model on the given data."""
+ return _regression_train_wrapper(sc, lambda d, i:
+ sc._jvm.PythonMLLibAPI().trainLogisticRegressionModelWithSGD(d._jrdd,
+ iterations, step, mini_batch_fraction, i),
+ LogisticRegressionModel, data, initial_weights)
+
+class SVMModel(LinearModel):
+ """A support vector machine.
+
+ >>> data = array([0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 1.0, 3.0]).reshape(4,2)
+ >>> svm = SVMWithSGD.train(sc, sc.parallelize(data))
+ >>> svm.predict(array([1.0])) != None
+ True
+ """
+ def predict(self, x):
+ _linear_predictor_typecheck(x, self._coeff)
+ margin = dot(x, self._coeff) + self._intercept
+ return 1 if margin >= 0 else 0
+
+class SVMWithSGD(object):
+ @classmethod
+ def train(cls, sc, data, iterations=100, step=1.0, reg_param=1.0,
+ mini_batch_fraction=1.0, initial_weights=None):
+ """Train a support vector machine on the given data."""
+ return _regression_train_wrapper(sc, lambda d, i:
+ sc._jvm.PythonMLLibAPI().trainSVMModelWithSGD(d._jrdd,
+ iterations, step, reg_param, mini_batch_fraction, i),
+ SVMModel, data, initial_weights)
+
+def _test():
+ import doctest
+ globs = globals().copy()
+ globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+ (failure_count, test_count) = doctest.testmod(globs=globs,
+ optionflags=doctest.ELLIPSIS)
+ globs['sc'].stop()
+ if failure_count:
+ exit(-1)
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
new file mode 100644
index 0000000000..8cf20e591a
--- /dev/null
+++ b/python/pyspark/mllib/clustering.py
@@ -0,0 +1,79 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from numpy import array, dot
+from math import sqrt
+from pyspark import SparkContext
+from pyspark.mllib._common import \
+ _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \
+ _serialize_double_matrix, _deserialize_double_matrix, \
+ _serialize_double_vector, _deserialize_double_vector, \
+ _get_initial_weights, _serialize_rating, _regression_train_wrapper
+
+class KMeansModel(object):
+ """A clustering model derived from the k-means method.
+
+ >>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4,2)
+ >>> clusters = KMeans.train(sc, sc.parallelize(data), 2, maxIterations=10, runs=30, initialization_mode="random")
+ >>> clusters.predict(array([0.0, 0.0])) == clusters.predict(array([1.0, 1.0]))
+ True
+ >>> clusters.predict(array([8.0, 9.0])) == clusters.predict(array([9.0, 8.0]))
+ True
+ >>> clusters = KMeans.train(sc, sc.parallelize(data), 2)
+ """
+ def __init__(self, centers_):
+ self.centers = centers_
+
+ def predict(self, x):
+ """Find the cluster to which x belongs in this model."""
+ best = 0
+ best_distance = 1e75
+ for i in range(0, self.centers.shape[0]):
+ diff = x - self.centers[i]
+ distance = sqrt(dot(diff, diff))
+ if distance < best_distance:
+ best = i
+ best_distance = distance
+ return best
+
+class KMeans(object):
+ @classmethod
+ def train(cls, sc, data, k, maxIterations=100, runs=1,
+ initialization_mode="k-means||"):
+ """Train a k-means clustering model."""
+ dataBytes = _get_unmangled_double_vector_rdd(data)
+ ans = sc._jvm.PythonMLLibAPI().trainKMeansModel(dataBytes._jrdd,
+ k, maxIterations, runs, initialization_mode)
+ if len(ans) != 1:
+ raise RuntimeError("JVM call result had unexpected length")
+ elif type(ans[0]) != bytearray:
+ raise RuntimeError("JVM call result had first element of type "
+ + type(ans[0]) + " which is not bytearray")
+ return KMeansModel(_deserialize_double_matrix(ans[0]))
+
+def _test():
+ import doctest
+ globs = globals().copy()
+ globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+ (failure_count, test_count) = doctest.testmod(globs=globs,
+ optionflags=doctest.ELLIPSIS)
+ globs['sc'].stop()
+ if failure_count:
+ exit(-1)
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
new file mode 100644
index 0000000000..14d06cba21
--- /dev/null
+++ b/python/pyspark/mllib/recommendation.py
@@ -0,0 +1,74 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark import SparkContext
+from pyspark.mllib._common import \
+ _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \
+ _serialize_double_matrix, _deserialize_double_matrix, \
+ _serialize_double_vector, _deserialize_double_vector, \
+ _get_initial_weights, _serialize_rating, _regression_train_wrapper
+
+class MatrixFactorizationModel(object):
+ """A matrix factorisation model trained by regularized alternating
+ least-squares.
+
+ >>> r1 = (1, 1, 1.0)
+ >>> r2 = (1, 2, 2.0)
+ >>> r3 = (2, 1, 2.0)
+ >>> ratings = sc.parallelize([r1, r2, r3])
+ >>> model = ALS.trainImplicit(sc, ratings, 1)
+ >>> model.predict(2,2) is not None
+ True
+ """
+
+ def __init__(self, sc, java_model):
+ self._context = sc
+ self._java_model = java_model
+
+ def __del__(self):
+ self._context._gateway.detach(self._java_model)
+
+ def predict(self, user, product):
+ return self._java_model.predict(user, product)
+
+class ALS(object):
+ @classmethod
+ def train(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1):
+ ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating)
+ mod = sc._jvm.PythonMLLibAPI().trainALSModel(ratingBytes._jrdd,
+ rank, iterations, lambda_, blocks)
+ return MatrixFactorizationModel(sc, mod)
+
+ @classmethod
+ def trainImplicit(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01):
+ ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating)
+ mod = sc._jvm.PythonMLLibAPI().trainImplicitALSModel(ratingBytes._jrdd,
+ rank, iterations, lambda_, blocks, alpha)
+ return MatrixFactorizationModel(sc, mod)
+
+def _test():
+ import doctest
+ globs = globals().copy()
+ globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+ (failure_count, test_count) = doctest.testmod(globs=globs,
+ optionflags=doctest.ELLIPSIS)
+ globs['sc'].stop()
+ if failure_count:
+ exit(-1)
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
new file mode 100644
index 0000000000..a3a68b29e0
--- /dev/null
+++ b/python/pyspark/mllib/regression.py
@@ -0,0 +1,110 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from numpy import array, dot
+from pyspark import SparkContext
+from pyspark.mllib._common import \
+ _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \
+ _serialize_double_matrix, _deserialize_double_matrix, \
+ _serialize_double_vector, _deserialize_double_vector, \
+ _get_initial_weights, _serialize_rating, _regression_train_wrapper, \
+ _linear_predictor_typecheck
+
+class LinearModel(object):
+ """Something that has a vector of coefficients and an intercept."""
+ def __init__(self, coeff, intercept):
+ self._coeff = coeff
+ self._intercept = intercept
+
+class LinearRegressionModelBase(LinearModel):
+ """A linear regression model.
+
+ >>> lrmb = LinearRegressionModelBase(array([1.0, 2.0]), 0.1)
+ >>> abs(lrmb.predict(array([-1.03, 7.777])) - 14.624) < 1e-6
+ True
+ """
+ def predict(self, x):
+ """Predict the value of the dependent variable given a vector x"""
+ """containing values for the independent variables."""
+ _linear_predictor_typecheck(x, self._coeff)
+ return dot(self._coeff, x) + self._intercept
+
+class LinearRegressionModel(LinearRegressionModelBase):
+ """A linear regression model derived from a least-squares fit.
+
+ >>> data = array([0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0]).reshape(4,2)
+ >>> lrm = LinearRegressionWithSGD.train(sc, sc.parallelize(data), initial_weights=array([1.0]))
+ """
+
+class LinearRegressionWithSGD(object):
+ @classmethod
+ def train(cls, sc, data, iterations=100, step=1.0,
+ mini_batch_fraction=1.0, initial_weights=None):
+ """Train a linear regression model on the given data."""
+ return _regression_train_wrapper(sc, lambda d, i:
+ sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD(
+ d._jrdd, iterations, step, mini_batch_fraction, i),
+ LinearRegressionModel, data, initial_weights)
+
+class LassoModel(LinearRegressionModelBase):
+ """A linear regression model derived from a least-squares fit with an
+ l_1 penalty term.
+
+ >>> data = array([0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0]).reshape(4,2)
+ >>> lrm = LassoWithSGD.train(sc, sc.parallelize(data), initial_weights=array([1.0]))
+ """
+
+class LassoWithSGD(object):
+ @classmethod
+ def train(cls, sc, data, iterations=100, step=1.0, reg_param=1.0,
+ mini_batch_fraction=1.0, initial_weights=None):
+ """Train a Lasso regression model on the given data."""
+ return _regression_train_wrapper(sc, lambda d, i:
+ sc._jvm.PythonMLLibAPI().trainLassoModelWithSGD(d._jrdd,
+ iterations, step, reg_param, mini_batch_fraction, i),
+ LassoModel, data, initial_weights)
+
+class RidgeRegressionModel(LinearRegressionModelBase):
+ """A linear regression model derived from a least-squares fit with an
+ l_2 penalty term.
+
+ >>> data = array([0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0]).reshape(4,2)
+ >>> lrm = RidgeRegressionWithSGD.train(sc, sc.parallelize(data), initial_weights=array([1.0]))
+ """
+
+class RidgeRegressionWithSGD(object):
+ @classmethod
+ def train(cls, sc, data, iterations=100, step=1.0, reg_param=1.0,
+ mini_batch_fraction=1.0, initial_weights=None):
+ """Train a ridge regression model on the given data."""
+ return _regression_train_wrapper(sc, lambda d, i:
+ sc._jvm.PythonMLLibAPI().trainRidgeModelWithSGD(d._jrdd,
+ iterations, step, reg_param, mini_batch_fraction, i),
+ RidgeRegressionModel, data, initial_weights)
+
+def _test():
+ import doctest
+ globs = globals().copy()
+ globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+ (failure_count, test_count) = doctest.testmod(globs=globs,
+ optionflags=doctest.ELLIPSIS)
+ globs['sc'].stop()
+ if failure_count:
+ exit(-1)
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 811fa6f018..2a500ab919 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -308,4 +308,4 @@ def write_int(value, stream):
def write_with_length(obj, stream):
write_int(len(obj), stream)
- stream.write(obj) \ No newline at end of file
+ stream.write(obj)
diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py
index a475959090..ef07eb437b 100644
--- a/python/pyspark/shell.py
+++ b/python/pyspark/shell.py
@@ -42,7 +42,7 @@ print "Using Python version %s (%s, %s)" % (
platform.python_version(),
platform.python_build()[0],
platform.python_build()[1])
-print "Spark context avaiable as sc."
+print "Spark context available as sc."
if add_files != None:
print "Adding files: [%s]" % ", ".join(add_files)
diff --git a/spark-class b/spark-class
index 4eb95a9ba2..802e4aa104 100755
--- a/spark-class
+++ b/spark-class
@@ -129,11 +129,11 @@ fi
# Compute classpath using external script
CLASSPATH=`$FWDIR/bin/compute-classpath.sh`
-CLASSPATH="$SPARK_TOOLS_JAR:$CLASSPATH"
+CLASSPATH="$CLASSPATH:$SPARK_TOOLS_JAR"
if $cygwin; then
- CLASSPATH=`cygpath -wp $CLASSPATH`
- export SPARK_TOOLS_JAR=`cygpath -w $SPARK_TOOLS_JAR`
+ CLASSPATH=`cygpath -wp $CLASSPATH`
+ export SPARK_TOOLS_JAR=`cygpath -w $SPARK_TOOLS_JAR`
fi
export CLASSPATH
diff --git a/spark-class2.cmd b/spark-class2.cmd
index 3869d0761b..dc9dadf356 100644
--- a/spark-class2.cmd
+++ b/spark-class2.cmd
@@ -17,7 +17,7 @@ rem See the License for the specific language governing permissions and
rem limitations under the License.
rem
-set SCALA_VERSION=2.9.3
+set SCALA_VERSION=2.10
rem Figure out where the Spark framework is installed
set FWDIR=%~dp0
@@ -75,7 +75,7 @@ rem Compute classpath using external script
set DONT_PRINT_CLASSPATH=1
call "%FWDIR%bin\compute-classpath.cmd"
set DONT_PRINT_CLASSPATH=0
-set CLASSPATH=%SPARK_TOOLS_JAR%;%CLASSPATH%
+set CLASSPATH=%CLASSPATH%;%SPARK_TOOLS_JAR%
rem Figure out where java is.
set RUNNER=java
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
index f106bba678..35e23c1355 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
@@ -39,9 +39,9 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time)
val graph = ssc.graph
val checkpointDir = ssc.checkpointDir
val checkpointDuration = ssc.checkpointDuration
- val pendingTimes = ssc.scheduler.jobManager.getPendingTimes()
+ val pendingTimes = ssc.scheduler.getPendingTimes()
val delaySeconds = MetadataCleaner.getDelaySeconds(ssc.conf)
- val sparkConf = ssc.sc.conf
+ val sparkConf = ssc.conf
def validate() {
assert(master != null, "Checkpoint.master is null")
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
index 8005202500..ce2a9d4142 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
@@ -17,24 +17,19 @@
package org.apache.spark.streaming
-import org.apache.spark.streaming.dstream._
import StreamingContext._
-import org.apache.spark.util.MetadataCleaner
-
-//import Time._
-
+import org.apache.spark.streaming.dstream._
+import org.apache.spark.streaming.scheduler.Job
import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.MetadataCleaner
-import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.reflect.ClassTag
import java.io.{ObjectInputStream, IOException, ObjectOutputStream}
-import org.apache.hadoop.fs.{FileSystem, Path}
-import org.apache.hadoop.conf.Configuration
/**
* A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
index b9a58fded6..daed7ff7c3 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
@@ -21,6 +21,7 @@ import dstream.InputDStream
import java.io.{ObjectInputStream, IOException, ObjectOutputStream}
import collection.mutable.ArrayBuffer
import org.apache.spark.Logging
+import org.apache.spark.streaming.scheduler.Job
final private[streaming] class DStreamGraph extends Serializable with Logging {
initLogging()
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/JobManager.scala b/streaming/src/main/scala/org/apache/spark/streaming/JobManager.scala
deleted file mode 100644
index 5233129506..0000000000
--- a/streaming/src/main/scala/org/apache/spark/streaming/JobManager.scala
+++ /dev/null
@@ -1,88 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.streaming
-
-import org.apache.spark.Logging
-import org.apache.spark.SparkEnv
-import java.util.concurrent.Executors
-import collection.mutable.HashMap
-import collection.mutable.ArrayBuffer
-
-
-private[streaming]
-class JobManager(ssc: StreamingContext, numThreads: Int = 1) extends Logging {
-
- class JobHandler(ssc: StreamingContext, job: Job) extends Runnable {
- def run() {
- SparkEnv.set(ssc.env)
- try {
- val timeTaken = job.run()
- logInfo("Total delay: %.5f s for job %s of time %s (execution: %.5f s)".format(
- (System.currentTimeMillis() - job.time.milliseconds) / 1000.0, job.id, job.time.milliseconds, timeTaken / 1000.0))
- } catch {
- case e: Exception =>
- logError("Running " + job + " failed", e)
- }
- clearJob(job)
- }
- }
-
- initLogging()
-
- val jobExecutor = Executors.newFixedThreadPool(numThreads)
- val jobs = new HashMap[Time, ArrayBuffer[Job]]
-
- def runJob(job: Job) {
- jobs.synchronized {
- jobs.getOrElseUpdate(job.time, new ArrayBuffer[Job]) += job
- }
- jobExecutor.execute(new JobHandler(ssc, job))
- logInfo("Added " + job + " to queue")
- }
-
- def stop() {
- jobExecutor.shutdown()
- }
-
- private def clearJob(job: Job) {
- var timeCleared = false
- val time = job.time
- jobs.synchronized {
- val jobsOfTime = jobs.get(time)
- if (jobsOfTime.isDefined) {
- jobsOfTime.get -= job
- if (jobsOfTime.get.isEmpty) {
- jobs -= time
- timeCleared = true
- }
- } else {
- throw new Exception("Job finished for time " + job.time +
- " but time does not exist in jobs")
- }
- }
- if (timeCleared) {
- ssc.scheduler.clearOldMetadata(time)
- }
- }
-
- def getPendingTimes(): Array[Time] = {
- jobs.synchronized {
- jobs.keySet.toArray
- }
- }
-}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index 286ec285a9..339f6e64a2 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -47,9 +47,9 @@ import org.apache.hadoop.mapreduce.lib.input.TextInputFormat
import org.apache.hadoop.fs.Path
import twitter4j.Status
import twitter4j.auth.Authorization
+import org.apache.spark.streaming.scheduler._
import akka.util.ByteString
-
/**
* A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic
* information (such as, cluster URL and job name) to internally create a SparkContext, it provides
@@ -160,9 +160,10 @@ class StreamingContext private (
}
}
- protected[streaming] var checkpointDuration: Duration = if (isCheckpointPresent) cp_.checkpointDuration else null
- protected[streaming] var receiverJobThread: Thread = null
- protected[streaming] var scheduler: Scheduler = null
+ protected[streaming] val checkpointDuration: Duration = {
+ if (isCheckpointPresent) cp_.checkpointDuration else graph.batchDuration
+ }
+ protected[streaming] val scheduler = new JobScheduler(this)
/**
* Return the associated Spark context
@@ -524,6 +525,13 @@ class StreamingContext private (
graph.addOutputStream(outputStream)
}
+ /** Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for
+ * receiving system events related to streaming.
+ */
+ def addStreamingListener(streamingListener: StreamingListener) {
+ scheduler.listenerBus.addListener(streamingListener)
+ }
+
protected def validate() {
assert(graph != null, "Graph is null")
graph.validate()
@@ -539,27 +547,22 @@ class StreamingContext private (
* Start the execution of the streams.
*/
def start() {
- if (checkpointDir != null && checkpointDuration == null && graph != null) {
- checkpointDuration = graph.batchDuration
- }
-
validate()
+ // Get the network input streams
val networkInputStreams = graph.getInputStreams().filter(s => s match {
case n: NetworkInputDStream[_] => true
case _ => false
}).map(_.asInstanceOf[NetworkInputDStream[_]]).toArray
+ // Start the network input tracker (must start before receivers)
if (networkInputStreams.length > 0) {
- // Start the network input tracker (must start before receivers)
networkInputTracker = new NetworkInputTracker(this, networkInputStreams)
networkInputTracker.start()
}
-
Thread.sleep(1000)
// Start the scheduler
- scheduler = new Scheduler(this)
scheduler.start()
}
@@ -570,7 +573,6 @@ class StreamingContext private (
try {
if (scheduler != null) scheduler.stop()
if (networkInputTracker != null) networkInputTracker.stop()
- if (receiverJobThread != null) receiverJobThread.interrupt()
sc.stop()
logInfo("StreamingContext stopped successfully")
} catch {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
index 5842a7cd68..29f673d8ae 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
@@ -40,6 +40,7 @@ import org.apache.spark.api.java.{JavaPairRDD, JavaSparkContext, JavaRDD}
import org.apache.spark.streaming._
import org.apache.spark.streaming.dstream._
import org.apache.spark.SparkConf
+import org.apache.spark.streaming.scheduler.StreamingListener
/**
* A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic
@@ -696,6 +697,13 @@ class JavaStreamingContext(val ssc: StreamingContext) {
ssc.remember(duration)
}
+ /** Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for
+ * receiving system events related to streaming.
+ */
+ def addStreamingListener(streamingListener: StreamingListener) {
+ ssc.addStreamingListener(streamingListener)
+ }
+
/**
* Starts the execution of the streams.
*/
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala
index 98b14cb224..364abcde68 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala
@@ -18,7 +18,8 @@
package org.apache.spark.streaming.dstream
import org.apache.spark.rdd.RDD
-import org.apache.spark.streaming.{Duration, DStream, Job, Time}
+import org.apache.spark.streaming.{Duration, DStream, Time}
+import org.apache.spark.streaming.scheduler.Job
import scala.reflect.ClassTag
private[streaming]
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
index bd607f9d18..1839ca3578 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
@@ -33,6 +33,7 @@ import org.apache.spark.streaming._
import org.apache.spark.{Logging, SparkEnv}
import org.apache.spark.rdd.{RDD, BlockRDD}
import org.apache.spark.storage.{BlockId, StorageLevel, StreamBlockId}
+import org.apache.spark.streaming.scheduler.{DeregisterReceiver, AddBlocks, RegisterReceiver}
/**
* Abstract class for defining any InputDStream that has to start a receiver on worker
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala
new file mode 100644
index 0000000000..4e8d07fe92
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler
+
+import org.apache.spark.streaming.Time
+
+/**
+ * Class having information on completed batches.
+ * @param batchTime Time of the batch
+ * @param submissionTime Clock time of when jobs of this batch was submitted to
+ * the streaming scheduler queue
+ * @param processingStartTime Clock time of when the first job of this batch started processing
+ * @param processingEndTime Clock time of when the last job of this batch finished processing
+ */
+case class BatchInfo(
+ batchTime: Time,
+ submissionTime: Long,
+ processingStartTime: Option[Long],
+ processingEndTime: Option[Long]
+ ) {
+
+ /**
+ * Time taken for the first job of this batch to start processing from the time this batch
+ * was submitted to the streaming scheduler. Essentially, it is
+ * `processingStartTime` - `submissionTime`.
+ */
+ def schedulingDelay = processingStartTime.map(_ - submissionTime)
+
+ /**
+ * Time taken for the all jobs of this batch to finish processing from the time they started
+ * processing. Essentially, it is `processingEndTime` - `processingStartTime`.
+ */
+ def processingDelay = processingEndTime.zip(processingStartTime).map(x => x._1 - x._2).headOption
+
+ /**
+ * Time taken for all the jobs of this batch to finish processing from the time they
+ * were submitted. Essentially, it is `processingDelay` + `schedulingDelay`.
+ */
+ def totalDelay = schedulingDelay.zip(processingDelay).map(x => x._1 + x._2).headOption
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Job.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala
index 2128b7c7a6..7341bfbc99 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Job.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala
@@ -15,13 +15,17 @@
* limitations under the License.
*/
-package org.apache.spark.streaming
+package org.apache.spark.streaming.scheduler
-import java.util.concurrent.atomic.AtomicLong
+import org.apache.spark.streaming.Time
+/**
+ * Class representing a Spark computation. It may contain multiple Spark jobs.
+ */
private[streaming]
class Job(val time: Time, func: () => _) {
- val id = Job.getNewId()
+ var id: String = _
+
def run(): Long = {
val startTime = System.currentTimeMillis
func()
@@ -29,13 +33,9 @@ class Job(val time: Time, func: () => _) {
(stopTime - startTime)
}
- override def toString = "streaming job " + id + " @ " + time
-}
-
-private[streaming]
-object Job {
- val id = new AtomicLong(0)
-
- def getNewId() = id.getAndIncrement()
-}
+ def setId(number: Int) {
+ id = "streaming job " + time + "." + number
+ }
+ override def toString = id
+} \ No newline at end of file
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Scheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
index 82ed6bed69..dbd08415a1 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Scheduler.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
@@ -15,31 +15,35 @@
* limitations under the License.
*/
-package org.apache.spark.streaming
+package org.apache.spark.streaming.scheduler
-import util.{ManualClock, RecurringTimer, Clock}
import org.apache.spark.SparkEnv
import org.apache.spark.Logging
+import org.apache.spark.streaming.{Checkpoint, Time, CheckpointWriter}
+import org.apache.spark.streaming.util.{ManualClock, RecurringTimer, Clock}
+/**
+ * This class generates jobs from DStreams as well as drives checkpointing and cleaning
+ * up DStream metadata.
+ */
private[streaming]
-class Scheduler(ssc: StreamingContext) extends Logging {
+class JobGenerator(jobScheduler: JobScheduler) extends Logging {
initLogging()
- val concurrentJobs = ssc.sc.conf.getOrElse("spark.streaming.concurrentJobs", "1").toInt
- val jobManager = new JobManager(ssc, concurrentJobs)
- val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) {
- new CheckpointWriter(ssc.conf, ssc.checkpointDir)
- } else {
- null
- }
-
+ val ssc = jobScheduler.ssc
val clockClass = ssc.sc.conf.getOrElse(
"spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock")
val clock = Class.forName(clockClass).newInstance().asInstanceOf[Clock]
val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds,
longTime => generateJobs(new Time(longTime)))
val graph = ssc.graph
+ lazy val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) {
+ new CheckpointWriter(ssc.conf, ssc.checkpointDir)
+ } else {
+ null
+ }
+
var latestTime: Time = null
def start() = synchronized {
@@ -48,26 +52,24 @@ class Scheduler(ssc: StreamingContext) extends Logging {
} else {
startFirstTime()
}
- logInfo("Scheduler started")
+ logInfo("JobGenerator started")
}
def stop() = synchronized {
timer.stop()
- jobManager.stop()
if (checkpointWriter != null) checkpointWriter.stop()
ssc.graph.stop()
- logInfo("Scheduler stopped")
+ logInfo("JobGenerator stopped")
}
private def startFirstTime() {
val startTime = new Time(timer.getStartTime())
graph.start(startTime - graph.batchDuration)
timer.start(startTime.milliseconds)
- logInfo("Scheduler's timer started at " + startTime)
+ logInfo("JobGenerator's timer started at " + startTime)
}
private def restart() {
-
// If manual clock is being used for testing, then
// either set the manual clock to the last checkpointed time,
// or if the property is defined set it to that time
@@ -93,35 +95,34 @@ class Scheduler(ssc: StreamingContext) extends Logging {
val timesToReschedule = (pendingTimes ++ downTimes).distinct.sorted(Time.ordering)
logInfo("Batches to reschedule: " + timesToReschedule.mkString(", "))
timesToReschedule.foreach(time =>
- graph.generateJobs(time).foreach(jobManager.runJob)
+ jobScheduler.runJobs(time, graph.generateJobs(time))
)
// Restart the timer
timer.start(restartTime.milliseconds)
- logInfo("Scheduler's timer restarted at " + restartTime)
+ logInfo("JobGenerator's timer restarted at " + restartTime)
}
/** Generate jobs and perform checkpoint for the given `time`. */
- def generateJobs(time: Time) {
+ private def generateJobs(time: Time) {
SparkEnv.set(ssc.env)
logInfo("\n-----------------------------------------------------\n")
- graph.generateJobs(time).foreach(jobManager.runJob)
+ jobScheduler.runJobs(time, graph.generateJobs(time))
latestTime = time
doCheckpoint(time)
}
/**
- * Clear old metadata assuming jobs of `time` have finished processing.
- * And also perform checkpoint.
+ * On batch completion, clear old metadata and checkpoint computation.
*/
- def clearOldMetadata(time: Time) {
+ private[streaming] def onBatchCompletion(time: Time) {
ssc.graph.clearOldMetadata(time)
doCheckpoint(time)
}
/** Perform checkpoint for the give `time`. */
- def doCheckpoint(time: Time) = synchronized {
- if (ssc.checkpointDuration != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) {
+ private def doCheckpoint(time: Time) = synchronized {
+ if (checkpointWriter != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) {
logInfo("Checkpointing graph for time " + time)
ssc.graph.updateCheckpointData(time)
checkpointWriter.write(new Checkpoint(ssc, time))
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
new file mode 100644
index 0000000000..9511ccfbed
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler
+
+import org.apache.spark.Logging
+import org.apache.spark.SparkEnv
+import java.util.concurrent.{TimeUnit, ConcurrentHashMap, Executors}
+import scala.collection.mutable.HashSet
+import org.apache.spark.streaming._
+
+/**
+ * This class schedules jobs to be run on Spark. It uses the JobGenerator to generate
+ * the jobs and runs them using a thread pool. Number of threads
+ */
+private[streaming]
+class JobScheduler(val ssc: StreamingContext) extends Logging {
+
+ initLogging()
+
+ val jobSets = new ConcurrentHashMap[Time, JobSet]
+ val numConcurrentJobs = System.getProperty("spark.streaming.concurrentJobs", "1").toInt
+ val executor = Executors.newFixedThreadPool(numConcurrentJobs)
+ val generator = new JobGenerator(this)
+ val listenerBus = new StreamingListenerBus()
+
+ def clock = generator.clock
+
+ def start() {
+ generator.start()
+ }
+
+ def stop() {
+ generator.stop()
+ executor.shutdown()
+ if (!executor.awaitTermination(5, TimeUnit.SECONDS)) {
+ executor.shutdownNow()
+ }
+ }
+
+ def runJobs(time: Time, jobs: Seq[Job]) {
+ if (jobs.isEmpty) {
+ logInfo("No jobs added for time " + time)
+ } else {
+ val jobSet = new JobSet(time, jobs)
+ jobSets.put(time, jobSet)
+ jobSet.jobs.foreach(job => executor.execute(new JobHandler(job)))
+ logInfo("Added jobs for time " + time)
+ }
+ }
+
+ def getPendingTimes(): Array[Time] = {
+ jobSets.keySet.toArray(new Array[Time](0))
+ }
+
+ private def beforeJobStart(job: Job) {
+ val jobSet = jobSets.get(job.time)
+ if (!jobSet.hasStarted) {
+ listenerBus.post(StreamingListenerBatchStarted(jobSet.toBatchInfo()))
+ }
+ jobSet.beforeJobStart(job)
+ logInfo("Starting job " + job.id + " from job set of time " + jobSet.time)
+ SparkEnv.set(generator.ssc.env)
+ }
+
+ private def afterJobEnd(job: Job) {
+ val jobSet = jobSets.get(job.time)
+ jobSet.afterJobStop(job)
+ logInfo("Finished job " + job.id + " from job set of time " + jobSet.time)
+ if (jobSet.hasCompleted) {
+ jobSets.remove(jobSet.time)
+ generator.onBatchCompletion(jobSet.time)
+ logInfo("Total delay: %.3f s for time %s (execution: %.3f s)".format(
+ jobSet.totalDelay / 1000.0, jobSet.time.toString,
+ jobSet.processingDelay / 1000.0
+ ))
+ listenerBus.post(StreamingListenerBatchCompleted(jobSet.toBatchInfo()))
+ }
+ }
+
+ private[streaming]
+ class JobHandler(job: Job) extends Runnable {
+ def run() {
+ beforeJobStart(job)
+ try {
+ job.run()
+ } catch {
+ case e: Exception =>
+ logError("Running " + job + " failed", e)
+ }
+ afterJobEnd(job)
+ }
+ }
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala
new file mode 100644
index 0000000000..57268674ea
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler
+
+import scala.collection.mutable.HashSet
+import org.apache.spark.streaming.Time
+
+/** Class representing a set of Jobs
+ * belong to the same batch.
+ */
+private[streaming]
+case class JobSet(time: Time, jobs: Seq[Job]) {
+
+ private val incompleteJobs = new HashSet[Job]()
+ var submissionTime = System.currentTimeMillis() // when this jobset was submitted
+ var processingStartTime = -1L // when the first job of this jobset started processing
+ var processingEndTime = -1L // when the last job of this jobset finished processing
+
+ jobs.zipWithIndex.foreach { case (job, i) => job.setId(i) }
+ incompleteJobs ++= jobs
+
+ def beforeJobStart(job: Job) {
+ if (processingStartTime < 0) processingStartTime = System.currentTimeMillis()
+ }
+
+ def afterJobStop(job: Job) {
+ incompleteJobs -= job
+ if (hasCompleted) processingEndTime = System.currentTimeMillis()
+ }
+
+ def hasStarted() = (processingStartTime > 0)
+
+ def hasCompleted() = incompleteJobs.isEmpty
+
+ // Time taken to process all the jobs from the time they started processing
+ // (i.e. not including the time they wait in the streaming scheduler queue)
+ def processingDelay = processingEndTime - processingStartTime
+
+ // Time taken to process all the jobs from the time they were submitted
+ // (i.e. including the time they wait in the streaming scheduler queue)
+ def totalDelay = {
+ processingEndTime - time.milliseconds
+ }
+
+ def toBatchInfo(): BatchInfo = {
+ new BatchInfo(
+ time,
+ submissionTime,
+ if (processingStartTime >= 0 ) Some(processingStartTime) else None,
+ if (processingEndTime >= 0 ) Some(processingEndTime) else None
+ )
+ }
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala
index 6e9a781978..abff55d77c 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.streaming
+package org.apache.spark.streaming.scheduler
import org.apache.spark.streaming.dstream.{NetworkInputDStream, NetworkReceiver}
import org.apache.spark.streaming.dstream.{StopReceiver, ReportBlock, ReportError}
@@ -31,6 +31,7 @@ import akka.actor._
import akka.pattern.ask
import akka.dispatch._
import org.apache.spark.storage.BlockId
+import org.apache.spark.streaming.{Time, StreamingContext}
private[streaming] sealed trait NetworkInputTrackerMessage
private[streaming] case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala
new file mode 100644
index 0000000000..36225e190c
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler
+
+import scala.collection.mutable.Queue
+import org.apache.spark.util.Distribution
+
+/** Base trait for events related to StreamingListener */
+sealed trait StreamingListenerEvent
+
+case class StreamingListenerBatchCompleted(batchInfo: BatchInfo) extends StreamingListenerEvent
+
+case class StreamingListenerBatchStarted(batchInfo: BatchInfo) extends StreamingListenerEvent
+
+
+/**
+ * A listener interface for receiving information about an ongoing streaming
+ * computation.
+ */
+trait StreamingListener {
+ /**
+ * Called when processing of a batch has completed
+ */
+ def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { }
+
+ /**
+ * Called when processing of a batch has started
+ */
+ def onBatchStarted(batchStarted: StreamingListenerBatchStarted) { }
+}
+
+
+/**
+ * A simple StreamingListener that logs summary statistics across Spark Streaming batches
+ * @param numBatchInfos Number of last batches to consider for generating statistics (default: 10)
+ */
+class StatsReportListener(numBatchInfos: Int = 10) extends StreamingListener {
+ // Queue containing latest completed batches
+ val batchInfos = new Queue[BatchInfo]()
+
+ override def onBatchCompleted(batchStarted: StreamingListenerBatchCompleted) {
+ batchInfos.enqueue(batchStarted.batchInfo)
+ if (batchInfos.size > numBatchInfos) batchInfos.dequeue()
+ printStats()
+ }
+
+ def printStats() {
+ showMillisDistribution("Total delay: ", _.totalDelay)
+ showMillisDistribution("Processing time: ", _.processingDelay)
+ }
+
+ def showMillisDistribution(heading: String, getMetric: BatchInfo => Option[Long]) {
+ org.apache.spark.scheduler.StatsReportListener.showMillisDistribution(
+ heading, extractDistribution(getMetric))
+ }
+
+ def extractDistribution(getMetric: BatchInfo => Option[Long]): Option[Distribution] = {
+ Distribution(batchInfos.flatMap(getMetric(_)).map(_.toDouble))
+ }
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala
new file mode 100644
index 0000000000..110a20f282
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler
+
+import org.apache.spark.Logging
+import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer}
+import java.util.concurrent.LinkedBlockingQueue
+
+/** Asynchronously passes StreamingListenerEvents to registered StreamingListeners. */
+private[spark] class StreamingListenerBus() extends Logging {
+ private val listeners = new ArrayBuffer[StreamingListener]() with SynchronizedBuffer[StreamingListener]
+
+ /* Cap the capacity of the SparkListenerEvent queue so we get an explicit error (rather than
+ * an OOM exception) if it's perpetually being added to more quickly than it's being drained. */
+ private val EVENT_QUEUE_CAPACITY = 10000
+ private val eventQueue = new LinkedBlockingQueue[StreamingListenerEvent](EVENT_QUEUE_CAPACITY)
+ private var queueFullErrorMessageLogged = false
+
+ new Thread("StreamingListenerBus") {
+ setDaemon(true)
+ override def run() {
+ while (true) {
+ val event = eventQueue.take
+ event match {
+ case batchStarted: StreamingListenerBatchStarted =>
+ listeners.foreach(_.onBatchStarted(batchStarted))
+ case batchCompleted: StreamingListenerBatchCompleted =>
+ listeners.foreach(_.onBatchCompleted(batchCompleted))
+ case _ =>
+ }
+ }
+ }
+ }.start()
+
+ def addListener(listener: StreamingListener) {
+ listeners += listener
+ }
+
+ def post(event: StreamingListenerEvent) {
+ val eventAdded = eventQueue.offer(event)
+ if (!eventAdded && !queueFullErrorMessageLogged) {
+ logError("Dropping SparkListenerEvent because no remaining room in event queue. " +
+ "This likely means one of the SparkListeners is too slow and cannot keep up with the " +
+ "rate at which tasks are being started by the scheduler.")
+ queueFullErrorMessageLogged = true
+ }
+ }
+
+ /**
+ * Waits until there are no more events in the queue, or until the specified time has elapsed.
+ * Used for testing only. Returns true if the queue has emptied and false is the specified time
+ * elapsed before the queue emptied.
+ */
+ def waitUntilEmpty(timeoutMillis: Int): Boolean = {
+ val finishTime = System.currentTimeMillis + timeoutMillis
+ while (!eventQueue.isEmpty()) {
+ if (System.currentTimeMillis > finishTime) {
+ return false
+ }
+ /* Sleep rather than using wait/notify, because this is used only for testing and wait/notify
+ * add overhead in the general case. */
+ Thread.sleep(10)
+ }
+ return true
+ }
+}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
index 60e986cb9d..ee6b433d1f 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
@@ -26,17 +26,6 @@ import util.ManualClock
import org.apache.spark.{SparkContext, SparkConf}
class BasicOperationsSuite extends TestSuiteBase {
-
- override def framework = "BasicOperationsSuite"
-
- conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock")
-
- after {
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.driver.port")
- System.clearProperty("spark.hostPort")
- }
-
test("map") {
val input = Seq(1 to 4, 5 to 8, 9 to 12)
testOperation(
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
index ca230fd056..c60a3f5390 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
@@ -40,29 +40,25 @@ import org.apache.spark.streaming.util.ManualClock
* the checkpointing of a DStream's RDDs as well as the checkpointing of
* the whole DStream graph.
*/
-class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
+class CheckpointSuite extends TestSuiteBase {
- before {
+ var ssc: StreamingContext = null
+
+ override def batchDuration = Milliseconds(500)
+
+ override def actuallyWait = true // to allow checkpoints to be written
+
+ override def beforeFunction() {
+ super.beforeFunction()
FileUtils.deleteDirectory(new File(checkpointDir))
}
- after {
+ override def afterFunction() {
+ super.afterFunction()
if (ssc != null) ssc.stop()
FileUtils.deleteDirectory(new File(checkpointDir))
-
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.driver.port")
- System.clearProperty("spark.hostPort")
}
- var ssc: StreamingContext = null
-
- override def framework = "CheckpointSuite"
-
- override def batchDuration = Milliseconds(500)
-
- override def actuallyWait = true
-
test("basic rdd checkpoints + dstream graph checkpoint recovery") {
assert(batchDuration === Milliseconds(500), "batchDuration for this test must be 1 second")
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala
index 6337c5359c..da9b04de1a 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala
@@ -32,17 +32,22 @@ import collection.mutable.ArrayBuffer
* This testsuite tests master failures at random times while the stream is running using
* the real clock.
*/
-class FailureSuite extends FunSuite with BeforeAndAfter with Logging {
+class FailureSuite extends TestSuiteBase with Logging {
var directory = "FailureSuite"
val numBatches = 30
- val batchDuration = Milliseconds(1000)
- before {
+ override def batchDuration = Milliseconds(1000)
+
+ override def useManualClock = false
+
+ override def beforeFunction() {
+ super.beforeFunction()
FileUtils.deleteDirectory(new File(directory))
}
- after {
+ override def afterFunction() {
+ super.afterFunction()
FileUtils.deleteDirectory(new File(directory))
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
index 8c16daa21c..52381c10b0 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
@@ -50,16 +50,6 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
val testPort = 9999
- override def checkpointDir = "checkpoint"
-
- conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock")
-
- after {
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.driver.port")
- System.clearProperty("spark.hostPort")
- }
-
test("socket input stream") {
// Start the server
val testServer = new TestServer()
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
new file mode 100644
index 0000000000..fa64142096
--- /dev/null
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming
+
+import org.apache.spark.streaming.scheduler._
+import scala.collection.mutable.ArrayBuffer
+import org.scalatest.matchers.ShouldMatchers
+
+class StreamingListenerSuite extends TestSuiteBase with ShouldMatchers {
+
+ val input = (1 to 4).map(Seq(_)).toSeq
+ val operation = (d: DStream[Int]) => d.map(x => x)
+
+ // To make sure that the processing start and end times in collected
+ // information are different for successive batches
+ override def batchDuration = Milliseconds(100)
+ override def actuallyWait = true
+
+ test("basic BatchInfo generation") {
+ val ssc = setupStreams(input, operation)
+ val collector = new BatchInfoCollector
+ ssc.addStreamingListener(collector)
+ runStreams(ssc, input.size, input.size)
+ val batchInfos = collector.batchInfos
+ batchInfos should have size 4
+
+ batchInfos.foreach(info => {
+ info.schedulingDelay should not be None
+ info.processingDelay should not be None
+ info.totalDelay should not be None
+ info.schedulingDelay.get should be >= 0L
+ info.processingDelay.get should be >= 0L
+ info.totalDelay.get should be >= 0L
+ })
+
+ isInIncreasingOrder(batchInfos.map(_.submissionTime)) should be (true)
+ isInIncreasingOrder(batchInfos.map(_.processingStartTime.get)) should be (true)
+ isInIncreasingOrder(batchInfos.map(_.processingEndTime.get)) should be (true)
+ }
+
+ /** Check if a sequence of numbers is in increasing order */
+ def isInIncreasingOrder(seq: Seq[Long]): Boolean = {
+ for(i <- 1 until seq.size) {
+ if (seq(i - 1) > seq(i)) return false
+ }
+ true
+ }
+
+ /** Listener that collects information on processed batches */
+ class BatchInfoCollector extends StreamingListener {
+ val batchInfos = new ArrayBuffer[BatchInfo]
+ override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) {
+ batchInfos += batchCompleted.batchInfo
+ }
+ }
+}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
index 3dd6718491..33464bc3a1 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
@@ -110,7 +110,7 @@ class TestOutputStreamWithPartitions[T: ClassTag](parent: DStream[T],
trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
// Name of the framework for Spark context
- def framework = "TestSuiteBase"
+ def framework = this.getClass.getSimpleName
// Master for Spark context
def master = "local[2]"
@@ -127,15 +127,45 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
// Maximum time to wait before the test times out
def maxWaitTimeMillis = 10000
+ // Whether to use manual clock or not
+ def useManualClock = true
+
// Whether to actually wait in real time before changing manual clock
def actuallyWait = false
- // A SparkConf to use in tests. Can be modified before calling setupStreams to configure things.
+ //// A SparkConf to use in tests. Can be modified before calling setupStreams to configure things.
val conf = new SparkConf()
.setMaster(master)
.setAppName(framework)
.set("spark.cleaner.ttl", "3600")
+ // Default before function for any streaming test suite. Override this
+ // if you want to add your stuff to "before" (i.e., don't call before { } )
+ def beforeFunction() {
+ //if (useManualClock) {
+ // System.setProperty(
+ // "spark.streaming.clock",
+ // "org.apache.spark.streaming.util.ManualClock"
+ // )
+ //} else {
+ // System.clearProperty("spark.streaming.clock")
+ //}
+ if (useManualClock) {
+ conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock")
+ }
+ }
+
+ // Default after function for any streaming test suite. Override this
+ // if you want to add your stuff to "after" (i.e., don't call after { } )
+ def afterFunction() {
+ // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
+ System.clearProperty("spark.driver.port")
+ System.clearProperty("spark.hostPort")
+ }
+
+ before(beforeFunction)
+ after(afterFunction)
+
/**
* Set up required DStreams to test the DStream operation using the two sequences
* of input collections.
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala
index 3242c4cd11..c92c34d49b 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala
@@ -21,19 +21,9 @@ import org.apache.spark.streaming.StreamingContext._
class WindowOperationsSuite extends TestSuiteBase {
- conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock")
+ override def maxWaitTimeMillis = 20000 // large window tests can sometimes take longer
- override def framework = "WindowOperationsSuite"
-
- override def maxWaitTimeMillis = 20000
-
- override def batchDuration = Seconds(1)
-
- after {
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.driver.port")
- System.clearProperty("spark.hostPort")
- }
+ override def batchDuration = Seconds(1) // making sure its visible in this class
val largerSlideInput = Seq(
Seq(("a", 1)),
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index cc150888eb..595a7ee8c3 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -422,8 +422,10 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
}
def monitorApplication(appId: ApplicationId): Boolean = {
+ val interval = new SparkConf().getOrElse("spark.yarn.report.interval", "1000").toLong
+
while (true) {
- Thread.sleep(1000)
+ Thread.sleep(interval)
val report = super.getApplicationReport(appId)
logInfo("Application report from ASM: \n" +
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
index 4c9fee5695..5966a0f757 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
@@ -27,8 +27,8 @@ import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import org.apache.spark.Logging
-import org.apache.spark.scheduler.SplitInfo
-import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedulerBackend}
+import org.apache.spark.scheduler.{SplitInfo,TaskSchedulerImpl}
+import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
import org.apache.spark.util.Utils
import org.apache.hadoop.conf.Configuration
@@ -214,9 +214,9 @@ private[yarn] class YarnAllocationHandler(
// host if there are sufficiently large number of hosts/containers.
val allocatedContainers = new ArrayBuffer[Container](_allocatedContainers.size)
- allocatedContainers ++= ClusterScheduler.prioritizeContainers(dataLocalContainers)
- allocatedContainers ++= ClusterScheduler.prioritizeContainers(rackLocalContainers)
- allocatedContainers ++= ClusterScheduler.prioritizeContainers(offRackContainers)
+ allocatedContainers ++= TaskSchedulerImpl.prioritizeContainers(dataLocalContainers)
+ allocatedContainers ++= TaskSchedulerImpl.prioritizeContainers(rackLocalContainers)
+ allocatedContainers ++= TaskSchedulerImpl.prioritizeContainers(offRackContainers)
// Run each of the allocated containers
for (container <- allocatedContainers) {
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala
index 63a0449e5a..522e0a9ad7 100644
--- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala
@@ -20,13 +20,14 @@ package org.apache.spark.scheduler.cluster
import org.apache.spark._
import org.apache.hadoop.conf.Configuration
import org.apache.spark.deploy.yarn.YarnAllocationHandler
+import org.apache.spark.scheduler.TaskSchedulerImpl
import org.apache.spark.util.Utils
/**
*
* This scheduler launch worker through Yarn - by call into Client to launch WorkerLauncher as AM.
*/
-private[spark] class YarnClientClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) {
+private[spark] class YarnClientClusterScheduler(sc: SparkContext, conf: Configuration) extends TaskSchedulerImpl(sc) {
def this(sc: SparkContext) = this(sc, new Configuration())
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
index 6feaaff014..4b69f5078b 100644
--- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
@@ -20,9 +20,10 @@ package org.apache.spark.scheduler.cluster
import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState}
import org.apache.spark.{SparkException, Logging, SparkContext}
import org.apache.spark.deploy.yarn.{Client, ClientArguments}
+import org.apache.spark.scheduler.TaskSchedulerImpl
private[spark] class YarnClientSchedulerBackend(
- scheduler: ClusterScheduler,
+ scheduler: TaskSchedulerImpl,
sc: SparkContext)
extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
with Logging {
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
index 29b3f22e13..2d9fbcb400 100644
--- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
@@ -17,16 +17,20 @@
package org.apache.spark.scheduler.cluster
+import org.apache.hadoop.conf.Configuration
+
import org.apache.spark._
import org.apache.spark.deploy.yarn.{ApplicationMaster, YarnAllocationHandler}
+import org.apache.spark.scheduler.TaskSchedulerImpl
import org.apache.spark.util.Utils
-import org.apache.hadoop.conf.Configuration
/**
*
- * This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of ApplicationMaster, etc is done
+ * This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of
+ * ApplicationMaster, etc. is done
*/
-private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) {
+private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration)
+ extends TaskSchedulerImpl(sc) {
logInfo("Created YarnClusterScheduler")